romanbredehoft-zama
commited on
Commit
•
316f8e9
1
Parent(s):
b47829b
Rename third party and improve collaboration comments
Browse files- app.py +18 -12
- backend.py +17 -17
- deployment_files/model/{pre_processor_third_party.pkl → pre_processor_cs_agency.pkl} +0 -0
- deployment_files/{pre_processor_third_party.pkl → pre_processor_cs_agency.pkl} +0 -0
- development.py +9 -9
- server.py +1 -1
- settings.py +5 -5
- utils/client_server_interface.py +1 -1
- utils/pre_processing.py +2 -2
app.py
CHANGED
@@ -24,7 +24,7 @@ from backend import (
|
|
24 |
keygen_send,
|
25 |
pre_process_encrypt_send_user,
|
26 |
pre_process_encrypt_send_bank,
|
27 |
-
|
28 |
run_fhe,
|
29 |
get_output_and_decrypt,
|
30 |
explain_encrypt_run_decrypt,
|
@@ -61,12 +61,14 @@ with demo:
|
|
61 |
"""
|
62 |
)
|
63 |
|
64 |
-
gr.Markdown("# Client
|
65 |
|
66 |
gr.Markdown("## Step 1: Generate the keys.")
|
67 |
gr.Markdown(
|
68 |
"""
|
69 |
-
- The private key is
|
|
|
|
|
70 |
- The evaluation key is a public key that the server needs to process encrypted data. It is
|
71 |
therefore transmitted to the server for further processing as well.
|
72 |
"""
|
@@ -91,7 +93,7 @@ with demo:
|
|
91 |
- a user's personal information in order to evaluate his/her credit card eligibility;
|
92 |
- the user’s bank account history, which provides any type of information on the user's
|
93 |
banking information relevant to the decision (here, we consider duration of account);
|
94 |
-
- and
|
95 |
history) that could provide additional insight relevant to the decision.
|
96 |
"""
|
97 |
)
|
@@ -187,7 +189,7 @@ with demo:
|
|
187 |
|
188 |
with gr.Row():
|
189 |
with gr.Column(scale=2):
|
190 |
-
gr.Markdown("###
|
191 |
employed = gr.Radio(["Yes", "No"], label="Is the person employed ?", value="Yes")
|
192 |
years_employed = gr.Dropdown(
|
193 |
choices=YEARS_EMPLOYED_BINS,
|
@@ -197,9 +199,9 @@ with demo:
|
|
197 |
)
|
198 |
|
199 |
with gr.Column():
|
200 |
-
|
201 |
|
202 |
-
|
203 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
204 |
)
|
205 |
|
@@ -220,12 +222,12 @@ with demo:
|
|
220 |
outputs=[encrypted_input_bank],
|
221 |
)
|
222 |
|
223 |
-
# Button to pre-process, generate the key, encrypt and send the
|
224 |
# client side to the server
|
225 |
-
|
226 |
-
|
227 |
inputs=[client_id, years_employed, employed],
|
228 |
-
outputs=[
|
229 |
)
|
230 |
|
231 |
gr.Markdown("# Server side")
|
@@ -248,10 +250,14 @@ with demo:
|
|
248 |
# Button to send the encodings to the server using post method
|
249 |
execute_fhe_button.click(run_fhe, inputs=[client_id], outputs=[fhe_execution_time])
|
250 |
|
251 |
-
gr.Markdown("# Client
|
252 |
gr.Markdown(
|
253 |
"""
|
254 |
Once the server completed the inference, the encrypted output is returned to the user.
|
|
|
|
|
|
|
|
|
255 |
"""
|
256 |
)
|
257 |
|
|
|
24 |
keygen_send,
|
25 |
pre_process_encrypt_send_user,
|
26 |
pre_process_encrypt_send_bank,
|
27 |
+
pre_process_encrypt_send_cs_agency,
|
28 |
run_fhe,
|
29 |
get_output_and_decrypt,
|
30 |
explain_encrypt_run_decrypt,
|
|
|
61 |
"""
|
62 |
)
|
63 |
|
64 |
+
gr.Markdown("# Client, Bank and Credit Scoring Agency setup")
|
65 |
|
66 |
gr.Markdown("## Step 1: Generate the keys.")
|
67 |
gr.Markdown(
|
68 |
"""
|
69 |
+
- The private key is generated jointly by the entities that collaborate to compute the
|
70 |
+
credit score. It is used to encrypt and decrypt the data and shall never be shared with
|
71 |
+
any other party.
|
72 |
- The evaluation key is a public key that the server needs to process encrypted data. It is
|
73 |
therefore transmitted to the server for further processing as well.
|
74 |
"""
|
|
|
93 |
- a user's personal information in order to evaluate his/her credit card eligibility;
|
94 |
- the user’s bank account history, which provides any type of information on the user's
|
95 |
banking information relevant to the decision (here, we consider duration of account);
|
96 |
+
- and credit scoring agency information, which represents any other information (here, employment
|
97 |
history) that could provide additional insight relevant to the decision.
|
98 |
"""
|
99 |
)
|
|
|
189 |
|
190 |
with gr.Row():
|
191 |
with gr.Column(scale=2):
|
192 |
+
gr.Markdown("### Credit Scoring Agency ")
|
193 |
employed = gr.Radio(["Yes", "No"], label="Is the person employed ?", value="Yes")
|
194 |
years_employed = gr.Dropdown(
|
195 |
choices=YEARS_EMPLOYED_BINS,
|
|
|
199 |
)
|
200 |
|
201 |
with gr.Column():
|
202 |
+
encrypt_button_cs_agency = gr.Button("Encrypt the inputs and send to server.")
|
203 |
|
204 |
+
encrypted_input_cs_agency = gr.Textbox(
|
205 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
206 |
)
|
207 |
|
|
|
222 |
outputs=[encrypted_input_bank],
|
223 |
)
|
224 |
|
225 |
+
# Button to pre-process, generate the key, encrypt and send the credit scoring agency inputs from the
|
226 |
# client side to the server
|
227 |
+
encrypt_button_cs_agency.click(
|
228 |
+
pre_process_encrypt_send_cs_agency,
|
229 |
inputs=[client_id, years_employed, employed],
|
230 |
+
outputs=[encrypted_input_cs_agency],
|
231 |
)
|
232 |
|
233 |
gr.Markdown("# Server side")
|
|
|
250 |
# Button to send the encodings to the server using post method
|
251 |
execute_fhe_button.click(run_fhe, inputs=[client_id], outputs=[fhe_execution_time])
|
252 |
|
253 |
+
gr.Markdown("# Client, Bank and Credit Scoring Agency setup")
|
254 |
gr.Markdown(
|
255 |
"""
|
256 |
Once the server completed the inference, the encrypted output is returned to the user.
|
257 |
+
|
258 |
+
The three entities that provide the information to compute the credit score are the only
|
259 |
+
ones that can decrypt the result. They take part in a decryption protocol that allows to
|
260 |
+
only decrypt the full result when all three parties decrypt their share of the result.
|
261 |
"""
|
262 |
)
|
263 |
|
backend.py
CHANGED
@@ -20,11 +20,11 @@ from settings import (
|
|
20 |
INPUT_SLICES,
|
21 |
PRE_PROCESSOR_USER_PATH,
|
22 |
PRE_PROCESSOR_BANK_PATH,
|
23 |
-
|
24 |
CLIENT_TYPES,
|
25 |
USER_COLUMNS,
|
26 |
BANK_COLUMNS,
|
27 |
-
|
28 |
YEARS_EMPLOYED_BINS,
|
29 |
YEARS_EMPLOYED_BIN_NAME_TO_INDEX,
|
30 |
)
|
@@ -35,11 +35,11 @@ from utils.client_server_interface import MultiInputsFHEModelClient
|
|
35 |
with (
|
36 |
PRE_PROCESSOR_USER_PATH.open('rb') as file_user,
|
37 |
PRE_PROCESSOR_BANK_PATH.open('rb') as file_bank,
|
38 |
-
|
39 |
):
|
40 |
PRE_PROCESSOR_USER = pickle.load(file_user)
|
41 |
PRE_PROCESSOR_BANK = pickle.load(file_bank)
|
42 |
-
|
43 |
|
44 |
|
45 |
def shorten_bytes_object(bytes_object, limit=500):
|
@@ -111,7 +111,7 @@ def _get_client_file_path(name, client_id, client_type=None):
|
|
111 |
'encrypted_outputs').
|
112 |
client_id (int): The client ID to consider.
|
113 |
client_type (Optional[str]): The type of user to consider (either 'user', 'bank',
|
114 |
-
'
|
115 |
|
116 |
Returns:
|
117 |
pathlib.Path: The file path.
|
@@ -132,7 +132,7 @@ def _send_to_server(client_id, client_type, file_name):
|
|
132 |
Args:
|
133 |
client_id (int): The client ID to consider.
|
134 |
client_type (Optional[str]): The type of client to consider (either 'user', 'bank',
|
135 |
-
'
|
136 |
file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs').
|
137 |
"""
|
138 |
# Get the paths to the encrypted inputs
|
@@ -204,7 +204,7 @@ def _encrypt_send(client_id, inputs, client_type):
|
|
204 |
Args:
|
205 |
client_id (str): The current client ID to consider.
|
206 |
inputs (numpy.ndarray): The inputs to encrypt.
|
207 |
-
client_type (str): The type of client to consider (either 'user', 'bank' or '
|
208 |
|
209 |
Returns:
|
210 |
encrypted_inputs_short (str): A short representation of the encrypted input to send in hex.
|
@@ -303,8 +303,8 @@ def pre_process_encrypt_send_bank(client_id, *inputs):
|
|
303 |
return _encrypt_send(client_id, preprocessed_bank_inputs, "bank")
|
304 |
|
305 |
|
306 |
-
def
|
307 |
-
"""Pre-process, encrypt and send the
|
308 |
|
309 |
Args:
|
310 |
client_id (str): The current client ID to consider.
|
@@ -318,15 +318,15 @@ def pre_process_encrypt_send_third_party(client_id, *inputs):
|
|
318 |
years_employed = YEARS_EMPLOYED_BIN_NAME_TO_INDEX[years_employed_bin]
|
319 |
is_employed = employed == "Yes"
|
320 |
|
321 |
-
|
322 |
"Years_employed": [years_employed],
|
323 |
"Employed": [is_employed],
|
324 |
})
|
325 |
|
326 |
-
|
327 |
-
|
328 |
|
329 |
-
return _encrypt_send(client_id,
|
330 |
|
331 |
|
332 |
def run_fhe(client_id):
|
@@ -423,7 +423,7 @@ def explain_encrypt_run_decrypt(client_id, prediction_output, *inputs):
|
|
423 |
"Explaining the prediction can only be done if the credit card is likely to be denied."
|
424 |
)
|
425 |
|
426 |
-
# Retrieve the
|
427 |
years_employed, employed = inputs
|
428 |
|
429 |
# Years_employed is divided into several ordered bins. Here, we retrieve the index representing
|
@@ -441,7 +441,7 @@ def explain_encrypt_run_decrypt(client_id, prediction_output, *inputs):
|
|
441 |
for years_employed_bin in YEARS_EMPLOYED_BINS[bin_index+1:]:
|
442 |
|
443 |
# Send the new encrypted input
|
444 |
-
|
445 |
|
446 |
# Run the model in FHE
|
447 |
run_fhe(client_id)
|
@@ -452,9 +452,9 @@ def explain_encrypt_run_decrypt(client_id, prediction_output, *inputs):
|
|
452 |
is_approved = "approved" in output_prediction[0]
|
453 |
output_predictions.append(is_approved)
|
454 |
|
455 |
-
# Re-send the initial
|
456 |
# some inputs basically re-writes the associated file on the server side)
|
457 |
-
|
458 |
|
459 |
# In case the model predicted at least one approval
|
460 |
if any(output_predictions):
|
|
|
20 |
INPUT_SLICES,
|
21 |
PRE_PROCESSOR_USER_PATH,
|
22 |
PRE_PROCESSOR_BANK_PATH,
|
23 |
+
PRE_PROCESSOR_CS_AGENCY_PATH,
|
24 |
CLIENT_TYPES,
|
25 |
USER_COLUMNS,
|
26 |
BANK_COLUMNS,
|
27 |
+
CS_AGENCY_COLUMNS,
|
28 |
YEARS_EMPLOYED_BINS,
|
29 |
YEARS_EMPLOYED_BIN_NAME_TO_INDEX,
|
30 |
)
|
|
|
35 |
with (
|
36 |
PRE_PROCESSOR_USER_PATH.open('rb') as file_user,
|
37 |
PRE_PROCESSOR_BANK_PATH.open('rb') as file_bank,
|
38 |
+
PRE_PROCESSOR_CS_AGENCY_PATH.open('rb') as file_cs_agency,
|
39 |
):
|
40 |
PRE_PROCESSOR_USER = pickle.load(file_user)
|
41 |
PRE_PROCESSOR_BANK = pickle.load(file_bank)
|
42 |
+
PRE_PROCESSOR_CS_AGENCY = pickle.load(file_cs_agency)
|
43 |
|
44 |
|
45 |
def shorten_bytes_object(bytes_object, limit=500):
|
|
|
111 |
'encrypted_outputs').
|
112 |
client_id (int): The client ID to consider.
|
113 |
client_type (Optional[str]): The type of user to consider (either 'user', 'bank',
|
114 |
+
'cs_agency' or None). Default to None, which is used for evaluation key and output.
|
115 |
|
116 |
Returns:
|
117 |
pathlib.Path: The file path.
|
|
|
132 |
Args:
|
133 |
client_id (int): The client ID to consider.
|
134 |
client_type (Optional[str]): The type of client to consider (either 'user', 'bank',
|
135 |
+
'cs_agency' or None).
|
136 |
file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs').
|
137 |
"""
|
138 |
# Get the paths to the encrypted inputs
|
|
|
204 |
Args:
|
205 |
client_id (str): The current client ID to consider.
|
206 |
inputs (numpy.ndarray): The inputs to encrypt.
|
207 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'cs_agency').
|
208 |
|
209 |
Returns:
|
210 |
encrypted_inputs_short (str): A short representation of the encrypted input to send in hex.
|
|
|
303 |
return _encrypt_send(client_id, preprocessed_bank_inputs, "bank")
|
304 |
|
305 |
|
306 |
+
def pre_process_encrypt_send_cs_agency(client_id, *inputs):
|
307 |
+
"""Pre-process, encrypt and send the credit scoring agency inputs for a specific client to the server.
|
308 |
|
309 |
Args:
|
310 |
client_id (str): The current client ID to consider.
|
|
|
318 |
years_employed = YEARS_EMPLOYED_BIN_NAME_TO_INDEX[years_employed_bin]
|
319 |
is_employed = employed == "Yes"
|
320 |
|
321 |
+
cs_agency_inputs = pandas.DataFrame({
|
322 |
"Years_employed": [years_employed],
|
323 |
"Employed": [is_employed],
|
324 |
})
|
325 |
|
326 |
+
cs_agency_inputs = cs_agency_inputs.reindex(CS_AGENCY_COLUMNS, axis=1)
|
327 |
+
preprocessed_cs_agency_inputs = PRE_PROCESSOR_CS_AGENCY.transform(cs_agency_inputs)
|
328 |
|
329 |
+
return _encrypt_send(client_id, preprocessed_cs_agency_inputs, "cs_agency")
|
330 |
|
331 |
|
332 |
def run_fhe(client_id):
|
|
|
423 |
"Explaining the prediction can only be done if the credit card is likely to be denied."
|
424 |
)
|
425 |
|
426 |
+
# Retrieve the credit scoring agency inputs
|
427 |
years_employed, employed = inputs
|
428 |
|
429 |
# Years_employed is divided into several ordered bins. Here, we retrieve the index representing
|
|
|
441 |
for years_employed_bin in YEARS_EMPLOYED_BINS[bin_index+1:]:
|
442 |
|
443 |
# Send the new encrypted input
|
444 |
+
pre_process_encrypt_send_cs_agency(client_id, years_employed_bin, employed)
|
445 |
|
446 |
# Run the model in FHE
|
447 |
run_fhe(client_id)
|
|
|
452 |
is_approved = "approved" in output_prediction[0]
|
453 |
output_predictions.append(is_approved)
|
454 |
|
455 |
+
# Re-send the initial credit scoring agency inputs in order to avoid unwanted conflict (as sending
|
456 |
# some inputs basically re-writes the associated file on the server side)
|
457 |
+
pre_process_encrypt_send_cs_agency(client_id, years_employed, employed)
|
458 |
|
459 |
# In case the model predicted at least one approval
|
460 |
if any(output_predictions):
|
deployment_files/model/{pre_processor_third_party.pkl → pre_processor_cs_agency.pkl}
RENAMED
File without changes
|
deployment_files/{pre_processor_third_party.pkl → pre_processor_cs_agency.pkl}
RENAMED
File without changes
|
development.py
CHANGED
@@ -11,10 +11,10 @@ from settings import (
|
|
11 |
INPUT_SLICES,
|
12 |
PRE_PROCESSOR_USER_PATH,
|
13 |
PRE_PROCESSOR_BANK_PATH,
|
14 |
-
|
15 |
USER_COLUMNS,
|
16 |
BANK_COLUMNS,
|
17 |
-
|
18 |
)
|
19 |
from utils.client_server_interface import MultiInputsFHEModelDev
|
20 |
from utils.model import MultiInputDecisionTreeClassifier, MultiInputDecisionTreeRegressor
|
@@ -33,7 +33,7 @@ def get_multi_inputs(data):
|
|
33 |
return (
|
34 |
data[:, INPUT_SLICES["user"]],
|
35 |
data[:, INPUT_SLICES["bank"]],
|
36 |
-
data[:, INPUT_SLICES["
|
37 |
)
|
38 |
|
39 |
|
@@ -49,16 +49,16 @@ data_y = data_x.pop("Target").copy().to_frame()
|
|
49 |
# Get data from all parties
|
50 |
data_user = data_x[USER_COLUMNS].copy()
|
51 |
data_bank = data_x[BANK_COLUMNS].copy()
|
52 |
-
|
53 |
|
54 |
# Feature engineer the data
|
55 |
-
pre_processor_user, pre_processor_bank,
|
56 |
|
57 |
preprocessed_data_user = pre_processor_user.fit_transform(data_user)
|
58 |
preprocessed_data_bank = pre_processor_bank.fit_transform(data_bank)
|
59 |
-
|
60 |
|
61 |
-
preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank,
|
62 |
|
63 |
|
64 |
print("\nTrain and compile the model")
|
@@ -85,10 +85,10 @@ fhe_model_dev.save(via_mlir=True)
|
|
85 |
with (
|
86 |
PRE_PROCESSOR_USER_PATH.open('wb') as file_user,
|
87 |
PRE_PROCESSOR_BANK_PATH.open('wb') as file_bank,
|
88 |
-
|
89 |
):
|
90 |
pickle.dump(pre_processor_user, file_user)
|
91 |
pickle.dump(pre_processor_bank, file_bank)
|
92 |
-
pickle.dump(
|
93 |
|
94 |
print("\nDone !")
|
|
|
11 |
INPUT_SLICES,
|
12 |
PRE_PROCESSOR_USER_PATH,
|
13 |
PRE_PROCESSOR_BANK_PATH,
|
14 |
+
PRE_PROCESSOR_CS_AGENCY_PATH,
|
15 |
USER_COLUMNS,
|
16 |
BANK_COLUMNS,
|
17 |
+
CS_AGENCY_COLUMNS,
|
18 |
)
|
19 |
from utils.client_server_interface import MultiInputsFHEModelDev
|
20 |
from utils.model import MultiInputDecisionTreeClassifier, MultiInputDecisionTreeRegressor
|
|
|
33 |
return (
|
34 |
data[:, INPUT_SLICES["user"]],
|
35 |
data[:, INPUT_SLICES["bank"]],
|
36 |
+
data[:, INPUT_SLICES["cs_agency"]]
|
37 |
)
|
38 |
|
39 |
|
|
|
49 |
# Get data from all parties
|
50 |
data_user = data_x[USER_COLUMNS].copy()
|
51 |
data_bank = data_x[BANK_COLUMNS].copy()
|
52 |
+
data_cs_agency = data_x[CS_AGENCY_COLUMNS].copy()
|
53 |
|
54 |
# Feature engineer the data
|
55 |
+
pre_processor_user, pre_processor_bank, pre_processor_cs_agency = get_pre_processors()
|
56 |
|
57 |
preprocessed_data_user = pre_processor_user.fit_transform(data_user)
|
58 |
preprocessed_data_bank = pre_processor_bank.fit_transform(data_bank)
|
59 |
+
preprocessed_data_cs_agency = pre_processor_cs_agency.fit_transform(data_cs_agency)
|
60 |
|
61 |
+
preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_cs_agency), axis=1)
|
62 |
|
63 |
|
64 |
print("\nTrain and compile the model")
|
|
|
85 |
with (
|
86 |
PRE_PROCESSOR_USER_PATH.open('wb') as file_user,
|
87 |
PRE_PROCESSOR_BANK_PATH.open('wb') as file_bank,
|
88 |
+
PRE_PROCESSOR_CS_AGENCY_PATH.open('wb') as file_cs_agency,
|
89 |
):
|
90 |
pickle.dump(pre_processor_user, file_user)
|
91 |
pickle.dump(pre_processor_bank, file_bank)
|
92 |
+
pickle.dump(pre_processor_cs_agency, file_cs_agency)
|
93 |
|
94 |
print("\nDone !")
|
server.py
CHANGED
@@ -20,7 +20,7 @@ def _get_server_file_path(name, client_id, client_type=None):
|
|
20 |
'encrypted_outputs').
|
21 |
client_id (int): The client ID to consider.
|
22 |
client_type (Optional[str]): The type of user to consider (either 'user', 'bank',
|
23 |
-
'
|
24 |
|
25 |
Returns:
|
26 |
pathlib.Path: The file path.
|
|
|
20 |
'encrypted_outputs').
|
21 |
client_id (int): The client ID to consider.
|
22 |
client_type (Optional[str]): The type of user to consider (either 'user', 'bank',
|
23 |
+
'cs_agency' or None). Default to None, which is used for evaluation key and output.
|
24 |
|
25 |
Returns:
|
26 |
pathlib.Path: The file path.
|
settings.py
CHANGED
@@ -18,7 +18,7 @@ DEPLOYMENT_PATH = DEPLOYMENT_PATH / "model"
|
|
18 |
# Path targeting pre-processor saved files
|
19 |
PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
|
20 |
PRE_PROCESSOR_BANK_PATH = DEPLOYMENT_PATH / 'pre_processor_bank.pkl'
|
21 |
-
|
22 |
|
23 |
# Create the necessary directories
|
24 |
FHE_KEYS.mkdir(exist_ok=True)
|
@@ -34,16 +34,16 @@ DATA_PATH = "data/data.csv"
|
|
34 |
# Development settings
|
35 |
PROCESSED_INPUT_SHAPE = (1, 39)
|
36 |
|
37 |
-
CLIENT_TYPES = ["user", "bank", "
|
38 |
INPUT_INDEXES = {
|
39 |
"user": 0,
|
40 |
"bank": 1,
|
41 |
-
"
|
42 |
}
|
43 |
INPUT_SLICES = {
|
44 |
"user": slice(0, 36), # First position: start from 0
|
45 |
"bank": slice(36, 37), # Second position: start from n_feature_user
|
46 |
-
"
|
47 |
}
|
48 |
|
49 |
# Fix column order for pre-processing steps
|
@@ -53,7 +53,7 @@ USER_COLUMNS = [
|
|
53 |
'Occupation_type',
|
54 |
]
|
55 |
BANK_COLUMNS = ["Account_age"]
|
56 |
-
|
57 |
|
58 |
_data = pandas.read_csv(DATA_PATH, encoding="utf-8")
|
59 |
|
|
|
18 |
# Path targeting pre-processor saved files
|
19 |
PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
|
20 |
PRE_PROCESSOR_BANK_PATH = DEPLOYMENT_PATH / 'pre_processor_bank.pkl'
|
21 |
+
PRE_PROCESSOR_CS_AGENCY_PATH = DEPLOYMENT_PATH / 'pre_processor_cs_agency.pkl'
|
22 |
|
23 |
# Create the necessary directories
|
24 |
FHE_KEYS.mkdir(exist_ok=True)
|
|
|
34 |
# Development settings
|
35 |
PROCESSED_INPUT_SHAPE = (1, 39)
|
36 |
|
37 |
+
CLIENT_TYPES = ["user", "bank", "cs_agency"]
|
38 |
INPUT_INDEXES = {
|
39 |
"user": 0,
|
40 |
"bank": 1,
|
41 |
+
"cs_agency": 2,
|
42 |
}
|
43 |
INPUT_SLICES = {
|
44 |
"user": slice(0, 36), # First position: start from 0
|
45 |
"bank": slice(36, 37), # Second position: start from n_feature_user
|
46 |
+
"cs_agency": slice(37, 39), # Third position: start from n_feature_user + n_feature_bank
|
47 |
}
|
48 |
|
49 |
# Fix column order for pre-processing steps
|
|
|
53 |
'Occupation_type',
|
54 |
]
|
55 |
BANK_COLUMNS = ["Account_age"]
|
56 |
+
CS_AGENCY_COLUMNS = ["Years_employed", "Employed"]
|
57 |
|
58 |
_data = pandas.read_csv(DATA_PATH, encoding="utf-8")
|
59 |
|
utils/client_server_interface.py
CHANGED
@@ -47,7 +47,7 @@ class MultiInputsFHEModelClient(FHEModelClient):
|
|
47 |
x (numpy.ndarray): The input to consider. Here, the input should only represent a
|
48 |
single party.
|
49 |
input_index (int): The index representing the type of model (0: "user", 1: "bank",
|
50 |
-
2: "
|
51 |
processed_input_shape (Tuple[int]): The total input shape (all parties combined) after
|
52 |
pre-processing.
|
53 |
input_slice (slice): The slices to consider for the given party.
|
|
|
47 |
x (numpy.ndarray): The input to consider. Here, the input should only represent a
|
48 |
single party.
|
49 |
input_index (int): The index representing the type of model (0: "user", 1: "bank",
|
50 |
+
2: "cs_agency")
|
51 |
processed_input_shape (Tuple[int]): The total input shape (all parties combined) after
|
52 |
pre-processing.
|
53 |
input_slice (slice): The slices to consider for the given party.
|
utils/pre_processing.py
CHANGED
@@ -55,10 +55,10 @@ def get_pre_processors():
|
|
55 |
verbose_feature_names_out=False,
|
56 |
)
|
57 |
|
58 |
-
|
59 |
transformers=[],
|
60 |
remainder='passthrough',
|
61 |
verbose_feature_names_out=False,
|
62 |
)
|
63 |
|
64 |
-
return pre_processor_user, pre_processor_bank,
|
|
|
55 |
verbose_feature_names_out=False,
|
56 |
)
|
57 |
|
58 |
+
pre_processor_cs_agency = ColumnTransformer(
|
59 |
transformers=[],
|
60 |
remainder='passthrough',
|
61 |
verbose_feature_names_out=False,
|
62 |
)
|
63 |
|
64 |
+
return pre_processor_user, pre_processor_bank, pre_processor_cs_agency
|