encrypted_credit_scoring / development.py
romanbredehoft-zama's picture
Add second model for optional explainability step
74c0c8e
raw
history blame
5.12 kB
"""Train and compile the model."""
import shutil
import numpy
import pandas
import pickle
from settings import (
APPROVAL_DEPLOYMENT_PATH,
EXPLAIN_DEPLOYMENT_PATH,
DATA_PATH,
APPROVAL_INPUT_SLICES,
EXPLAIN_INPUT_SLICES,
PRE_PROCESSOR_USER_PATH,
PRE_PROCESSOR_BANK_PATH,
PRE_PROCESSOR_THIRD_PARTY_PATH,
USER_COLUMNS,
BANK_COLUMNS,
APPROVAL_THIRD_PARTY_COLUMNS,
EXPLAIN_THIRD_PARTY_COLUMNS,
)
from utils.client_server_interface import MultiInputsFHEModelDev
from utils.model import MultiInputDecisionTreeClassifier, MultiInputDecisionTreeRegressor
from utils.pre_processing import get_pre_processors
def get_multi_inputs(data, is_approval):
"""Get inputs for all three parties from the input data, using fixed slices.
Args:
data (numpy.ndarray): The input data to consider.
is_approval (bool): If the data should be used for the 'approval' model (else, otherwise for
the 'explain' model).
Returns:
(Tuple[numpy.ndarray]): The inputs for all three parties.
"""
if is_approval:
return (
data[:, APPROVAL_INPUT_SLICES["user"]],
data[:, APPROVAL_INPUT_SLICES["bank"]],
data[:, APPROVAL_INPUT_SLICES["third_party"]]
)
return (
data[:, EXPLAIN_INPUT_SLICES["user"]],
data[:, EXPLAIN_INPUT_SLICES["bank"]],
data[:, EXPLAIN_INPUT_SLICES["third_party"]]
)
print("Load and pre-process the data")
# Load the data
data = pandas.read_csv(DATA_PATH, encoding="utf-8")
# Define input and target data
data_x = data.copy()
data_y = data_x.pop("Target").copy().to_frame()
# Get data from all parties
data_user = data_x[USER_COLUMNS].copy()
data_bank = data_x[BANK_COLUMNS].copy()
data_third_party = data_x[APPROVAL_THIRD_PARTY_COLUMNS].copy()
# Feature engineer the data
pre_processor_user, pre_processor_bank, pre_processor_third_party = get_pre_processors()
preprocessed_data_user = pre_processor_user.fit_transform(data_user)
preprocessed_data_bank = pre_processor_bank.fit_transform(data_bank)
preprocessed_data_third_party = pre_processor_third_party.fit_transform(data_third_party)
preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1)
print("\nTrain and compile the model")
model_approval = MultiInputDecisionTreeClassifier()
model_approval, sklearn_model_approval = model_approval.fit_benchmark(preprocessed_data_x, data_y)
multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=True)
model_approval.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
print("\nSave deployment files")
# Delete the deployment folder and its content if it already exists
if APPROVAL_DEPLOYMENT_PATH.is_dir():
shutil.rmtree(APPROVAL_DEPLOYMENT_PATH)
# Save files needed for deployment (and enable cross-platform deployment)
fhe_model_dev_approval = MultiInputsFHEModelDev(APPROVAL_DEPLOYMENT_PATH, model_approval)
fhe_model_dev_approval.save(via_mlir=True)
# Save pre-processors
with (
PRE_PROCESSOR_USER_PATH.open('wb') as file_user,
PRE_PROCESSOR_BANK_PATH.open('wb') as file_bank,
PRE_PROCESSOR_THIRD_PARTY_PATH.open('wb') as file_third_party,
):
pickle.dump(pre_processor_user, file_user)
pickle.dump(pre_processor_bank, file_bank)
pickle.dump(pre_processor_third_party, file_third_party)
print("\nLoad, train, compile and save files for the 'explain' model")
# Define input and target data
data_x = data.copy()
data_y = data_x.pop("Years_employed").copy().to_frame()
target_values = data_x.pop("Target").copy()
# Get all data points whose target value is True (credit card has been approved)
approved_mask = target_values == 1
data_x_approved = data_x[approved_mask]
data_y_approved = data_y[approved_mask]
# Get data from all parties
data_user = data_x_approved[USER_COLUMNS].copy()
data_bank = data_x_approved[BANK_COLUMNS].copy()
data_third_party = data_x_approved[EXPLAIN_THIRD_PARTY_COLUMNS].copy()
preprocessed_data_user = pre_processor_user.transform(data_user)
preprocessed_data_bank = pre_processor_bank.transform(data_bank)
preprocessed_data_third_party = data_third_party.to_numpy()
preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1)
model_explain = MultiInputDecisionTreeRegressor()
model_explain, sklearn_model_explain = model_explain.fit_benchmark(preprocessed_data_x, data_y_approved)
multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=False)
model_explain.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
# Delete the deployment folder and its content if it already exists
if EXPLAIN_DEPLOYMENT_PATH.is_dir():
shutil.rmtree(EXPLAIN_DEPLOYMENT_PATH)
# Save files needed for deployment (and enable cross-platform deployment)
fhe_model_dev_explain = MultiInputsFHEModelDev(EXPLAIN_DEPLOYMENT_PATH, model_explain)
fhe_model_dev_explain.save(via_mlir=True)
print("\nDone !")