|
"""Backend functions used in the app.""" |
|
|
|
import os |
|
import shutil |
|
import gradio as gr |
|
import numpy |
|
import requests |
|
import pickle |
|
import pandas |
|
from itertools import chain |
|
|
|
from settings import ( |
|
SERVER_URL, |
|
FHE_KEYS, |
|
CLIENT_FILES, |
|
SERVER_FILES, |
|
APPROVAL_DEPLOYMENT_PATH, |
|
EXPLAIN_DEPLOYMENT_PATH, |
|
APPROVAL_PROCESSED_INPUT_SHAPE, |
|
EXPLAIN_PROCESSED_INPUT_SHAPE, |
|
INPUT_INDEXES, |
|
APPROVAL_INPUT_SLICES, |
|
EXPLAIN_INPUT_SLICES, |
|
PRE_PROCESSOR_USER_PATH, |
|
PRE_PROCESSOR_BANK_PATH, |
|
PRE_PROCESSOR_THIRD_PARTY_PATH, |
|
CLIENT_TYPES, |
|
USER_COLUMNS, |
|
BANK_COLUMNS, |
|
APPROVAL_THIRD_PARTY_COLUMNS, |
|
) |
|
|
|
from utils.client_server_interface import MultiInputsFHEModelClient, MultiInputsFHEModelServer |
|
|
|
|
|
EXPLAIN_FHE_SERVER = MultiInputsFHEModelServer(EXPLAIN_DEPLOYMENT_PATH) |
|
|
|
|
|
with ( |
|
PRE_PROCESSOR_USER_PATH.open('rb') as file_user, |
|
PRE_PROCESSOR_BANK_PATH.open('rb') as file_bank, |
|
PRE_PROCESSOR_THIRD_PARTY_PATH.open('rb') as file_third_party, |
|
): |
|
PRE_PROCESSOR_USER = pickle.load(file_user) |
|
PRE_PROCESSOR_BANK = pickle.load(file_bank) |
|
PRE_PROCESSOR_THIRD_PARTY = pickle.load(file_third_party) |
|
|
|
|
|
def shorten_bytes_object(bytes_object, limit=500): |
|
"""Shorten the input bytes object to a given length. |
|
|
|
Encrypted data is too large for displaying it in the browser using Gradio. This function |
|
provides a shorten representation of it. |
|
|
|
Args: |
|
bytes_object (bytes): The input to shorten |
|
limit (int): The length to consider. Default to 500. |
|
|
|
Returns: |
|
str: Hexadecimal string shorten representation of the input byte object. |
|
|
|
""" |
|
|
|
shift = 100 |
|
return bytes_object[shift : limit + shift].hex() |
|
|
|
|
|
def clean_temporary_files(n_keys=20): |
|
"""Clean older keys and encrypted files. |
|
|
|
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this |
|
limit is reached, the oldest files are deleted. |
|
|
|
Args: |
|
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. |
|
|
|
""" |
|
|
|
key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime) |
|
|
|
|
|
client_ids = [] |
|
if len(key_dirs) > n_keys: |
|
n_keys_to_delete = len(key_dirs) - n_keys |
|
for key_dir in key_dirs[:n_keys_to_delete]: |
|
client_ids.append(key_dir.name) |
|
shutil.rmtree(key_dir) |
|
|
|
|
|
for directory in chain(CLIENT_FILES.iterdir(), SERVER_FILES.iterdir()): |
|
for client_id in client_ids: |
|
if client_id in directory.name: |
|
shutil.rmtree(directory) |
|
|
|
|
|
def _get_client(client_id, is_approval=True): |
|
"""Get the client instance. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
is_approval (bool): If client is representing the 'approval' model (else, it is |
|
representing the 'explain' model). Default to True. |
|
|
|
Returns: |
|
FHEModelClient: The client instance. |
|
""" |
|
key_suffix = "approval" if is_approval else "explain" |
|
key_dir = FHE_KEYS / f"{client_id}_{key_suffix}" |
|
client_dir = APPROVAL_DEPLOYMENT_PATH if is_approval else EXPLAIN_DEPLOYMENT_PATH |
|
|
|
return MultiInputsFHEModelClient(client_dir, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES)) |
|
|
|
|
|
def _get_client_file_path(name, client_id, client_type=None): |
|
"""Get the file path for the client. |
|
|
|
Args: |
|
name (str): The desired file name (either 'evaluation_key', 'encrypted_inputs' or |
|
'encrypted_outputs'). |
|
client_id (int): The client ID to consider. |
|
client_type (Optional[str]): The type of user to consider (either 'user', 'bank', |
|
'third_party' or None). Default to None, which is used for evaluation key and output. |
|
|
|
Returns: |
|
pathlib.Path: The file path. |
|
""" |
|
client_type_suffix = "" |
|
if client_type is not None: |
|
client_type_suffix = f"_{client_type}" |
|
|
|
dir_path = CLIENT_FILES / f"{client_id}" |
|
dir_path.mkdir(exist_ok=True) |
|
|
|
return dir_path / f"{name}{client_type_suffix}" |
|
|
|
|
|
def _send_to_server(client_id, client_type, file_name): |
|
"""Send the encrypted inputs or the evaluation key to the server. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
client_type (Optional[str]): The type of client to consider (either 'user', 'bank', |
|
'third_party' or None). |
|
file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs'). |
|
""" |
|
|
|
encrypted_file_path = _get_client_file_path(file_name, client_id, client_type) |
|
|
|
|
|
data = { |
|
"client_id": client_id, |
|
"client_type": client_type, |
|
"file_name": file_name, |
|
} |
|
|
|
files = [ |
|
("files", open(encrypted_file_path, "rb")), |
|
] |
|
|
|
|
|
url = SERVER_URL + "send_file" |
|
with requests.post( |
|
url=url, |
|
data=data, |
|
files=files, |
|
) as response: |
|
return response.ok |
|
|
|
|
|
def keygen_send(): |
|
"""Generate the private and evaluation key, and send the evaluation key to the server. |
|
|
|
Returns: |
|
client_id (str): The current client ID to consider. |
|
""" |
|
|
|
clean_temporary_files() |
|
|
|
|
|
client_id = numpy.random.randint(0, 2**32) |
|
|
|
|
|
client = _get_client(client_id) |
|
|
|
|
|
client.generate_private_and_evaluation_keys(force=True) |
|
|
|
|
|
evaluation_key = client.get_serialized_evaluation_keys() |
|
|
|
file_name = "evaluation_key" |
|
|
|
|
|
|
|
evaluation_key_path = _get_client_file_path(file_name, client_id) |
|
|
|
with evaluation_key_path.open("wb") as evaluation_key_file: |
|
evaluation_key_file.write(evaluation_key) |
|
|
|
|
|
_send_to_server(client_id, None, file_name) |
|
|
|
|
|
evaluation_key_short = shorten_bytes_object(evaluation_key) |
|
|
|
return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent β
") |
|
|
|
|
|
def _encrypt_send(client_id, inputs, client_type, app_mode=True): |
|
"""Encrypt the given inputs for a specific client and send it to the server. |
|
|
|
Args: |
|
client_id (str): The current client ID to consider. |
|
inputs (numpy.ndarray): The inputs to encrypt. |
|
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). |
|
|
|
Returns: |
|
encrypted_inputs_short (str): A short representation of the encrypted input to send in hex. |
|
""" |
|
if client_id == "": |
|
raise gr.Error("Please generate the keys first.") |
|
|
|
|
|
client = _get_client(client_id) |
|
|
|
|
|
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs( |
|
inputs, |
|
input_index=INPUT_INDEXES[client_type], |
|
processed_input_shape=APPROVAL_PROCESSED_INPUT_SHAPE, |
|
input_slice=APPROVAL_INPUT_SLICES[client_type], |
|
) |
|
|
|
file_name = "encrypted_inputs" |
|
|
|
|
|
|
|
encrypted_inputs_path = _get_client_file_path(file_name, client_id, client_type) |
|
|
|
with encrypted_inputs_path.open("wb") as encrypted_inputs_file: |
|
encrypted_inputs_file.write(encrypted_inputs) |
|
|
|
|
|
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs) |
|
|
|
_send_to_server(client_id, client_type, file_name) |
|
|
|
return encrypted_inputs_short |
|
|
|
|
|
def _pre_process_user(*inputs): |
|
"""Pre-process the user inputs. |
|
|
|
Args: |
|
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. |
|
|
|
Returns: |
|
(numpy.ndarray): The pre-processed inputs. |
|
""" |
|
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \ |
|
family_status, occupation_type, housing_type = inputs |
|
|
|
|
|
own_car = "Car" in bool_inputs |
|
own_property = "Property" in bool_inputs |
|
mobile_phone = "Mobile phone" in bool_inputs |
|
|
|
user_inputs = pandas.DataFrame({ |
|
"Own_car": [own_car], |
|
"Own_property": [own_property], |
|
"Mobile_phone": [mobile_phone], |
|
"Num_children": [num_children], |
|
"Household_size": [household_size], |
|
"Total_income": [total_income], |
|
"Age": [age], |
|
"Income_type": [income_type], |
|
"Education_type": [education_type], |
|
"Family_status": [family_status], |
|
"Occupation_type": [occupation_type], |
|
"Housing_type": [housing_type], |
|
}) |
|
|
|
user_inputs = user_inputs.reindex(USER_COLUMNS, axis=1) |
|
|
|
preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs) |
|
|
|
return preprocessed_user_inputs |
|
|
|
|
|
def pre_process_encrypt_send_user(client_id, *inputs): |
|
"""Pre-process, encrypt and send the user inputs for a specific client to the server. |
|
|
|
Args: |
|
client_id (str): The current client ID to consider. |
|
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. |
|
|
|
Returns: |
|
(str): A short representation of the encrypted input to send in hex. |
|
""" |
|
preprocessed_user_inputs = _pre_process_user(*inputs) |
|
|
|
return _encrypt_send(client_id, preprocessed_user_inputs, "user") |
|
|
|
|
|
def _pre_process_bank(*inputs): |
|
"""Pre-process the bank inputs. |
|
|
|
Args: |
|
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. |
|
|
|
Returns: |
|
(numpy.ndarray): The pre-processed inputs. |
|
""" |
|
account_age = inputs[0] |
|
|
|
bank_inputs = pandas.DataFrame({ |
|
"Account_age": [account_age], |
|
}) |
|
|
|
bank_inputs = bank_inputs.reindex(BANK_COLUMNS, axis=1) |
|
|
|
preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs) |
|
|
|
return preprocessed_bank_inputs |
|
|
|
|
|
def pre_process_encrypt_send_bank(client_id, *inputs): |
|
"""Pre-process, encrypt and send the bank inputs for a specific client to the server. |
|
|
|
Args: |
|
client_id (str): The current client ID to consider. |
|
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. |
|
|
|
Returns: |
|
(str): A short representation of the encrypted input to send in hex. |
|
""" |
|
preprocessed_bank_inputs = _pre_process_bank(*inputs) |
|
|
|
return _encrypt_send(client_id, preprocessed_bank_inputs, "bank") |
|
|
|
|
|
def _pre_process_third_party(*inputs): |
|
"""Pre-process the third party inputs. |
|
|
|
Args: |
|
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. |
|
|
|
Returns: |
|
(numpy.ndarray): The pre-processed inputs. |
|
""" |
|
third_party_data = {} |
|
if len(inputs) == 1: |
|
employed = inputs[0] |
|
else: |
|
employed, years_employed = inputs |
|
third_party_data["Years_employed"] = [years_employed] |
|
|
|
is_employed = employed == "Yes" |
|
third_party_data["Employed"] = [is_employed] |
|
|
|
third_party_inputs = pandas.DataFrame(third_party_data) |
|
|
|
if len(inputs) == 1: |
|
preprocessed_third_party_inputs = third_party_inputs.to_numpy() |
|
else: |
|
third_party_inputs = third_party_inputs.reindex(APPROVAL_THIRD_PARTY_COLUMNS, axis=1) |
|
preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs) |
|
|
|
return preprocessed_third_party_inputs |
|
|
|
|
|
def pre_process_encrypt_send_third_party(client_id, *inputs): |
|
"""Pre-process, encrypt and send the third party inputs for a specific client to the server. |
|
|
|
Args: |
|
client_id (str): The current client ID to consider. |
|
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. |
|
|
|
Returns: |
|
(str): A short representation of the encrypted input to send in hex. |
|
""" |
|
preprocessed_third_party_inputs = _pre_process_third_party(*inputs) |
|
|
|
return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party") |
|
|
|
|
|
def run_fhe(client_id): |
|
"""Run the model on the encrypted inputs previously sent using FHE. |
|
|
|
Args: |
|
client_id (str): The current client ID to consider. |
|
""" |
|
|
|
if client_id == "": |
|
raise gr.Error("Please generate the keys first.") |
|
|
|
data = { |
|
"client_id": client_id, |
|
} |
|
|
|
|
|
url = SERVER_URL + "run_fhe" |
|
with requests.post( |
|
url=url, |
|
data=data, |
|
) as response: |
|
if response.ok: |
|
return response.json() |
|
else: |
|
raise gr.Error("Please send the inputs from all three parties to the server first.") |
|
|
|
|
|
def get_output_and_decrypt(client_id): |
|
"""Retrieve the encrypted output. |
|
|
|
Args: |
|
client_id (str): The current client ID to consider. |
|
|
|
Returns: |
|
(Tuple[str, bytes]): The output message based on the decrypted prediction as well as |
|
a byte short representation of the encrypted output. |
|
""" |
|
|
|
if client_id == "": |
|
raise gr.Error("Please generate the keys first.") |
|
|
|
data = { |
|
"client_id": client_id, |
|
} |
|
|
|
|
|
url = SERVER_URL + "get_output" |
|
with requests.post( |
|
url=url, |
|
data=data, |
|
) as response: |
|
if response.ok: |
|
encrypted_output_proba = response.content |
|
|
|
|
|
encrypted_output_short = shorten_bytes_object(encrypted_output_proba) |
|
|
|
|
|
client = _get_client(client_id) |
|
|
|
|
|
output_proba = client.deserialize_decrypt_dequantize(encrypted_output_proba) |
|
|
|
|
|
output = numpy.argmax(output_proba, axis=1).squeeze() |
|
|
|
return ( |
|
"Credit card is likely to be approved β
" if output == 1 |
|
else "Credit card is likely to be denied β", |
|
encrypted_output_short, |
|
) |
|
|
|
else: |
|
raise gr.Error("Please run the FHE execution first and wait for it to be completed.") |
|
|
|
|
|
def years_employed_encrypt_run_decrypt(client_id, prediction_output, *inputs): |
|
"""Pre-process and encrypt the inputs, run the prediction in FHE and decrypt the output. |
|
|
|
Args: |
|
client_id (str): The current client ID to consider. |
|
prediction_output (str): The initial prediction output. This parameter is only used to |
|
throw an error in case the prediction was positive. |
|
*inputs (Tuple[numpy.ndarray]): The inputs to consider. |
|
|
|
Returns: |
|
(str): A message indicating the number of additional years of employment that could be |
|
required in order to increase the chance of |
|
credit card approval. |
|
""" |
|
|
|
if "approved" in prediction_output: |
|
raise gr.Error( |
|
"Explaining the prediction can only be done if the credit card is likely to be denied." |
|
) |
|
|
|
|
|
client = _get_client(client_id, is_approval=False) |
|
|
|
|
|
client.generate_private_and_evaluation_keys(force=False) |
|
|
|
|
|
evaluation_key = client.get_serialized_evaluation_keys() |
|
|
|
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \ |
|
family_status, occupation_type, housing_type, account_age, employed, years_employed = inputs |
|
|
|
preprocessed_user_inputs = _pre_process_user( |
|
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, |
|
family_status, occupation_type, housing_type, |
|
) |
|
preprocessed_bank_inputs = _pre_process_bank(account_age) |
|
preprocessed_third_party_inputs = _pre_process_third_party(employed) |
|
|
|
preprocessed_inputs = [ |
|
preprocessed_user_inputs, |
|
preprocessed_bank_inputs, |
|
preprocessed_third_party_inputs |
|
] |
|
|
|
|
|
encrypted_inputs = [] |
|
for client_type, preprocessed_input in zip(CLIENT_TYPES, preprocessed_inputs): |
|
encrypted_input = client.quantize_encrypt_serialize_multi_inputs( |
|
preprocessed_input, |
|
input_index=INPUT_INDEXES[client_type], |
|
processed_input_shape=EXPLAIN_PROCESSED_INPUT_SHAPE, |
|
input_slice=EXPLAIN_INPUT_SLICES[client_type], |
|
) |
|
encrypted_inputs.append(encrypted_input) |
|
|
|
|
|
encrypted_output = EXPLAIN_FHE_SERVER.run( |
|
*encrypted_inputs, |
|
serialized_evaluation_keys=evaluation_key |
|
) |
|
|
|
|
|
output_prediction = client.deserialize_decrypt_dequantize(encrypted_output) |
|
|
|
|
|
years_employed_diff = int(numpy.ceil(output_prediction.squeeze() - years_employed)) |
|
|
|
if years_employed_diff > 0: |
|
return ( |
|
f"Having at least {years_employed_diff} more years of employment would increase " |
|
"your chance of having your credit card approved." |
|
) |
|
|
|
return ( |
|
"The number of years of employment you provided seems to be enough. The negative prediction " |
|
"might come from other inputs." |
|
) |
|
|
|
|