import os import pandas as pd import numpy as np from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from concrete.ml.sklearn import RandomForestClassifier as ConcreteRandomForestClassifier import gradio as gr from utils import ( CLIENT_DIR, CURRENT_DIR, DEPLOYMENT_DIR, INPUT_BROWSER_LIMIT, KEYS_DIR, SERVER_URL, clean_directory, ) import requests import subprocess import time from typing import Dict, List, Tuple from concrete.ml.deployment import FHEModelClient subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR) time.sleep(3) current_dir = os.path.dirname(os.path.realpath(__file__)) data = pd.read_csv(os.path.join(current_dir, "files/titanic.csv")) def is_none(obj) -> bool: """ Check if the object is None. Args: obj (any): The input to be checked. Returns: bool: True if the object is None or empty, False otherwise. """ return obj is None or (obj is not None and len(obj) < 1) def encode_age(df): df.Age = df.Age.fillna(-0.5) bins = (-1, 0, 5, 12, 18, 25, 35, 60, 120) categories = pd.cut(df.Age, bins, labels=False) df.Age = categories return df def encode_fare(df): df.Fare = df.Fare.fillna(-0.5) bins = (-1, 0, 8, 15, 31, 1000) categories = pd.cut(df.Fare, bins, labels=False) df.Fare = categories return df def encode_df(df): df = encode_age(df) df = encode_fare(df) sex_mapping = {"male": 0, "female": 1} df = df.replace({"Sex": sex_mapping}) embark_mapping = {"S": 1, "C": 2, "Q": 3} df = df.replace({"Embarked": embark_mapping}) df.Embarked = df.Embarked.fillna(0) df["Company"] = 0 df.loc[(df["SibSp"] > 0), "Company"] = 1 df.loc[(df["Parch"] > 0), "Company"] = 2 df.loc[(df["SibSp"] > 0) & (df["Parch"] > 0), "Company"] = 3 df = df[ [ "PassengerId", "Pclass", "Sex", "Age", "Fare", "Embarked", "Company", "Survived", ] ] return df train = encode_df(data) X_all = train.drop(["Survived", "PassengerId"], axis=1) y_all = train["Survived"] num_test = 0.20 X_train, X_test, y_train, y_test = train_test_split( X_all, y_all, test_size=num_test, random_state=23 ) clf = RandomForestClassifier() clf.fit(X_train, y_train) predictions = clf.predict(X_test) n_estimators = 50 max_depth = 4 n_bits = 6 n_jobs_xgb = 1 n_jobs_gridsearch = -1 concrete_clf = ConcreteRandomForestClassifier( n_bits=n_bits, n_estimators=n_estimators, max_depth=max_depth, n_jobs=n_jobs_xgb ) concrete_clf.fit(X_train, y_train) concrete_predictions = concrete_clf.predict(X_test) def predict_survival(passenger_class, is_male, age, company, fare, embark_point): if passenger_class is None or embark_point is None: return None df = pd.DataFrame.from_dict( { "Pclass": [passenger_class + 1], "Sex": [0 if is_male else 1], "Age": [age], "Fare": [fare], "Embarked": [embark_point + 1], "Company": [ (1 if "Sibling" in company else 0) + (2 if "Child" in company else 0) ] } ) df = encode_age(df) df = encode_fare(df) pred = clf.predict_proba(df)[0] return {"Perishes": float(pred[0]), "Survives": float(pred[1])} def collect_input(passenger_class, is_male, age, company, fare, embark_point): if passenger_class is None or embark_point is None: return None input_dict = { "Pclass": [passenger_class + 1], "Sex": [0 if is_male else 1], "Age": [age], "Fare": [fare], "Embarked": [embark_point + 1], "Company": [ (1 if "Sibling" in company else 0) + (2 if "Child" in company else 0) ] } # print(input_dict) return input_dict def clear_predict_survival(input_dict): df = pd.DataFrame.from_dict(input_dict) df = encode_age(df) df = encode_fare(df) pred = clf.predict_proba(df)[0] return {"Perishes": float(pred[0]), "Survives": float(pred[1])} def concrete_predict_survival(input_dict): df = pd.DataFrame.from_dict(input_dict) df = encode_age(df) df = encode_fare(df) pred = concrete_clf.predict_proba(df)[0] return {"Perishes": float(pred[0]), "Survives": float(pred[1])} # print("\nclear_test ", clear_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]})) # print("encrypted_test", concrete_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]}),"\n") def key_gen_fn() -> Dict: """ Generate keys for a given user. Args: Returns: dict: A dictionary containing the generated keys and related information. """ clean_directory() # Generate a random user ID user_id = np.random.randint(0, 2**32) print(f"Your user ID is: {user_id}....") client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Creates the private and evaluation keys on the client side client.generate_private_and_evaluation_keys() # Get the serialized evaluation keys serialized_evaluation_keys = client.get_serialized_evaluation_keys() assert isinstance(serialized_evaluation_keys, bytes) # Save the evaluation key evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" with evaluation_key_path.open("wb") as f: f.write(serialized_evaluation_keys) serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] return { error_box2: gr.update(visible=False), key_box: gr.update(visible=True, value=serialized_evaluation_keys_shorten_hex), user_id_box: gr.update(visible=True, value=user_id), key_len_box: gr.update( visible=False, value=f"{len(serialized_evaluation_keys) / (10**6):.2f} MB" ), } def encrypt_fn(user_inputs: np.ndarray, user_id: str) -> None: """ """ if is_none(user_id) or is_none(user_inputs): print("Error in encryption step: Provide your inputs and generate the evaluation keys.") return { error_box3: gr.update( visible=True, value="⚠️ Please ensure that your inputs have been submitted and " "that you have generated the evaluation key.", ) } # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() user_inputs_df = pd.DataFrame.from_dict(user_inputs) user_inputs_df = encode_age(user_inputs_df) user_inputs_df = encode_fare(user_inputs_df) print("user_inputs to be encrypted =\n", user_inputs_df) print("user_inputs to be encrypted =\n", user_inputs_df.to_numpy()) encrypted_quantized_user_inputs = client.quantize_encrypt_serialize(user_inputs_df.to_numpy()) assert isinstance(encrypted_quantized_user_inputs, bytes) encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" with encrypted_input_path.open("wb") as f: f.write(encrypted_quantized_user_inputs) encrypted_quantized_user_inputs_shorten_hex = encrypted_quantized_user_inputs.hex()[ :INPUT_BROWSER_LIMIT ] return { error_box3: gr.update(visible=False), input_dict_box: gr.update(visible=False, value=user_inputs), enc_dict_box: gr.update(visible=True, value=encrypted_quantized_user_inputs_shorten_hex), } def send_input_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Send the encrypted data and the evaluation key to the server. """ if is_none(user_id) or is_none(user_inputs): return { error_box4: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the inputs have been submitted and the evaluation " "key has been generated before sending the data to the server.", ) } evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" if not evaluation_key_path.is_file(): print( "Error Encountered While Sending Data to the Server: " f"The key has been generated correctly - {evaluation_key_path.is_file()=}" ) return { error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.") } if not encrypted_input_path.is_file(): print( "Error Encountered While Sending Data to the Server: The data has not been encrypted " f"correctly on the client side - {encrypted_input_path.is_file()=}" ) return { error_box4: gr.update( visible=True, value="⚠️ Please encrypt the data with the private key first.", ), } # Define the data and files to post data = { "user_id": user_id, "input": user_inputs, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"Sending Data: {response.ok=}") return { error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent", } def run_fhe_fn(user_id: str) -> Dict: """Send the encrypted input and the evaluation key to the server. Args: user_id (int): The current user's ID. """ if is_none(user_id): return { error_box5: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the inputs have been submitted, the evaluation " "key has been generated and the server received the data " "before processing the data.", ), fhe_execution_time_box: None, } data = { "user_id": user_id, } url = SERVER_URL + "run_fhe" with requests.post( url=url, data=data, ) as response: if not response.ok: return { error_box5: gr.update( visible=True, value=( "⚠️ An error occurred on the Server Side. " "Please check connectivity and data transmission." ), ), fhe_execution_time_box: gr.update(visible=False), } else: time.sleep(10) print(f"response.ok: {response.ok}, {response.json()} - Computed") return { error_box5: gr.update(visible=False), fhe_execution_time_box: gr.update(visible=True, value=f"{response.json():.2f} seconds"), } def send_input_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Send the encrypted data and the evaluation key to the server. """ if is_none(user_id) or is_none(user_inputs): return { error_box4: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the inputs have been submitted and the evaluation " "key has been generated before sending the data to the server.", ) } evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" if not evaluation_key_path.is_file(): print( "Error Encountered While Sending Data to the Server: " f"The key has been generated correctly - {evaluation_key_path.is_file()=}" ) return { error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.") } if not encrypted_input_path.is_file(): print( "Error Encountered While Sending Data to the Server: The data has not been encrypted " f"correctly on the client side - {encrypted_input_path.is_file()=}" ) return { error_box4: gr.update( visible=True, value="⚠️ Please encrypt the data with the private key first.", ), } # Define the data and files to post data = { "user_id": user_id, "input": user_inputs, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"Sending Data: {response.ok=}") return { error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent", } def get_output_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Retreive the encrypted data from the server. """ if is_none(user_id) or is_none(user_inputs): return { error_box6: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the server has successfully processed and transmitted the data to the client.", ) } data = { "user_id": user_id, } # Retrieve the encrypted output url = SERVER_URL + "get_output" with requests.post( url=url, data=data, ) as response: if response.ok: print(f"Receive Data: {response.ok=}") encrypted_output = response.content # Save the encrypted output to bytes in a file as it is too large to pass through # regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" with encrypted_output_path.open("wb") as f: f.write(encrypted_output) return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"} def decrypt_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Dencrypt the data on the `Client Side`. Args: user_id (str): The current user's ID user_inputs (np.ndarray): The user inputs Returns: Decrypted output """ if is_none(user_id) or is_none(user_inputs): return { error_box7: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the client has successfully received the data from the server.", ) } # Get the encrypted output path encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" if not encrypted_output_path.is_file(): print("Error in decryption step: Please run the FHE execution, first.") return { error_box7: gr.update( visible=True, value="⚠️ Please ensure that: \n" "- the connectivity \n" "- the inputs have been submitted \n" "- the evaluation key has been generated \n" "- the server processed the encrypted data \n" "- the Client received the data from the Server before decrypting the prediction", ), decrypt_box: None, } # Load the encrypted output as bytes with encrypted_output_path.open("rb") as f: encrypted_output = f.read() # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Deserialize, decrypt and post-process the encrypted output output = client.deserialize_decrypt_dequantize(encrypted_output) print("output =\n", output) out = {"Perishes": float(output[0][0]), "Survives": float(output[0][1])} print("output =\n", out) return { error_box7: gr.update(visible=False), decrypt_box: out, label: out, } with gr.Blocks() as demo: # Step 1.1: Provide inputs gr.Markdown("### Titanic Survival Prediction with ML and Private Computation") with gr.Row(): inp = [ gr.Dropdown(["first", "second", "third"], type="index", label="Select Passenger Class"), gr.Checkbox(label="Male?"), gr.Slider(0, 80, value=25, label="Age", step=1), gr.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"), gr.Number(value=20, label="Fare"), gr.Radio(["Southampton", "Cherbourg", "Queenstown"], type="index", label="Embark point:"), ] out = gr.JSON() btn = gr.Button("Confirm inputs") btn.click(fn=collect_input, inputs=inp, outputs=out) # Step 2.1: Key generation gen_key_btn = gr.Button("Generate the evaluation key") error_box2 = gr.Textbox(label="Error ❌", visible=False) user_id_box = gr.Textbox(label="User ID:", visible=True) key_len_box = gr.Textbox(label="Evaluation Key Size:", visible=False) key_box = gr.Textbox(label="Evaluation key (truncated):", max_lines=3, visible=False) gen_key_btn.click( key_gen_fn, inputs=None, outputs=[ key_box, user_id_box, key_len_box, error_box2, ], ) # # Step 2.2: Encrypt data locally gr.Markdown("### Encrypt the data") encrypt_btn = gr.Button("Encrypt the data using the private secret key") error_box3 = gr.Textbox(label="Error ❌", visible=False) with gr.Row(): with gr.Column(): input_dict_box = gr.Textbox(label="input_dict_box:", max_lines=10) with gr.Column(): enc_dict_box = gr.Textbox(label="encrypted input_dict_box:", max_lines=10) encrypt_btn.click( encrypt_fn, inputs=[out, user_id_box], outputs=[ input_dict_box, enc_dict_box, error_box3, ], ) # # Step 2.3: Send encrypted data to the server gr.Markdown( "### Send the encrypted data to the Server Side" ) error_box4 = gr.Textbox(label="Error ❌", visible=False) with gr.Row(): with gr.Column(scale=4): send_input_btn = gr.Button("Send data") with gr.Column(scale=1): srv_resp_send_data_box = gr.Checkbox(label="Data Sent", show_label=False) send_input_btn.click( send_input_fn, inputs=[user_id_box, out], outputs=[error_box4, srv_resp_send_data_box], ) # ------------------------- Step 3 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 3: Run the FHE evaluation") gr.Markdown("