"""Backend functions used in the app.""" import os import shutil import gradio as gr import numpy import requests import pickle import pandas from itertools import chain from settings import ( SERVER_URL, FHE_KEYS, CLIENT_FILES, SERVER_FILES, DEPLOYMENT_PATH, PROCESSED_INPUT_SHAPE, INPUT_INDEXES, INPUT_SLICES, PRE_PROCESSOR_USER_PATH, PRE_PROCESSOR_BANK_PATH, PRE_PROCESSOR_CS_AGENCY_PATH, CLIENT_TYPES, USER_COLUMNS, BANK_COLUMNS, CS_AGENCY_COLUMNS, YEARS_EMPLOYED_BINS, YEARS_EMPLOYED_BIN_NAME_TO_INDEX, ) from utils.client_server_interface import MultiInputsFHEModelClient # Define the messages associated to the predictions APPROVED_MESSAGE = "Credit card is likely to be approved ✅" DENIED_MESSAGE = "Credit card is likely to be denied ❌" # Load pre-processor instances with ( PRE_PROCESSOR_USER_PATH.open('rb') as file_user, PRE_PROCESSOR_BANK_PATH.open('rb') as file_bank, PRE_PROCESSOR_CS_AGENCY_PATH.open('rb') as file_cs_agency, ): PRE_PROCESSOR_USER = pickle.load(file_user) PRE_PROCESSOR_BANK = pickle.load(file_bank) PRE_PROCESSOR_CS_AGENCY = pickle.load(file_cs_agency) def shorten_bytes_object(bytes_object, limit=500): """Shorten the input bytes object to a given length. Encrypted data is too large for displaying it in the browser using Gradio. This function provides a shorten representation of it. Args: bytes_object (bytes): The input to shorten limit (int): The length to consider. Default to 500. Returns: str: Hexadecimal string shorten representation of the input byte object. """ # Define a shift for better display shift = 100 return bytes_object[shift : limit + shift].hex() def clean_temporary_files(n_keys=20): """Clean older keys and encrypted files. A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this limit is reached, the oldest files are deleted. Args: n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. """ # Get the oldest key files in the key directory key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime) # If more than n_keys keys are found, remove the oldest client_ids = [] if len(key_dirs) > n_keys: n_keys_to_delete = len(key_dirs) - n_keys for key_dir in key_dirs[:n_keys_to_delete]: client_ids.append(key_dir.name) shutil.rmtree(key_dir) # Delete all files related to the IDs whose keys were deleted for directory in chain(CLIENT_FILES.iterdir(), SERVER_FILES.iterdir()): for client_id in client_ids: if client_id in directory.name: shutil.rmtree(directory) def _get_client(client_id): """Get the client instance. Args: client_id (int): The client ID to consider. Returns: FHEModelClient: The client instance. """ key_dir = FHE_KEYS / f"{client_id}" return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES)) def _get_client_file_path(name, client_id, client_type=None): """Get the file path for the client. Args: name (str): The desired file name (either 'evaluation_key', 'encrypted_inputs' or 'encrypted_outputs'). client_id (int): The client ID to consider. client_type (Optional[str]): The type of user to consider (either 'user', 'bank', 'cs_agency' or None). Default to None, which is used for evaluation key and output. Returns: pathlib.Path: The file path. """ client_type_suffix = "" if client_type is not None: client_type_suffix = f"_{client_type}" dir_path = CLIENT_FILES / f"{client_id}" dir_path.mkdir(exist_ok=True) return dir_path / f"{name}{client_type_suffix}" def _send_to_server(client_id, client_type, file_name): """Send the encrypted inputs or the evaluation key to the server. Args: client_id (int): The client ID to consider. client_type (Optional[str]): The type of client to consider (either 'user', 'bank', 'cs_agency' or None). file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs'). """ # Get the paths to the encrypted inputs encrypted_file_path = _get_client_file_path(file_name, client_id, client_type) # Define the data and files to post data = { "client_id": client_id, "client_type": client_type, "file_name": file_name, } files = [ ("files", open(encrypted_file_path, "rb")), ] # Send the encrypted inputs or evaluation key to the server url = SERVER_URL + "send_file" with requests.post( url=url, data=data, files=files, ) as response: return response.ok def keygen_send(): """Generate the private and evaluation key, and send the evaluation key to the server. Returns: client_id (str): The current client ID to consider. """ # Clean temporary files clean_temporary_files() # Create an ID for the current client to consider client_id = numpy.random.randint(0, 2**32) # Retrieve the client instance client = _get_client(client_id) # Generate the private and evaluation keys client.generate_private_and_evaluation_keys(force=True) # Retrieve the serialized evaluation key evaluation_key = client.get_serialized_evaluation_keys() file_name = "evaluation_key" # Save evaluation key as bytes in a file as it is too large to pass through regular Gradio # buttons (see https://github.com/gradio-app/gradio/issues/1877) evaluation_key_path = _get_client_file_path(file_name, client_id) with evaluation_key_path.open("wb") as evaluation_key_file: evaluation_key_file.write(evaluation_key) # Send the evaluation key to the server _send_to_server(client_id, None, file_name) # Create a truncated version of the evaluation key for display evaluation_key_short = shorten_bytes_object(evaluation_key) return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent ✅") def _encrypt_send(client_id, inputs, client_type): """Encrypt the given inputs for a specific client and send it to the server. Args: client_id (str): The current client ID to consider. inputs (numpy.ndarray): The inputs to encrypt. client_type (str): The type of client to consider (either 'user', 'bank' or 'cs_agency'). Returns: encrypted_inputs_short (str): A short representation of the encrypted input to send in hex. """ if client_id == "": raise gr.Error("Please generate the keys first.") # Retrieve the client instance client = _get_client(client_id) # Quantize, encrypt and serialize the inputs encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs( inputs, input_index=INPUT_INDEXES[client_type], processed_input_shape=PROCESSED_INPUT_SHAPE, input_slice=INPUT_SLICES[client_type], ) file_name = "encrypted_inputs" # Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio # buttons, https://github.com/gradio-app/gradio/issues/1877 encrypted_inputs_path = _get_client_file_path(file_name, client_id, client_type) with encrypted_inputs_path.open("wb") as encrypted_inputs_file: encrypted_inputs_file.write(encrypted_inputs) # Create a truncated version of the encrypted inputs for display encrypted_inputs_short = shorten_bytes_object(encrypted_inputs) _send_to_server(client_id, client_type, file_name) return encrypted_inputs_short def pre_process_encrypt_send_user(client_id, *inputs): """Pre-process, encrypt and send the user inputs for a specific client to the server. Args: client_id (str): The current client ID to consider. *inputs (Tuple[numpy.ndarray]): The inputs to pre-process. Returns: (str): A short representation of the encrypted input to send in hex. """ bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \ family_status, occupation_type, housing_type = inputs # Retrieve boolean values own_car = "Car" in bool_inputs own_property = "Property" in bool_inputs mobile_phone = "Mobile phone" in bool_inputs user_inputs = pandas.DataFrame({ "Own_car": [own_car], "Own_property": [own_property], "Mobile_phone": [mobile_phone], "Num_children": [num_children], "Household_size": [household_size], "Total_income": [total_income], "Age": [age], "Income_type": [income_type], "Education_type": [education_type], "Family_status": [family_status], "Occupation_type": [occupation_type], "Housing_type": [housing_type], }) user_inputs = user_inputs.reindex(USER_COLUMNS, axis=1) preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs) return _encrypt_send(client_id, preprocessed_user_inputs, "user") def pre_process_encrypt_send_bank(client_id, *inputs): """Pre-process, encrypt and send the bank inputs for a specific client to the server. Args: client_id (str): The current client ID to consider. *inputs (Tuple[numpy.ndarray]): The inputs to pre-process. Returns: (str): A short representation of the encrypted input to send in hex. """ account_age = inputs[0] bank_inputs = pandas.DataFrame({ "Account_age": [account_age], }) bank_inputs = bank_inputs.reindex(BANK_COLUMNS, axis=1) preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs) return _encrypt_send(client_id, preprocessed_bank_inputs, "bank") def pre_process_encrypt_send_cs_agency(client_id, *inputs): """Pre-process, encrypt and send the credit scoring agency inputs for a specific client to the server. Args: client_id (str): The current client ID to consider. *inputs (Tuple[numpy.ndarray]): The inputs to pre-process. Returns: (str): A short representation of the encrypted input to send in hex. """ years_employed_bin, employed = inputs years_employed = YEARS_EMPLOYED_BIN_NAME_TO_INDEX[years_employed_bin] is_employed = employed == "Yes" cs_agency_inputs = pandas.DataFrame({ "Years_employed": [years_employed], "Employed": [is_employed], }) cs_agency_inputs = cs_agency_inputs.reindex(CS_AGENCY_COLUMNS, axis=1) preprocessed_cs_agency_inputs = PRE_PROCESSOR_CS_AGENCY.transform(cs_agency_inputs) return _encrypt_send(client_id, preprocessed_cs_agency_inputs, "cs_agency") def run_fhe(client_id): """Run the model on the encrypted inputs previously sent using FHE. Args: client_id (str): The current client ID to consider. """ if client_id == "": raise gr.Error("Please generate the keys first.") data = { "client_id": client_id, } # Trigger the FHE execution on the encrypted inputs previously sent url = SERVER_URL + "run_fhe" with requests.post( url=url, data=data, ) as response: if response.ok: return response.json() else: raise gr.Error("Please send the inputs from all three parties to the server first.") def get_output_and_decrypt(client_id): """Retrieve the encrypted output. Args: client_id (str): The current client ID to consider. Returns: (Tuple[str, bytes]): The output message based on the decrypted prediction as well as a byte short representation of the encrypted output. """ if client_id == "": raise gr.Error("Please generate the keys first.") data = { "client_id": client_id, } # Retrieve the encrypted output url = SERVER_URL + "get_output" with requests.post( url=url, data=data, ) as response: if response.ok: encrypted_output_proba = response.content # Create a truncated version of the encrypted inputs for display encrypted_output_short = shorten_bytes_object(encrypted_output_proba) # Retrieve the client API client = _get_client(client_id) # Deserialize, decrypt and post-process the encrypted output output_proba = client.deserialize_decrypt_dequantize(encrypted_output_proba) # Determine the predicted class output = numpy.argmax(output_proba, axis=1).squeeze() return ( APPROVED_MESSAGE if output == 1 else DENIED_MESSAGE, encrypted_output_short, ) else: raise gr.Error("Please run the FHE execution first and wait for it to be completed.") def explain_encrypt_run_decrypt(client_id, prediction_output, *inputs): """Pre-process and encrypt the inputs, run the prediction in FHE and decrypt the output. Args: client_id (str): The current client ID to consider. prediction_output (str): The initial prediction output. This parameter is only used to throw an error in case the prediction was positive. *inputs (Tuple[numpy.ndarray]): The inputs to consider. Returns: (str): A message indicating the number of additional years of employment that could be required in order to increase the chance of credit card approval. """ if "approved" in prediction_output: raise gr.Error( "Explaining the prediction can only be done if the credit card is likely to be denied." ) # Retrieve the credit scoring agency inputs years_employed, employed = inputs # Years_employed is divided into several ordered bins. Here, we retrieve the index representing # the bin from the input bin_index = YEARS_EMPLOYED_BIN_NAME_TO_INDEX[years_employed] # If the bin is not the last (representing the most years of employment), we run the model in # FHE for each bins "older" or equal to the given bin, in order. Then, we retrieve the first # bin that changes the model's prediction to "approval" and display it to the user. if bin_index != len(YEARS_EMPLOYED_BINS) - 1: # Loop over the bins starting with "older" or equal to the given bin for years_employed_bin in YEARS_EMPLOYED_BINS[bin_index:]: # Send the new encrypted input pre_process_encrypt_send_cs_agency(client_id, years_employed_bin, employed) # Run the model in FHE run_fhe(client_id) # Retrieve the new prediction output_prediction = get_output_and_decrypt(client_id) # If the bin made the model predict an approval, share it to the user if "approved" in output_prediction[0]: # If the approval was made using the given input, that means the user most likely # tried the bin suggested in a previous explainability run. In that case, we # confirm that the credit card is likely to be approved if years_employed_bin == years_employed: return APPROVED_MESSAGE # Else, that means the users is looking for some explainability. We therefore # suggest to try the obtained bin return ( DENIED_MESSAGE + f" However, having at least {years_employed_bin} years of " "employment would increase your chance of having your credit card approved." ) # In case no bins made the model predict an approval, explain why return ( DENIED_MESSAGE + " Unfortunately, increasing the number of years of employment up to " f"{YEARS_EMPLOYED_BINS[-1]} years does not seem to be enough to get an approval based " "on the given inputs. Other inputs like the income or the account's age might have " "bigger impact in this particular case." ) # In case the user tried the "oldest" bin (but still got denied), explain why return ( DENIED_MESSAGE + " Unfortunately, you already have the maximum amount of years of " f"employment ({years_employed} years). Other inputs like the income or the account's age " "might have a bigger impact in this particular case." )