romanbredehoft-zama
commited on
Commit
β’
74c0c8e
1
Parent(s):
b0303a0
Add second model for optional explainability step
Browse files- app.py +69 -40
- backend.py +171 -36
- deployment_files/{client.zip β approval_model/client.zip} +2 -2
- deployment_files/{server.zip β approval_model/server.zip} +2 -2
- deployment_files/{versions.json β approval_model/versions.json} +0 -0
- deployment_files/explain_model/client.zip +3 -0
- deployment_files/explain_model/server.zip +3 -0
- deployment_files/explain_model/versions.json +1 -0
- development.py +78 -17
- server.py +2 -2
- settings.py +16 -4
- utils/client_server_interface.py +30 -11
- utils/model.py +5 -43
app.py
CHANGED
@@ -26,6 +26,7 @@ from backend import (
|
|
26 |
run_fhe,
|
27 |
get_output,
|
28 |
decrypt_output,
|
|
|
29 |
)
|
30 |
|
31 |
|
@@ -60,6 +61,12 @@ with demo:
|
|
60 |
)
|
61 |
client_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
gr.Markdown("## Step 2: Fill in some information.")
|
64 |
gr.Markdown(
|
65 |
"""
|
@@ -125,6 +132,31 @@ with demo:
|
|
125 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
126 |
)
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
gr.Markdown("# Server side")
|
129 |
gr.Markdown(
|
130 |
"""
|
@@ -142,6 +174,9 @@ with demo:
|
|
142 |
label="Total FHE execution time (in seconds):", max_lines=1, interactive=False
|
143 |
)
|
144 |
|
|
|
|
|
|
|
145 |
gr.Markdown("# Client side")
|
146 |
gr.Markdown(
|
147 |
"""
|
@@ -161,6 +196,13 @@ with demo:
|
|
161 |
label="Encrypted output representation: ", max_lines=2, interactive=False
|
162 |
)
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
gr.Markdown("## Step 6: Decrypt the output.")
|
165 |
gr.Markdown(
|
166 |
"""
|
@@ -173,52 +215,39 @@ with demo:
|
|
173 |
label="Prediction", max_lines=1, interactive=False
|
174 |
)
|
175 |
|
176 |
-
# Button
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
# Button to pre-process, generate the key, encrypt and send the user inputs from the client
|
183 |
-
# side to the server
|
184 |
-
encrypt_button_user.click(
|
185 |
-
pre_process_encrypt_send_user,
|
186 |
-
inputs=[client_id, bool_inputs, num_children, household_size, total_income, age, \
|
187 |
-
income_type, education_type, family_status, occupation_type, housing_type],
|
188 |
-
outputs=[encrypted_input_user],
|
189 |
)
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
197 |
)
|
198 |
-
|
199 |
-
|
200 |
-
# client side to the server
|
201 |
-
encrypt_button_third_party.click(
|
202 |
-
pre_process_encrypt_send_third_party,
|
203 |
-
inputs=[client_id, employed, years_employed],
|
204 |
-
outputs=[encrypted_input_third_party],
|
205 |
)
|
206 |
-
|
207 |
-
|
208 |
-
execute_fhe_button.click(run_fhe, inputs=[client_id], outputs=[fhe_execution_time])
|
209 |
-
|
210 |
-
# Button to send the encodings to the server using post method
|
211 |
-
get_output_button.click(
|
212 |
-
get_output,
|
213 |
-
inputs=[client_id],
|
214 |
-
outputs=[encrypted_output_representation],
|
215 |
)
|
216 |
|
217 |
-
# Button to
|
218 |
-
|
219 |
-
|
220 |
-
inputs=[client_id
|
221 |
-
|
|
|
|
|
222 |
)
|
223 |
|
224 |
gr.Markdown(
|
|
|
26 |
run_fhe,
|
27 |
get_output,
|
28 |
decrypt_output,
|
29 |
+
years_employed_encrypt_run_decrypt,
|
30 |
)
|
31 |
|
32 |
|
|
|
61 |
)
|
62 |
client_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
63 |
|
64 |
+
# Button generate the keys
|
65 |
+
keygen_button.click(
|
66 |
+
keygen_send,
|
67 |
+
outputs=[client_id, evaluation_key, keygen_button],
|
68 |
+
)
|
69 |
+
|
70 |
gr.Markdown("## Step 2: Fill in some information.")
|
71 |
gr.Markdown(
|
72 |
"""
|
|
|
132 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
133 |
)
|
134 |
|
135 |
+
# Button to pre-process, generate the key, encrypt and send the user inputs from the client
|
136 |
+
# side to the server
|
137 |
+
encrypt_button_user.click(
|
138 |
+
pre_process_encrypt_send_user,
|
139 |
+
inputs=[client_id, bool_inputs, num_children, household_size, total_income, age, \
|
140 |
+
income_type, education_type, family_status, occupation_type, housing_type],
|
141 |
+
outputs=[encrypted_input_user],
|
142 |
+
)
|
143 |
+
|
144 |
+
# Button to pre-process, generate the key, encrypt and send the bank inputs from the client
|
145 |
+
# side to the server
|
146 |
+
encrypt_button_bank.click(
|
147 |
+
pre_process_encrypt_send_bank,
|
148 |
+
inputs=[client_id, account_age],
|
149 |
+
outputs=[encrypted_input_bank],
|
150 |
+
)
|
151 |
+
|
152 |
+
# Button to pre-process, generate the key, encrypt and send the third party inputs from the
|
153 |
+
# client side to the server
|
154 |
+
encrypt_button_third_party.click(
|
155 |
+
pre_process_encrypt_send_third_party,
|
156 |
+
inputs=[client_id, employed, years_employed],
|
157 |
+
outputs=[encrypted_input_third_party],
|
158 |
+
)
|
159 |
+
|
160 |
gr.Markdown("# Server side")
|
161 |
gr.Markdown(
|
162 |
"""
|
|
|
174 |
label="Total FHE execution time (in seconds):", max_lines=1, interactive=False
|
175 |
)
|
176 |
|
177 |
+
# Button to send the encodings to the server using post method
|
178 |
+
execute_fhe_button.click(run_fhe, inputs=[client_id], outputs=[fhe_execution_time])
|
179 |
+
|
180 |
gr.Markdown("# Client side")
|
181 |
gr.Markdown(
|
182 |
"""
|
|
|
196 |
label="Encrypted output representation: ", max_lines=2, interactive=False
|
197 |
)
|
198 |
|
199 |
+
# Button to send the encodings to the server using post method
|
200 |
+
get_output_button.click(
|
201 |
+
get_output,
|
202 |
+
inputs=[client_id],
|
203 |
+
outputs=[encrypted_output_representation],
|
204 |
+
)
|
205 |
+
|
206 |
gr.Markdown("## Step 6: Decrypt the output.")
|
207 |
gr.Markdown(
|
208 |
"""
|
|
|
215 |
label="Prediction", max_lines=1, interactive=False
|
216 |
)
|
217 |
|
218 |
+
# Button to decrypt the output
|
219 |
+
decrypt_button.click(
|
220 |
+
decrypt_output,
|
221 |
+
inputs=[client_id],
|
222 |
+
outputs=[prediction_output],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
)
|
224 |
|
225 |
+
gr.Markdown("## Step 7 (optional): Explain the prediction.")
|
226 |
+
gr.Markdown(
|
227 |
+
"""
|
228 |
+
In case the credit card is likely to be denied, the user can run a second model in order to
|
229 |
+
Explain the prediction better. More specifically, this new model indicates the number of
|
230 |
+
additional years of employment that could be required in order to increase the chance of
|
231 |
+
credit card approval.
|
232 |
+
All of the above steps are combined into a single button for simplicity. The following
|
233 |
+
button therefore encrypts the same inputs (except the years of employment) from all three
|
234 |
+
parties, runs the new prediction in FHE and decrypts the output.
|
235 |
+
"""
|
236 |
)
|
237 |
+
years_employed_prediction_button = gr.Button(
|
238 |
+
"Encrypt the inputs, compute in FHE and decrypt the output."
|
|
|
|
|
|
|
|
|
|
|
239 |
)
|
240 |
+
years_employed_prediction = gr.Textbox(
|
241 |
+
label="Additional years of employed required.", max_lines=1, interactive=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
)
|
243 |
|
244 |
+
# Button to explain the prediction
|
245 |
+
years_employed_prediction_button.click(
|
246 |
+
years_employed_encrypt_run_decrypt,
|
247 |
+
inputs=[client_id, prediction_output, bool_inputs, num_children, household_size, \
|
248 |
+
total_income, age, income_type, education_type, family_status, occupation_type, \
|
249 |
+
housing_type, account_age, employed, years_employed],
|
250 |
+
outputs=[years_employed_prediction],
|
251 |
)
|
252 |
|
253 |
gr.Markdown(
|
backend.py
CHANGED
@@ -14,20 +14,26 @@ from settings import (
|
|
14 |
FHE_KEYS,
|
15 |
CLIENT_FILES,
|
16 |
SERVER_FILES,
|
17 |
-
|
18 |
-
|
|
|
|
|
19 |
INPUT_INDEXES,
|
20 |
-
|
|
|
21 |
PRE_PROCESSOR_USER_PATH,
|
22 |
PRE_PROCESSOR_BANK_PATH,
|
23 |
PRE_PROCESSOR_THIRD_PARTY_PATH,
|
24 |
CLIENT_TYPES,
|
25 |
USER_COLUMNS,
|
26 |
BANK_COLUMNS,
|
27 |
-
|
28 |
)
|
29 |
|
30 |
-
from utils.client_server_interface import MultiInputsFHEModelClient
|
|
|
|
|
|
|
31 |
|
32 |
# Load pre-processor instances
|
33 |
with (
|
@@ -87,18 +93,22 @@ def clean_temporary_files(n_keys=20):
|
|
87 |
shutil.rmtree(directory)
|
88 |
|
89 |
|
90 |
-
def _get_client(client_id):
|
91 |
"""Get the client instance.
|
92 |
|
93 |
Args:
|
94 |
client_id (int): The client ID to consider.
|
|
|
|
|
95 |
|
96 |
Returns:
|
97 |
FHEModelClient: The client instance.
|
98 |
"""
|
99 |
-
|
|
|
|
|
100 |
|
101 |
-
return MultiInputsFHEModelClient(
|
102 |
|
103 |
|
104 |
def _get_client_file_path(name, client_id, client_type=None):
|
@@ -196,7 +206,7 @@ def keygen_send():
|
|
196 |
return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent β
")
|
197 |
|
198 |
|
199 |
-
def _encrypt_send(client_id, inputs, client_type):
|
200 |
"""Encrypt the given inputs for a specific client and send it to the server.
|
201 |
|
202 |
Args:
|
@@ -205,8 +215,7 @@ def _encrypt_send(client_id, inputs, client_type):
|
|
205 |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
206 |
|
207 |
Returns:
|
208 |
-
|
209 |
-
and a byte short representation of the encrypted input to send.
|
210 |
"""
|
211 |
if client_id == "":
|
212 |
raise gr.Error("Please generate the keys first.")
|
@@ -218,8 +227,8 @@ def _encrypt_send(client_id, inputs, client_type):
|
|
218 |
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
|
219 |
inputs,
|
220 |
input_index=INPUT_INDEXES[client_type],
|
221 |
-
processed_input_shape=
|
222 |
-
input_slice=
|
223 |
)
|
224 |
|
225 |
file_name = "encrypted_inputs"
|
@@ -239,16 +248,14 @@ def _encrypt_send(client_id, inputs, client_type):
|
|
239 |
return encrypted_inputs_short
|
240 |
|
241 |
|
242 |
-
def
|
243 |
-
"""Pre-process
|
244 |
|
245 |
Args:
|
246 |
-
client_id (str): The current client ID to consider.
|
247 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
248 |
|
249 |
Returns:
|
250 |
-
(
|
251 |
-
the encrypted input to send.
|
252 |
"""
|
253 |
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
|
254 |
family_status, occupation_type, housing_type = inputs
|
@@ -277,19 +284,32 @@ def pre_process_encrypt_send_user(client_id, *inputs):
|
|
277 |
|
278 |
preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
|
279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
return _encrypt_send(client_id, preprocessed_user_inputs, "user")
|
281 |
|
282 |
|
283 |
-
def
|
284 |
-
"""Pre-process
|
285 |
|
286 |
Args:
|
287 |
-
client_id (str): The current client ID to consider.
|
288 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
289 |
|
290 |
Returns:
|
291 |
-
(
|
292 |
-
the encrypted input to send.
|
293 |
"""
|
294 |
account_age = inputs[0]
|
295 |
|
@@ -301,32 +321,65 @@ def pre_process_encrypt_send_bank(client_id, *inputs):
|
|
301 |
|
302 |
preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs)
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
return _encrypt_send(client_id, preprocessed_bank_inputs, "bank")
|
305 |
|
306 |
|
307 |
-
def
|
308 |
-
"""Pre-process
|
309 |
|
310 |
Args:
|
311 |
-
client_id (str): The current client ID to consider.
|
312 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
313 |
|
314 |
Returns:
|
315 |
-
(
|
316 |
-
the encrypted input to send.
|
317 |
"""
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
is_employed = employed == "Yes"
|
|
|
321 |
|
322 |
-
third_party_inputs = pandas.DataFrame(
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
-
|
|
|
|
|
328 |
|
329 |
-
|
|
|
|
|
|
|
330 |
|
331 |
return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party")
|
332 |
|
@@ -430,4 +483,86 @@ def decrypt_output(client_id):
|
|
430 |
# Determine the predicted class
|
431 |
output = numpy.argmax(output_proba, axis=1).squeeze()
|
432 |
|
433 |
-
return "Credit card is likely to be approved β
" if output == 1 else "Credit card is likely to be denied β"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
FHE_KEYS,
|
15 |
CLIENT_FILES,
|
16 |
SERVER_FILES,
|
17 |
+
APPROVAL_DEPLOYMENT_PATH,
|
18 |
+
EXPLAIN_DEPLOYMENT_PATH,
|
19 |
+
APPROVAL_PROCESSED_INPUT_SHAPE,
|
20 |
+
EXPLAIN_PROCESSED_INPUT_SHAPE,
|
21 |
INPUT_INDEXES,
|
22 |
+
APPROVAL_INPUT_SLICES,
|
23 |
+
EXPLAIN_INPUT_SLICES,
|
24 |
PRE_PROCESSOR_USER_PATH,
|
25 |
PRE_PROCESSOR_BANK_PATH,
|
26 |
PRE_PROCESSOR_THIRD_PARTY_PATH,
|
27 |
CLIENT_TYPES,
|
28 |
USER_COLUMNS,
|
29 |
BANK_COLUMNS,
|
30 |
+
APPROVAL_THIRD_PARTY_COLUMNS,
|
31 |
)
|
32 |
|
33 |
+
from utils.client_server_interface import MultiInputsFHEModelClient, MultiInputsFHEModelServer
|
34 |
+
|
35 |
+
# Load the server used for explaining the prediction
|
36 |
+
EXPLAIN_FHE_SERVER = MultiInputsFHEModelServer(EXPLAIN_DEPLOYMENT_PATH)
|
37 |
|
38 |
# Load pre-processor instances
|
39 |
with (
|
|
|
93 |
shutil.rmtree(directory)
|
94 |
|
95 |
|
96 |
+
def _get_client(client_id, is_approval=True):
|
97 |
"""Get the client instance.
|
98 |
|
99 |
Args:
|
100 |
client_id (int): The client ID to consider.
|
101 |
+
is_approval (bool): If client is representing the 'approval' model (else, it is
|
102 |
+
representing the 'explain' model). Default to True.
|
103 |
|
104 |
Returns:
|
105 |
FHEModelClient: The client instance.
|
106 |
"""
|
107 |
+
key_suffix = "approval" if is_approval else "explain"
|
108 |
+
key_dir = FHE_KEYS / f"{client_id}_{key_suffix}"
|
109 |
+
client_dir = APPROVAL_DEPLOYMENT_PATH if is_approval else EXPLAIN_DEPLOYMENT_PATH
|
110 |
|
111 |
+
return MultiInputsFHEModelClient(client_dir, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES))
|
112 |
|
113 |
|
114 |
def _get_client_file_path(name, client_id, client_type=None):
|
|
|
206 |
return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent β
")
|
207 |
|
208 |
|
209 |
+
def _encrypt_send(client_id, inputs, client_type, app_mode=True):
|
210 |
"""Encrypt the given inputs for a specific client and send it to the server.
|
211 |
|
212 |
Args:
|
|
|
215 |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
216 |
|
217 |
Returns:
|
218 |
+
encrypted_inputs_short (str): A short representation of the encrypted input to send in hex.
|
|
|
219 |
"""
|
220 |
if client_id == "":
|
221 |
raise gr.Error("Please generate the keys first.")
|
|
|
227 |
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
|
228 |
inputs,
|
229 |
input_index=INPUT_INDEXES[client_type],
|
230 |
+
processed_input_shape=APPROVAL_PROCESSED_INPUT_SHAPE,
|
231 |
+
input_slice=APPROVAL_INPUT_SLICES[client_type],
|
232 |
)
|
233 |
|
234 |
file_name = "encrypted_inputs"
|
|
|
248 |
return encrypted_inputs_short
|
249 |
|
250 |
|
251 |
+
def _pre_process_user(*inputs):
|
252 |
+
"""Pre-process the user inputs.
|
253 |
|
254 |
Args:
|
|
|
255 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
256 |
|
257 |
Returns:
|
258 |
+
(numpy.ndarray): The pre-processed inputs.
|
|
|
259 |
"""
|
260 |
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
|
261 |
family_status, occupation_type, housing_type = inputs
|
|
|
284 |
|
285 |
preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
|
286 |
|
287 |
+
return preprocessed_user_inputs
|
288 |
+
|
289 |
+
|
290 |
+
def pre_process_encrypt_send_user(client_id, *inputs):
|
291 |
+
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
client_id (str): The current client ID to consider.
|
295 |
+
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
(str): A short representation of the encrypted input to send in hex.
|
299 |
+
"""
|
300 |
+
preprocessed_user_inputs = _pre_process_user(*inputs)
|
301 |
+
|
302 |
return _encrypt_send(client_id, preprocessed_user_inputs, "user")
|
303 |
|
304 |
|
305 |
+
def _pre_process_bank(*inputs):
|
306 |
+
"""Pre-process the bank inputs.
|
307 |
|
308 |
Args:
|
|
|
309 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
310 |
|
311 |
Returns:
|
312 |
+
(numpy.ndarray): The pre-processed inputs.
|
|
|
313 |
"""
|
314 |
account_age = inputs[0]
|
315 |
|
|
|
321 |
|
322 |
preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs)
|
323 |
|
324 |
+
return preprocessed_bank_inputs
|
325 |
+
|
326 |
+
|
327 |
+
def pre_process_encrypt_send_bank(client_id, *inputs):
|
328 |
+
"""Pre-process, encrypt and send the bank inputs for a specific client to the server.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
client_id (str): The current client ID to consider.
|
332 |
+
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
333 |
+
|
334 |
+
Returns:
|
335 |
+
(str): A short representation of the encrypted input to send in hex.
|
336 |
+
"""
|
337 |
+
preprocessed_bank_inputs = _pre_process_bank(*inputs)
|
338 |
+
|
339 |
return _encrypt_send(client_id, preprocessed_bank_inputs, "bank")
|
340 |
|
341 |
|
342 |
+
def _pre_process_third_party(*inputs):
|
343 |
+
"""Pre-process the third party inputs.
|
344 |
|
345 |
Args:
|
|
|
346 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
347 |
|
348 |
Returns:
|
349 |
+
(numpy.ndarray): The pre-processed inputs.
|
|
|
350 |
"""
|
351 |
+
third_party_data = {}
|
352 |
+
if len(inputs) == 1:
|
353 |
+
employed = inputs[0]
|
354 |
+
else:
|
355 |
+
employed, years_employed = inputs
|
356 |
+
third_party_data["Years_employed"] = [years_employed]
|
357 |
|
358 |
is_employed = employed == "Yes"
|
359 |
+
third_party_data["Employed"] = [is_employed]
|
360 |
|
361 |
+
third_party_inputs = pandas.DataFrame(third_party_data)
|
362 |
+
|
363 |
+
if len(inputs) == 1:
|
364 |
+
preprocessed_third_party_inputs = third_party_inputs.to_numpy()
|
365 |
+
else:
|
366 |
+
third_party_inputs = third_party_inputs.reindex(APPROVAL_THIRD_PARTY_COLUMNS, axis=1)
|
367 |
+
preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
|
368 |
+
|
369 |
+
return preprocessed_third_party_inputs
|
370 |
+
|
371 |
+
|
372 |
+
def pre_process_encrypt_send_third_party(client_id, *inputs):
|
373 |
+
"""Pre-process, encrypt and send the third party inputs for a specific client to the server.
|
374 |
|
375 |
+
Args:
|
376 |
+
client_id (str): The current client ID to consider.
|
377 |
+
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
378 |
|
379 |
+
Returns:
|
380 |
+
(str): A short representation of the encrypted input to send in hex.
|
381 |
+
"""
|
382 |
+
preprocessed_third_party_inputs = _pre_process_third_party(*inputs)
|
383 |
|
384 |
return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party")
|
385 |
|
|
|
483 |
# Determine the predicted class
|
484 |
output = numpy.argmax(output_proba, axis=1).squeeze()
|
485 |
|
486 |
+
return "Credit card is likely to be approved β
" if output == 1 else "Credit card is likely to be denied β"
|
487 |
+
|
488 |
+
|
489 |
+
def years_employed_encrypt_run_decrypt(client_id, prediction_output, *inputs):
|
490 |
+
"""Pre-process and encrypt the inputs, run the prediction in FHE and decrypt the output.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
client_id (str): The current client ID to consider.
|
494 |
+
prediction_output (str): The initial prediction output. This parameter is only used to
|
495 |
+
throw an error in case the prediction was positive.
|
496 |
+
*inputs (Tuple[numpy.ndarray]): The inputs to consider.
|
497 |
+
|
498 |
+
Returns:
|
499 |
+
(str): A message indicating the number of additional years of employment that could be
|
500 |
+
required in order to increase the chance of
|
501 |
+
credit card approval.
|
502 |
+
"""
|
503 |
+
|
504 |
+
if "approved" in prediction_output:
|
505 |
+
raise gr.Error(
|
506 |
+
"Explaining the prediction can only be done if the credit card is likely to be denied."
|
507 |
+
)
|
508 |
+
|
509 |
+
# Retrieve the client instance
|
510 |
+
client = _get_client(client_id, is_approval=False)
|
511 |
+
|
512 |
+
# Generate the private and evaluation keys
|
513 |
+
client.generate_private_and_evaluation_keys(force=False)
|
514 |
+
|
515 |
+
# Retrieve the serialized evaluation key
|
516 |
+
evaluation_key = client.get_serialized_evaluation_keys()
|
517 |
+
|
518 |
+
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
|
519 |
+
family_status, occupation_type, housing_type, account_age, employed, years_employed = inputs
|
520 |
+
|
521 |
+
preprocessed_user_inputs = _pre_process_user(
|
522 |
+
bool_inputs, num_children, household_size, total_income, age, income_type, education_type,
|
523 |
+
family_status, occupation_type, housing_type,
|
524 |
+
)
|
525 |
+
preprocessed_bank_inputs = _pre_process_bank(account_age)
|
526 |
+
preprocessed_third_party_inputs = _pre_process_third_party(employed)
|
527 |
+
|
528 |
+
preprocessed_inputs = [
|
529 |
+
preprocessed_user_inputs,
|
530 |
+
preprocessed_bank_inputs,
|
531 |
+
preprocessed_third_party_inputs
|
532 |
+
]
|
533 |
+
|
534 |
+
# Quantize, encrypt and serialize the inputs
|
535 |
+
encrypted_inputs = []
|
536 |
+
for client_type, preprocessed_input in zip(CLIENT_TYPES, preprocessed_inputs):
|
537 |
+
encrypted_input = client.quantize_encrypt_serialize_multi_inputs(
|
538 |
+
preprocessed_input,
|
539 |
+
input_index=INPUT_INDEXES[client_type],
|
540 |
+
processed_input_shape=EXPLAIN_PROCESSED_INPUT_SHAPE,
|
541 |
+
input_slice=EXPLAIN_INPUT_SLICES[client_type],
|
542 |
+
)
|
543 |
+
encrypted_inputs.append(encrypted_input)
|
544 |
+
|
545 |
+
# Run the FHE computation
|
546 |
+
encrypted_output = EXPLAIN_FHE_SERVER.run(
|
547 |
+
*encrypted_inputs,
|
548 |
+
serialized_evaluation_keys=evaluation_key
|
549 |
+
)
|
550 |
+
|
551 |
+
# Decrypt the output
|
552 |
+
output_prediction = client.deserialize_decrypt_dequantize(encrypted_output)
|
553 |
+
|
554 |
+
# Get the difference with the initial 'years of employment' input
|
555 |
+
years_employed_diff = int(numpy.ceil(output_prediction.squeeze() - years_employed))
|
556 |
+
|
557 |
+
if years_employed_diff > 0:
|
558 |
+
return (
|
559 |
+
f"Having at least {years_employed_diff} more years of employment would increase "
|
560 |
+
"your chance of having your credit card approved."
|
561 |
+
)
|
562 |
+
|
563 |
+
return (
|
564 |
+
"The number of years of employment you provided seems to be enough. The negative prediction "
|
565 |
+
"might come from other inputs."
|
566 |
+
)
|
567 |
+
|
568 |
+
|
deployment_files/{client.zip β approval_model/client.zip}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2ceb4a6e07cd13471c8c8c963d9e4de52d5af624e81775ebeb2421e29b9ba8c
|
3 |
+
size 28667
|
deployment_files/{server.zip β approval_model/server.zip}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e724012427c90fdc8df14360942909e5fa0accc8b27584880baab2a91533e78
|
3 |
+
size 1729
|
deployment_files/{versions.json β approval_model/versions.json}
RENAMED
File without changes
|
deployment_files/explain_model/client.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:506276661b4612d664d59f0d90aac1b5c09f942a850ec189aa16204d54433b27
|
3 |
+
size 27714
|
deployment_files/explain_model/server.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:596ae66c7effd9733a8780088984d4fc08479d67c11586ee5787111329cb353f
|
3 |
+
size 2035
|
deployment_files/explain_model/versions.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"concrete-python": "2.5.0rc1", "concrete-ml": "1.3.0", "python": "3.10.11"}
|
development.py
CHANGED
@@ -6,28 +6,49 @@ import pandas
|
|
6 |
import pickle
|
7 |
|
8 |
from settings import (
|
9 |
-
|
|
|
10 |
DATA_PATH,
|
11 |
-
|
|
|
12 |
PRE_PROCESSOR_USER_PATH,
|
13 |
PRE_PROCESSOR_BANK_PATH,
|
14 |
PRE_PROCESSOR_THIRD_PARTY_PATH,
|
15 |
USER_COLUMNS,
|
16 |
BANK_COLUMNS,
|
17 |
-
|
|
|
18 |
)
|
19 |
from utils.client_server_interface import MultiInputsFHEModelDev
|
20 |
-
from utils.model import MultiInputDecisionTreeClassifier
|
21 |
from utils.pre_processing import get_pre_processors
|
22 |
|
23 |
|
24 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
return (
|
26 |
-
data[:,
|
27 |
-
data[:,
|
28 |
-
data[:,
|
29 |
)
|
30 |
|
|
|
31 |
print("Load and pre-process the data")
|
32 |
|
33 |
# Load the data
|
@@ -40,7 +61,7 @@ data_y = data_x.pop("Target").copy().to_frame()
|
|
40 |
# Get data from all parties
|
41 |
data_user = data_x[USER_COLUMNS].copy()
|
42 |
data_bank = data_x[BANK_COLUMNS].copy()
|
43 |
-
data_third_party = data_x[
|
44 |
|
45 |
# Feature engineer the data
|
46 |
pre_processor_user, pre_processor_bank, pre_processor_third_party = get_pre_processors()
|
@@ -54,23 +75,23 @@ preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_da
|
|
54 |
|
55 |
print("\nTrain and compile the model")
|
56 |
|
57 |
-
|
58 |
|
59 |
-
|
60 |
|
61 |
-
multi_inputs_train =
|
62 |
|
63 |
-
|
64 |
|
65 |
print("\nSave deployment files")
|
66 |
|
67 |
# Delete the deployment folder and its content if it already exists
|
68 |
-
if
|
69 |
-
shutil.rmtree(
|
70 |
|
71 |
# Save files needed for deployment (and enable cross-platform deployment)
|
72 |
-
|
73 |
-
|
74 |
|
75 |
# Save pre-processors
|
76 |
with (
|
@@ -82,4 +103,44 @@ with (
|
|
82 |
pickle.dump(pre_processor_bank, file_bank)
|
83 |
pickle.dump(pre_processor_third_party, file_third_party)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
print("\nDone !")
|
|
|
6 |
import pickle
|
7 |
|
8 |
from settings import (
|
9 |
+
APPROVAL_DEPLOYMENT_PATH,
|
10 |
+
EXPLAIN_DEPLOYMENT_PATH,
|
11 |
DATA_PATH,
|
12 |
+
APPROVAL_INPUT_SLICES,
|
13 |
+
EXPLAIN_INPUT_SLICES,
|
14 |
PRE_PROCESSOR_USER_PATH,
|
15 |
PRE_PROCESSOR_BANK_PATH,
|
16 |
PRE_PROCESSOR_THIRD_PARTY_PATH,
|
17 |
USER_COLUMNS,
|
18 |
BANK_COLUMNS,
|
19 |
+
APPROVAL_THIRD_PARTY_COLUMNS,
|
20 |
+
EXPLAIN_THIRD_PARTY_COLUMNS,
|
21 |
)
|
22 |
from utils.client_server_interface import MultiInputsFHEModelDev
|
23 |
+
from utils.model import MultiInputDecisionTreeClassifier, MultiInputDecisionTreeRegressor
|
24 |
from utils.pre_processing import get_pre_processors
|
25 |
|
26 |
|
27 |
+
def get_multi_inputs(data, is_approval):
|
28 |
+
"""Get inputs for all three parties from the input data, using fixed slices.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
data (numpy.ndarray): The input data to consider.
|
32 |
+
is_approval (bool): If the data should be used for the 'approval' model (else, otherwise for
|
33 |
+
the 'explain' model).
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
(Tuple[numpy.ndarray]): The inputs for all three parties.
|
37 |
+
"""
|
38 |
+
if is_approval:
|
39 |
+
return (
|
40 |
+
data[:, APPROVAL_INPUT_SLICES["user"]],
|
41 |
+
data[:, APPROVAL_INPUT_SLICES["bank"]],
|
42 |
+
data[:, APPROVAL_INPUT_SLICES["third_party"]]
|
43 |
+
)
|
44 |
+
|
45 |
return (
|
46 |
+
data[:, EXPLAIN_INPUT_SLICES["user"]],
|
47 |
+
data[:, EXPLAIN_INPUT_SLICES["bank"]],
|
48 |
+
data[:, EXPLAIN_INPUT_SLICES["third_party"]]
|
49 |
)
|
50 |
|
51 |
+
|
52 |
print("Load and pre-process the data")
|
53 |
|
54 |
# Load the data
|
|
|
61 |
# Get data from all parties
|
62 |
data_user = data_x[USER_COLUMNS].copy()
|
63 |
data_bank = data_x[BANK_COLUMNS].copy()
|
64 |
+
data_third_party = data_x[APPROVAL_THIRD_PARTY_COLUMNS].copy()
|
65 |
|
66 |
# Feature engineer the data
|
67 |
pre_processor_user, pre_processor_bank, pre_processor_third_party = get_pre_processors()
|
|
|
75 |
|
76 |
print("\nTrain and compile the model")
|
77 |
|
78 |
+
model_approval = MultiInputDecisionTreeClassifier()
|
79 |
|
80 |
+
model_approval, sklearn_model_approval = model_approval.fit_benchmark(preprocessed_data_x, data_y)
|
81 |
|
82 |
+
multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=True)
|
83 |
|
84 |
+
model_approval.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
|
85 |
|
86 |
print("\nSave deployment files")
|
87 |
|
88 |
# Delete the deployment folder and its content if it already exists
|
89 |
+
if APPROVAL_DEPLOYMENT_PATH.is_dir():
|
90 |
+
shutil.rmtree(APPROVAL_DEPLOYMENT_PATH)
|
91 |
|
92 |
# Save files needed for deployment (and enable cross-platform deployment)
|
93 |
+
fhe_model_dev_approval = MultiInputsFHEModelDev(APPROVAL_DEPLOYMENT_PATH, model_approval)
|
94 |
+
fhe_model_dev_approval.save(via_mlir=True)
|
95 |
|
96 |
# Save pre-processors
|
97 |
with (
|
|
|
103 |
pickle.dump(pre_processor_bank, file_bank)
|
104 |
pickle.dump(pre_processor_third_party, file_third_party)
|
105 |
|
106 |
+
|
107 |
+
print("\nLoad, train, compile and save files for the 'explain' model")
|
108 |
+
|
109 |
+
# Define input and target data
|
110 |
+
data_x = data.copy()
|
111 |
+
data_y = data_x.pop("Years_employed").copy().to_frame()
|
112 |
+
target_values = data_x.pop("Target").copy()
|
113 |
+
|
114 |
+
# Get all data points whose target value is True (credit card has been approved)
|
115 |
+
approved_mask = target_values == 1
|
116 |
+
data_x_approved = data_x[approved_mask]
|
117 |
+
data_y_approved = data_y[approved_mask]
|
118 |
+
|
119 |
+
# Get data from all parties
|
120 |
+
data_user = data_x_approved[USER_COLUMNS].copy()
|
121 |
+
data_bank = data_x_approved[BANK_COLUMNS].copy()
|
122 |
+
data_third_party = data_x_approved[EXPLAIN_THIRD_PARTY_COLUMNS].copy()
|
123 |
+
|
124 |
+
preprocessed_data_user = pre_processor_user.transform(data_user)
|
125 |
+
preprocessed_data_bank = pre_processor_bank.transform(data_bank)
|
126 |
+
preprocessed_data_third_party = data_third_party.to_numpy()
|
127 |
+
|
128 |
+
preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1)
|
129 |
+
|
130 |
+
model_explain = MultiInputDecisionTreeRegressor()
|
131 |
+
|
132 |
+
model_explain, sklearn_model_explain = model_explain.fit_benchmark(preprocessed_data_x, data_y_approved)
|
133 |
+
|
134 |
+
multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=False)
|
135 |
+
|
136 |
+
model_explain.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
|
137 |
+
|
138 |
+
# Delete the deployment folder and its content if it already exists
|
139 |
+
if EXPLAIN_DEPLOYMENT_PATH.is_dir():
|
140 |
+
shutil.rmtree(EXPLAIN_DEPLOYMENT_PATH)
|
141 |
+
|
142 |
+
# Save files needed for deployment (and enable cross-platform deployment)
|
143 |
+
fhe_model_dev_explain = MultiInputsFHEModelDev(EXPLAIN_DEPLOYMENT_PATH, model_explain)
|
144 |
+
fhe_model_dev_explain.save(via_mlir=True)
|
145 |
+
|
146 |
print("\nDone !")
|
server.py
CHANGED
@@ -5,11 +5,11 @@ from typing import List, Optional
|
|
5 |
from fastapi import FastAPI, File, Form, UploadFile
|
6 |
from fastapi.responses import JSONResponse, Response
|
7 |
|
8 |
-
from settings import
|
9 |
from utils.client_server_interface import MultiInputsFHEModelServer
|
10 |
|
11 |
# Load the server
|
12 |
-
FHE_SERVER = MultiInputsFHEModelServer(
|
13 |
|
14 |
|
15 |
def _get_server_file_path(name, client_id, client_type=None):
|
|
|
5 |
from fastapi import FastAPI, File, Form, UploadFile
|
6 |
from fastapi.responses import JSONResponse, Response
|
7 |
|
8 |
+
from settings import APPROVAL_DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
|
9 |
from utils.client_server_interface import MultiInputsFHEModelServer
|
10 |
|
11 |
# Load the server
|
12 |
+
FHE_SERVER = MultiInputsFHEModelServer(APPROVAL_DEPLOYMENT_PATH)
|
13 |
|
14 |
|
15 |
def _get_server_file_path(name, client_id, client_type=None):
|
settings.py
CHANGED
@@ -6,12 +6,16 @@ import pandas
|
|
6 |
# The directory of this project
|
7 |
REPO_DIR = Path(__file__).parent
|
8 |
|
9 |
-
#
|
10 |
DEPLOYMENT_PATH = REPO_DIR / "deployment_files"
|
11 |
FHE_KEYS = REPO_DIR / ".fhe_keys"
|
12 |
CLIENT_FILES = REPO_DIR / "client_files"
|
13 |
SERVER_FILES = REPO_DIR / "server_files"
|
14 |
|
|
|
|
|
|
|
|
|
15 |
# Path targeting pre-processor saved files
|
16 |
PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
|
17 |
PRE_PROCESSOR_BANK_PATH = DEPLOYMENT_PATH / 'pre_processor_bank.pkl'
|
@@ -29,7 +33,8 @@ SERVER_URL = "http://localhost:8000/"
|
|
29 |
DATA_PATH = "data/data.csv"
|
30 |
|
31 |
# Development settings
|
32 |
-
|
|
|
33 |
|
34 |
CLIENT_TYPES = ["user", "bank", "third_party"]
|
35 |
INPUT_INDEXES = {
|
@@ -37,19 +42,26 @@ INPUT_INDEXES = {
|
|
37 |
"bank": 1,
|
38 |
"third_party": 2,
|
39 |
}
|
40 |
-
|
41 |
"user": slice(0, 36), # First position: start from 0
|
42 |
"bank": slice(36, 37), # Second position: start from n_feature_user
|
43 |
"third_party": slice(37, 39), # Third position: start from n_feature_user + n_feature_bank
|
44 |
}
|
|
|
|
|
|
|
|
|
|
|
45 |
|
|
|
46 |
USER_COLUMNS = [
|
47 |
'Own_car', 'Own_property', 'Mobile_phone', 'Num_children', 'Household_size',
|
48 |
'Total_income', 'Age', 'Income_type', 'Education_type', 'Family_status', 'Housing_type',
|
49 |
'Occupation_type',
|
50 |
]
|
51 |
BANK_COLUMNS = ["Account_age"]
|
52 |
-
|
|
|
53 |
|
54 |
_data = pandas.read_csv(DATA_PATH, encoding="utf-8")
|
55 |
|
|
|
6 |
# The directory of this project
|
7 |
REPO_DIR = Path(__file__).parent
|
8 |
|
9 |
+
# Main necessary directories
|
10 |
DEPLOYMENT_PATH = REPO_DIR / "deployment_files"
|
11 |
FHE_KEYS = REPO_DIR / ".fhe_keys"
|
12 |
CLIENT_FILES = REPO_DIR / "client_files"
|
13 |
SERVER_FILES = REPO_DIR / "server_files"
|
14 |
|
15 |
+
# ALl deployment directories
|
16 |
+
APPROVAL_DEPLOYMENT_PATH = DEPLOYMENT_PATH / "approval_model"
|
17 |
+
EXPLAIN_DEPLOYMENT_PATH = DEPLOYMENT_PATH / "explain_model"
|
18 |
+
|
19 |
# Path targeting pre-processor saved files
|
20 |
PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
|
21 |
PRE_PROCESSOR_BANK_PATH = DEPLOYMENT_PATH / 'pre_processor_bank.pkl'
|
|
|
33 |
DATA_PATH = "data/data.csv"
|
34 |
|
35 |
# Development settings
|
36 |
+
APPROVAL_PROCESSED_INPUT_SHAPE = (1, 39)
|
37 |
+
EXPLAIN_PROCESSED_INPUT_SHAPE = (1, 38)
|
38 |
|
39 |
CLIENT_TYPES = ["user", "bank", "third_party"]
|
40 |
INPUT_INDEXES = {
|
|
|
42 |
"bank": 1,
|
43 |
"third_party": 2,
|
44 |
}
|
45 |
+
APPROVAL_INPUT_SLICES = {
|
46 |
"user": slice(0, 36), # First position: start from 0
|
47 |
"bank": slice(36, 37), # Second position: start from n_feature_user
|
48 |
"third_party": slice(37, 39), # Third position: start from n_feature_user + n_feature_bank
|
49 |
}
|
50 |
+
EXPLAIN_INPUT_SLICES = {
|
51 |
+
"user": slice(0, 36), # First position: start from 0
|
52 |
+
"bank": slice(36, 37), # Second position: start from n_feature_user
|
53 |
+
"third_party": slice(37, 38), # Third position: start from n_feature_user + n_feature_bank
|
54 |
+
}
|
55 |
|
56 |
+
# Fix column order for pre-processing steps
|
57 |
USER_COLUMNS = [
|
58 |
'Own_car', 'Own_property', 'Mobile_phone', 'Num_children', 'Household_size',
|
59 |
'Total_income', 'Age', 'Income_type', 'Education_type', 'Family_status', 'Housing_type',
|
60 |
'Occupation_type',
|
61 |
]
|
62 |
BANK_COLUMNS = ["Account_age"]
|
63 |
+
APPROVAL_THIRD_PARTY_COLUMNS = ["Years_employed", "Employed"]
|
64 |
+
EXPLAIN_THIRD_PARTY_COLUMNS = ["Employed"]
|
65 |
|
66 |
_data = pandas.read_csv(DATA_PATH, encoding="utf-8")
|
67 |
|
utils/client_server_interface.py
CHANGED
@@ -3,10 +3,11 @@
|
|
3 |
import numpy
|
4 |
import copy
|
5 |
|
6 |
-
from
|
7 |
|
|
|
8 |
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
|
9 |
-
from concrete.ml.sklearn import
|
10 |
|
11 |
|
12 |
class MultiInputsFHEModelDev(FHEModelDev):
|
@@ -15,8 +16,9 @@ class MultiInputsFHEModelDev(FHEModelDev):
|
|
15 |
|
16 |
super().__init__(*arg, **kwargs)
|
17 |
|
|
|
18 |
model = copy.copy(self.model)
|
19 |
-
model.__class__ =
|
20 |
self.model = model
|
21 |
|
22 |
|
@@ -30,10 +32,27 @@ class MultiInputsFHEModelClient(FHEModelClient):
|
|
30 |
def quantize_encrypt_serialize_multi_inputs(
|
31 |
self,
|
32 |
x: numpy.ndarray,
|
33 |
-
input_index,
|
34 |
-
processed_input_shape,
|
35 |
-
input_slice
|
36 |
) -> bytes:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
x_padded = numpy.zeros(processed_input_shape)
|
39 |
|
@@ -58,15 +77,15 @@ class MultiInputsFHEModelServer(FHEModelServer):
|
|
58 |
|
59 |
def run(
|
60 |
self,
|
61 |
-
*serialized_encrypted_quantized_data: bytes,
|
62 |
serialized_evaluation_keys: bytes,
|
63 |
) -> bytes:
|
64 |
-
"""Run the model on the server over encrypted data.
|
65 |
|
66 |
Args:
|
67 |
-
serialized_encrypted_quantized_data (bytes):
|
68 |
-
and serialized data
|
69 |
-
serialized_evaluation_keys (bytes):
|
70 |
|
71 |
Returns:
|
72 |
bytes: the result of the model
|
|
|
3 |
import numpy
|
4 |
import copy
|
5 |
|
6 |
+
from typing import Tuple
|
7 |
|
8 |
+
from concrete.fhe import Value, EvaluationKeys
|
9 |
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
|
10 |
+
from concrete.ml.sklearn import DecisionTreeClassifier
|
11 |
|
12 |
|
13 |
class MultiInputsFHEModelDev(FHEModelDev):
|
|
|
16 |
|
17 |
super().__init__(*arg, **kwargs)
|
18 |
|
19 |
+
# Workaround that enables loading a modified version of a DecisionTreeClassifier model
|
20 |
model = copy.copy(self.model)
|
21 |
+
model.__class__ = DecisionTreeClassifier
|
22 |
self.model = model
|
23 |
|
24 |
|
|
|
32 |
def quantize_encrypt_serialize_multi_inputs(
|
33 |
self,
|
34 |
x: numpy.ndarray,
|
35 |
+
input_index: int,
|
36 |
+
processed_input_shape: Tuple[int],
|
37 |
+
input_slice: slice,
|
38 |
) -> bytes:
|
39 |
+
"""Quantize, encrypt and serialize inputs for a multi-party model.
|
40 |
+
|
41 |
+
In the following, the 'quantize_input' method called is the one defined in Concrete ML's
|
42 |
+
built-in models. Since they don't natively handle inputs for multi-party models, we need
|
43 |
+
to use padding along indexing and slicing so that inputs from a specific party are correctly
|
44 |
+
associated with input quantizers.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
x (numpy.ndarray): The input to consider. Here, the input should only represent a
|
48 |
+
single party.
|
49 |
+
input_index (int): The index representing the type of model (0: "user", 1: "bank",
|
50 |
+
2: "third_party")
|
51 |
+
processed_input_shape (Tuple[int]): The total input shape (all parties combined) after
|
52 |
+
pre-processing.
|
53 |
+
input_slice (slice): The slices to consider for the given party.
|
54 |
+
|
55 |
+
"""
|
56 |
|
57 |
x_padded = numpy.zeros(processed_input_shape)
|
58 |
|
|
|
77 |
|
78 |
def run(
|
79 |
self,
|
80 |
+
*serialized_encrypted_quantized_data: Tuple[bytes],
|
81 |
serialized_evaluation_keys: bytes,
|
82 |
) -> bytes:
|
83 |
+
"""Run the model on the server over encrypted data for a multi-party model.
|
84 |
|
85 |
Args:
|
86 |
+
serialized_encrypted_quantized_data (Tuple[bytes]): The encrypted, quantized
|
87 |
+
and serialized data for a multi-party model.
|
88 |
+
serialized_evaluation_keys (bytes): The serialized evaluation key.
|
89 |
|
90 |
Returns:
|
91 |
bytes: the result of the model
|
utils/model.py
CHANGED
@@ -13,7 +13,7 @@ from concrete.ml.common.utils import (
|
|
13 |
check_there_is_no_p_error_options_in_configuration
|
14 |
)
|
15 |
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
|
16 |
-
from concrete.ml.sklearn import DecisionTreeClassifier
|
17 |
|
18 |
class MultiInputModel:
|
19 |
|
@@ -131,46 +131,8 @@ class MultiInputModel:
|
|
131 |
|
132 |
return compiler
|
133 |
|
134 |
-
def predict_multi_inputs(self, *multi_inputs, simulate=True):
|
135 |
-
"""Run the inference with multiple inputs, with simulation or in FHE."""
|
136 |
-
assert all(isinstance(inputs, numpy.ndarray) for inputs in multi_inputs)
|
137 |
-
|
138 |
-
if not simulate:
|
139 |
-
self.fhe_circuit.keygen()
|
140 |
-
|
141 |
-
y_preds = []
|
142 |
-
execution_times = []
|
143 |
-
for inputs in zip(*multi_inputs):
|
144 |
-
inputs = tuple(numpy.expand_dims(input, axis=0) for input in inputs)
|
145 |
-
|
146 |
-
q_inputs = self.quantize_input(*inputs)
|
147 |
-
|
148 |
-
if simulate:
|
149 |
-
q_y_proba = self.fhe_circuit.simulate(*q_inputs)
|
150 |
-
else:
|
151 |
-
q_inputs_enc = self.fhe_circuit.encrypt(*q_inputs)
|
152 |
-
|
153 |
-
start = time.time()
|
154 |
-
q_y_proba_enc = self.fhe_circuit.run(*q_inputs_enc)
|
155 |
-
end = time.time() - start
|
156 |
-
|
157 |
-
execution_times.append(end)
|
158 |
-
|
159 |
-
q_y_proba = self.fhe_circuit.decrypt(q_y_proba_enc)
|
160 |
-
|
161 |
-
y_proba = self.dequantize_output(q_y_proba)
|
162 |
-
|
163 |
-
y_proba = self.post_processing(y_proba)
|
164 |
-
|
165 |
-
y_pred = numpy.argmax(y_proba, axis=1)
|
166 |
-
|
167 |
-
y_preds.append(y_pred)
|
168 |
-
|
169 |
-
if not simulate:
|
170 |
-
print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s")
|
171 |
-
|
172 |
-
return numpy.array(y_preds)
|
173 |
-
|
174 |
-
|
175 |
class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
|
176 |
-
pass
|
|
|
|
|
|
|
|
13 |
check_there_is_no_p_error_options_in_configuration
|
14 |
)
|
15 |
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
|
16 |
+
from concrete.ml.sklearn import DecisionTreeClassifier, DecisionTreeRegressor
|
17 |
|
18 |
class MultiInputModel:
|
19 |
|
|
|
131 |
|
132 |
return compiler
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
|
135 |
+
pass
|
136 |
+
|
137 |
+
class MultiInputDecisionTreeRegressor(MultiInputModel, DecisionTreeRegressor):
|
138 |
+
pass
|