import os import pandas as pd import numpy as np from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from concrete.ml.sklearn import RandomForestClassifier as ConcreteRandomForestClassifier import gradio as gr from utils import ( CLIENT_DIR, CURRENT_DIR, DEPLOYMENT_DIR, INPUT_BROWSER_LIMIT, KEYS_DIR, SERVER_URL, clean_directory, ) import requests import subprocess import time from typing import Dict, List, Tuple from concrete.ml.deployment import FHEModelClient subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR) time.sleep(3) current_dir = os.path.dirname(os.path.realpath(__file__)) data = pd.read_csv(os.path.join(current_dir, "files/titanic.csv")) def is_none(obj) -> bool: """ Check if the object is None. Args: obj (any): The input to be checked. Returns: bool: True if the object is None or empty, False otherwise. """ return obj is None or (obj is not None and len(obj) < 1) def encode_age(df): df.Age = df.Age.fillna(-0.5) bins = (-1, 0, 5, 12, 18, 25, 35, 60, 120) categories = pd.cut(df.Age, bins, labels=False) df.Age = categories return df def encode_fare(df): df.Fare = df.Fare.fillna(-0.5) bins = (-1, 0, 8, 15, 31, 1000) categories = pd.cut(df.Fare, bins, labels=False) df.Fare = categories return df def encode_df(df): df = encode_age(df) df = encode_fare(df) sex_mapping = {"male": 0, "female": 1} df = df.replace({"Sex": sex_mapping}) embark_mapping = {"S": 1, "C": 2, "Q": 3} df = df.replace({"Embarked": embark_mapping}) df.Embarked = df.Embarked.fillna(0) df["Company"] = 0 df.loc[(df["SibSp"] > 0), "Company"] = 1 df.loc[(df["Parch"] > 0), "Company"] = 2 df.loc[(df["SibSp"] > 0) & (df["Parch"] > 0), "Company"] = 3 df = df[ [ "PassengerId", "Pclass", "Sex", "Age", "Fare", "Embarked", "Company", "Survived", ] ] return df train = encode_df(data) X_all = train.drop(["Survived", "PassengerId"], axis=1) y_all = train["Survived"] num_test = 0.20 X_train, X_test, y_train, y_test = train_test_split( X_all, y_all, test_size=num_test, random_state=23 ) clf = RandomForestClassifier() clf.fit(X_train, y_train) predictions = clf.predict(X_test) n_estimators = 50 max_depth = 4 n_bits = 6 n_jobs_xgb = 1 n_jobs_gridsearch = -1 concrete_clf = ConcreteRandomForestClassifier( n_bits=n_bits, n_estimators=n_estimators, max_depth=max_depth, n_jobs=n_jobs_xgb ) concrete_clf.fit(X_train, y_train) concrete_predictions = concrete_clf.predict(X_test) def predict_survival(passenger_class, is_male, age, company, fare, embark_point): if passenger_class is None or embark_point is None: return None df = pd.DataFrame.from_dict( { "Pclass": [passenger_class + 1], "Sex": [0 if is_male else 1], "Age": [age], "Fare": [fare], "Embarked": [embark_point + 1], "Company": [ (1 if "Sibling" in company else 0) + (2 if "Child" in company else 0) ] } ) df = encode_age(df) df = encode_fare(df) pred = clf.predict_proba(df)[0] return {"Perishes": float(pred[0]), "Survives": float(pred[1])} def collect_input(passenger_class, is_male, age, company, fare, embark_point): if passenger_class is None or embark_point is None: return None input_dict = { "Pclass": [passenger_class + 1], "Sex": [0 if is_male else 1], "Age": [age], "Fare": [fare], "Embarked": [embark_point + 1], "Company": [ (1 if "Sibling" in company else 0) + (2 if "Child" in company else 0) ] } # print(input_dict) return input_dict def clear_predict_survival(input_dict): df = pd.DataFrame.from_dict(input_dict) df = encode_age(df) df = encode_fare(df) pred = clf.predict_proba(df)[0] return {"Perishes": float(pred[0]), "Survives": float(pred[1])} def concrete_predict_survival(input_dict): df = pd.DataFrame.from_dict(input_dict) df = encode_age(df) df = encode_fare(df) pred = concrete_clf.predict_proba(df)[0] return {"Perishes": float(pred[0]), "Survives": float(pred[1])} # print("\nclear_test ", clear_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]})) # print("encrypted_test", concrete_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]}),"\n") def key_gen_fn() -> Dict: """ Generate keys for a given user. Args: Returns: dict: A dictionary containing the generated keys and related information. """ clean_directory() # Generate a random user ID user_id = np.random.randint(0, 2**32) print(f"Your user ID is: {user_id}....") client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Creates the private and evaluation keys on the client side client.generate_private_and_evaluation_keys() # Get the serialized evaluation keys serialized_evaluation_keys = client.get_serialized_evaluation_keys() assert isinstance(serialized_evaluation_keys, bytes) # Save the evaluation key evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" with evaluation_key_path.open("wb") as f: f.write(serialized_evaluation_keys) serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] return { error_box2: gr.update(visible=False), key_box: gr.update(visible=True, value=serialized_evaluation_keys_shorten_hex), user_id_box: gr.update(visible=True, value=user_id), key_len_box: gr.update( visible=False, value=f"{len(serialized_evaluation_keys) / (10**6):.2f} MB" ), } def encrypt_fn(user_inputs: np.ndarray, user_id: str) -> None: """ """ if is_none(user_id) or is_none(user_inputs): print("Error in encryption step: Provide your inputs and generate the evaluation keys.") return { error_box3: gr.update( visible=True, value="⚠️ Please ensure that your inputs have been submitted and " "that you have generated the evaluation key.", ) } # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() user_inputs_df = pd.DataFrame.from_dict(user_inputs) user_inputs_df = encode_age(user_inputs_df) user_inputs_df = encode_fare(user_inputs_df) print("user_inputs to be encrypted =\n", user_inputs_df) print("user_inputs to be encrypted =\n", user_inputs_df.to_numpy()) encrypted_quantized_user_inputs = client.quantize_encrypt_serialize(user_inputs_df.to_numpy()) assert isinstance(encrypted_quantized_user_inputs, bytes) encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" with encrypted_input_path.open("wb") as f: f.write(encrypted_quantized_user_inputs) encrypted_quantized_user_inputs_shorten_hex = encrypted_quantized_user_inputs.hex()[ :INPUT_BROWSER_LIMIT ] return { error_box3: gr.update(visible=False), input_dict_box: gr.update(visible=False, value=user_inputs), enc_dict_box: gr.update(visible=True, value=encrypted_quantized_user_inputs_shorten_hex), } def send_input_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Send the encrypted data and the evaluation key to the server. """ if is_none(user_id) or is_none(user_inputs): return { error_box4: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the inputs have been submitted and the evaluation " "key has been generated before sending the data to the server.", ) } evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" if not evaluation_key_path.is_file(): print( "Error Encountered While Sending Data to the Server: " f"The key has been generated correctly - {evaluation_key_path.is_file()=}" ) return { error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.") } if not encrypted_input_path.is_file(): print( "Error Encountered While Sending Data to the Server: The data has not been encrypted " f"correctly on the client side - {encrypted_input_path.is_file()=}" ) return { error_box4: gr.update( visible=True, value="⚠️ Please encrypt the data with the private key first.", ), } # Define the data and files to post data = { "user_id": user_id, "input": user_inputs, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"Sending Data: {response.ok=}") return { error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent", } def run_fhe_fn(user_id: str) -> Dict: """Send the encrypted input and the evaluation key to the server. Args: user_id (int): The current user's ID. """ if is_none(user_id): return { error_box5: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the inputs have been submitted, the evaluation " "key has been generated and the server received the data " "before processing the data.", ), fhe_execution_time_box: None, } data = { "user_id": user_id, } url = SERVER_URL + "run_fhe" with requests.post( url=url, data=data, ) as response: if not response.ok: return { error_box5: gr.update( visible=True, value=( "⚠️ An error occurred on the Server Side. " "Please check connectivity and data transmission." ), ), fhe_execution_time_box: gr.update(visible=False), } else: time.sleep(10) print(f"response.ok: {response.ok}, {response.json()} - Computed") return { error_box5: gr.update(visible=False), fhe_execution_time_box: gr.update(visible=True, value=f"{response.json():.2f} seconds"), } def send_input_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Send the encrypted data and the evaluation key to the server. """ if is_none(user_id) or is_none(user_inputs): return { error_box4: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the inputs have been submitted and the evaluation " "key has been generated before sending the data to the server.", ) } evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" if not evaluation_key_path.is_file(): print( "Error Encountered While Sending Data to the Server: " f"The key has been generated correctly - {evaluation_key_path.is_file()=}" ) return { error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.") } if not encrypted_input_path.is_file(): print( "Error Encountered While Sending Data to the Server: The data has not been encrypted " f"correctly on the client side - {encrypted_input_path.is_file()=}" ) return { error_box4: gr.update( visible=True, value="⚠️ Please encrypt the data with the private key first.", ), } # Define the data and files to post data = { "user_id": user_id, "input": user_inputs, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"Sending Data: {response.ok=}") return { error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent", } def get_output_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Retreive the encrypted data from the server. """ if is_none(user_id) or is_none(user_inputs): return { error_box6: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the server has successfully processed and transmitted the data to the client.", ) } data = { "user_id": user_id, } # Retrieve the encrypted output url = SERVER_URL + "get_output" with requests.post( url=url, data=data, ) as response: if response.ok: print(f"Receive Data: {response.ok=}") encrypted_output = response.content # Save the encrypted output to bytes in a file as it is too large to pass through # regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" with encrypted_output_path.open("wb") as f: f.write(encrypted_output) return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"} def decrypt_fn(user_id: str, user_inputs: np.ndarray) -> Dict: """Dencrypt the data on the `Client Side`. Args: user_id (str): The current user's ID user_inputs (np.ndarray): The user inputs Returns: Decrypted output """ if is_none(user_id) or is_none(user_inputs): return { error_box7: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the client has successfully received the data from the server.", ) } # Get the encrypted output path encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" if not encrypted_output_path.is_file(): print("Error in decryption step: Please run the FHE execution, first.") return { error_box7: gr.update( visible=True, value="⚠️ Please ensure that: \n" "- the connectivity \n" "- the inputs have been submitted \n" "- the evaluation key has been generated \n" "- the server processed the encrypted data \n" "- the Client received the data from the Server before decrypting the prediction", ), decrypt_box: None, } # Load the encrypted output as bytes with encrypted_output_path.open("rb") as f: encrypted_output = f.read() # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Deserialize, decrypt and post-process the encrypted output output = client.deserialize_decrypt_dequantize(encrypted_output) print("output =\n", output) out = {"Perishes": float(output[0][0]), "Survives": float(output[0][1])} print("output =\n", out) return { error_box7: gr.update(visible=False), decrypt_box: out, label: out, } with gr.Blocks() as demo: # Step 1.1: Provide inputs gr.Markdown("### Titanic Survival Prediction with ML and Private Computation") with gr.Row(): inp = [ gr.Dropdown(["first", "second", "third"], type="index", label="Select Passenger Class"), gr.Checkbox(label="Male?"), gr.Slider(0, 80, value=25, label="Age", step=1), gr.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"), gr.Number(value=20, label="Fare"), gr.Radio(["Southampton", "Cherbourg", "Queenstown"], type="index", label="Embark point:"), ] out = gr.JSON() btn = gr.Button("Confirm inputs") btn.click(fn=collect_input, inputs=inp, outputs=out) # Step 2.1: Key generation gen_key_btn = gr.Button("Generate the evaluation key") error_box2 = gr.Textbox(label="Error ❌", visible=False) user_id_box = gr.Textbox(label="User ID:", visible=True) key_len_box = gr.Textbox(label="Evaluation Key Size:", visible=False) key_box = gr.Textbox(label="Evaluation key (truncated):", max_lines=3, visible=False) gen_key_btn.click( key_gen_fn, inputs=None, outputs=[ key_box, user_id_box, key_len_box, error_box2, ], ) # # Step 2.2: Encrypt data locally gr.Markdown("### Encrypt the data") encrypt_btn = gr.Button("Encrypt the data using the private secret key") error_box3 = gr.Textbox(label="Error ❌", visible=False) with gr.Row(): with gr.Column(): input_dict_box = gr.Textbox(label="input_dict_box:", max_lines=10) with gr.Column(): enc_dict_box = gr.Textbox(label="encrypted input_dict_box:", max_lines=10) encrypt_btn.click( encrypt_fn, inputs=[out, user_id_box], outputs=[ input_dict_box, enc_dict_box, error_box3, ], ) # # Step 2.3: Send encrypted data to the server gr.Markdown( "### Send the encrypted data to the Server Side" ) error_box4 = gr.Textbox(label="Error ❌", visible=False) with gr.Row(): with gr.Column(scale=4): send_input_btn = gr.Button("Send data") with gr.Column(scale=1): srv_resp_send_data_box = gr.Checkbox(label="Data Sent", show_label=False) send_input_btn.click( send_input_fn, inputs=[user_id_box, out], outputs=[error_box4, srv_resp_send_data_box], ) # ------------------------- Step 3 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 3: Run the FHE evaluation") gr.Markdown("
") gr.Markdown("Server Side") gr.Markdown( "Once the server receives the encrypted data, it can process and compute the output without ever decrypting the data just as it would on clear data.\n\n" ) run_fhe_btn = gr.Button("Run the FHE evaluation") error_box5 = gr.Textbox(label="Error ❌", visible=False) fhe_execution_time_box = gr.Textbox(label="Total FHE Execution Time:", visible=True) run_fhe_btn.click( run_fhe_fn, inputs=[user_id_box], outputs=[fhe_execution_time_box, error_box5], ) # ------------------------- Step 4 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 4: Decrypt the data") gr.Markdown("
") gr.Markdown("Client Side") gr.Markdown( "### Get the encrypted data from the Server Side" ) error_box6 = gr.Textbox(label="Error ❌", visible=False) # Step 4.1: Data transmission with gr.Row(): with gr.Column(scale=4): get_output_btn = gr.Button("Get data") with gr.Column(scale=1): srv_resp_retrieve_data_box = gr.Checkbox(label="Data Received", show_label=False) get_output_btn.click( get_output_fn, inputs=[user_id_box, out], outputs=[srv_resp_retrieve_data_box, error_box6], ) # Step 4.1: Data transmission gr.Markdown("### Decrypt the output") decrypt_btn = gr.Button("Decrypt the output using the private secret key") error_box7 = gr.Textbox(label="Error ❌", visible=False) decrypt_box = gr.Textbox(label="Decrypted Output:") label = gr.Label() decrypt_btn.click( decrypt_fn, inputs=[user_id_box, out], outputs=[decrypt_box, error_box7, label], ) # ------------------------- End ------------------------- demo.launch()