romanbredehoft-zama
commited on
Commit
•
bf71bfa
1
Parent(s):
8d5cb63
Add explicit error messages
Browse files- app.py +5 -3
- backend.py +67 -55
app.py
CHANGED
@@ -47,8 +47,10 @@ with demo:
|
|
47 |
gr.Markdown("## Client side")
|
48 |
|
49 |
gr.Markdown("### Step 1: Generate the keys. ")
|
50 |
-
# TODO: Re-initialize the key message once generated and sent
|
51 |
keygen_button = gr.Button("Generate the keys and send evaluation key to the server.")
|
|
|
|
|
|
|
52 |
client_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
53 |
|
54 |
gr.Markdown("### Step 2: Infos. ")
|
@@ -133,7 +135,7 @@ with demo:
|
|
133 |
get_output_button = gr.Button("Receive the encrypted output from the server.")
|
134 |
|
135 |
encrypted_output_representation = gr.Textbox(
|
136 |
-
label="Encrypted output representation: ", max_lines=
|
137 |
)
|
138 |
|
139 |
gr.Markdown("### Step 6: Decrypt the output.")
|
@@ -146,7 +148,7 @@ with demo:
|
|
146 |
# Button generate the keys
|
147 |
keygen_button.click(
|
148 |
keygen_send,
|
149 |
-
outputs=[client_id, keygen_button],
|
150 |
)
|
151 |
|
152 |
# Button to pre-process, generate the key, encrypt and send the user inputs from the client
|
|
|
47 |
gr.Markdown("## Client side")
|
48 |
|
49 |
gr.Markdown("### Step 1: Generate the keys. ")
|
|
|
50 |
keygen_button = gr.Button("Generate the keys and send evaluation key to the server.")
|
51 |
+
evaluation_key = gr.Textbox(
|
52 |
+
label="Evaluation key representation:", max_lines=2, interactive=False
|
53 |
+
)
|
54 |
client_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
55 |
|
56 |
gr.Markdown("### Step 2: Infos. ")
|
|
|
135 |
get_output_button = gr.Button("Receive the encrypted output from the server.")
|
136 |
|
137 |
encrypted_output_representation = gr.Textbox(
|
138 |
+
label="Encrypted output representation: ", max_lines=2, interactive=False
|
139 |
)
|
140 |
|
141 |
gr.Markdown("### Step 6: Decrypt the output.")
|
|
|
148 |
# Button generate the keys
|
149 |
keygen_button.click(
|
150 |
keygen_send,
|
151 |
+
outputs=[client_id, evaluation_key, keygen_button],
|
152 |
)
|
153 |
|
154 |
# Button to pre-process, generate the key, encrypt and send the user inputs from the client
|
backend.py
CHANGED
@@ -53,7 +53,7 @@ def shorten_bytes_object(bytes_object, limit=500):
|
|
53 |
|
54 |
|
55 |
def clean_temporary_files(n_keys=20):
|
56 |
-
"""Clean keys and encrypted
|
57 |
|
58 |
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
|
59 |
limit is reached, the oldest files are deleted.
|
@@ -73,16 +73,15 @@ def clean_temporary_files(n_keys=20):
|
|
73 |
user_ids.append(key_dir.name)
|
74 |
shutil.rmtree(key_dir)
|
75 |
|
76 |
-
# Get all the encrypted
|
77 |
client_files = CLIENT_FILES.iterdir()
|
78 |
server_files = SERVER_FILES.iterdir()
|
79 |
|
80 |
-
# Delete all files related to the
|
81 |
-
for
|
82 |
for user_id in user_ids:
|
83 |
-
if user_id in
|
84 |
-
|
85 |
-
client_server_file.unlink()
|
86 |
|
87 |
|
88 |
def _get_client(client_id):
|
@@ -122,11 +121,44 @@ def _get_client_file_path(name, client_id, client_type=None):
|
|
122 |
return dir_path / f"{name}{client_type_suffix}"
|
123 |
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
def keygen_send():
|
126 |
"""Generate the private and evaluation key, and send the evaluation key to the server.
|
127 |
|
128 |
Returns:
|
129 |
-
client_id (
|
130 |
"""
|
131 |
# Clean temporary files
|
132 |
clean_temporary_files()
|
@@ -154,48 +186,18 @@ def keygen_send():
|
|
154 |
|
155 |
# Send the evaluation key to the server
|
156 |
_send_to_server(client_id, None, file_name)
|
157 |
-
|
158 |
-
return client_id, gr.update(value="Keys are generated and sent ✅")
|
159 |
-
|
160 |
-
|
161 |
-
def _send_to_server(client_id, client_type, file_name):
|
162 |
-
"""Send the encrypted inputs or the evaluation key to the server.
|
163 |
-
|
164 |
-
Args:
|
165 |
-
client_id (int): The client ID to consider.
|
166 |
-
client_type (Optional[str]): The type of client to consider (either 'user', 'bank', 'third_party' or
|
167 |
-
None).
|
168 |
-
file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs').
|
169 |
-
"""
|
170 |
-
# Get the paths to the encrypted inputs
|
171 |
-
encrypted_file_path = _get_client_file_path(file_name, client_id, client_type)
|
172 |
|
173 |
-
#
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
"file_name": file_name,
|
178 |
-
}
|
179 |
-
|
180 |
-
files = [
|
181 |
-
("files", open(encrypted_file_path, "rb")),
|
182 |
-
]
|
183 |
-
|
184 |
-
# Send the encrypted inputs or evaluation key to the server
|
185 |
-
url = SERVER_URL + "send_file"
|
186 |
-
with requests.post(
|
187 |
-
url=url,
|
188 |
-
data=data,
|
189 |
-
files=files,
|
190 |
-
) as response:
|
191 |
-
return response.ok
|
192 |
|
193 |
|
194 |
def _encrypt_send(client_id, inputs, client_type):
|
195 |
"""Encrypt the given inputs for a specific client and send it to the server.
|
196 |
|
197 |
Args:
|
198 |
-
client_id (
|
199 |
inputs (numpy.ndarray): The inputs to encrypt.
|
200 |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
201 |
|
@@ -203,6 +205,8 @@ def _encrypt_send(client_id, inputs, client_type):
|
|
203 |
client_id, encrypted_inputs_short (int, bytes): Integer ID representing the current client
|
204 |
and a byte short representation of the encrypted input to send.
|
205 |
"""
|
|
|
|
|
206 |
|
207 |
# Retrieve the client instance
|
208 |
client = _get_client(client_id)
|
@@ -236,7 +240,7 @@ def pre_process_encrypt_send_user(client_id, *inputs):
|
|
236 |
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
|
237 |
|
238 |
Args:
|
239 |
-
client_id (
|
240 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
241 |
|
242 |
Returns:
|
@@ -284,7 +288,7 @@ def pre_process_encrypt_send_bank(client_id, *inputs):
|
|
284 |
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
|
285 |
|
286 |
Args:
|
287 |
-
client_id (
|
288 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
289 |
|
290 |
Returns:
|
@@ -300,7 +304,7 @@ def pre_process_encrypt_send_third_party(client_id, *inputs):
|
|
300 |
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
|
301 |
|
302 |
Args:
|
303 |
-
client_id (
|
304 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
305 |
|
306 |
Returns:
|
@@ -326,10 +330,11 @@ def run_fhe(client_id):
|
|
326 |
"""Run the model on the encrypted inputs previously sent using FHE.
|
327 |
|
328 |
Args:
|
329 |
-
client_id (
|
330 |
"""
|
331 |
|
332 |
-
|
|
|
333 |
|
334 |
data = {
|
335 |
"client_id": client_id,
|
@@ -344,18 +349,22 @@ def run_fhe(client_id):
|
|
344 |
if response.ok:
|
345 |
return response.json()
|
346 |
else:
|
347 |
-
raise gr.Error("Please
|
348 |
|
349 |
|
350 |
def get_output(client_id):
|
351 |
"""Retrieve the encrypted output.
|
352 |
|
353 |
Args:
|
354 |
-
client_id (
|
355 |
|
356 |
Returns:
|
357 |
encrypted_output_short (bytes): A byte short representation of the encrypted output.
|
358 |
"""
|
|
|
|
|
|
|
|
|
359 |
data = {
|
360 |
"client_id": client_id,
|
361 |
}
|
@@ -381,24 +390,27 @@ def get_output(client_id):
|
|
381 |
|
382 |
return encrypted_output_short
|
383 |
else:
|
384 |
-
raise gr.Error("Please
|
385 |
|
386 |
|
387 |
def decrypt_output(client_id):
|
388 |
"""Decrypt the result.
|
389 |
|
390 |
Args:
|
391 |
-
client_id (
|
392 |
|
393 |
Returns:
|
394 |
output(numpy.ndarray): The decrypted output
|
395 |
-
|
396 |
"""
|
|
|
|
|
|
|
|
|
397 |
# Get the encrypted output path
|
398 |
encrypted_output_path = _get_client_file_path("encrypted_output", client_id)
|
399 |
|
400 |
if not encrypted_output_path.is_file():
|
401 |
-
raise gr.Error("Please
|
402 |
|
403 |
# Load the encrypted output as bytes
|
404 |
with encrypted_output_path.open("rb") as encrypted_output_file:
|
@@ -412,5 +424,5 @@ def decrypt_output(client_id):
|
|
412 |
|
413 |
# Determine the predicted class
|
414 |
output = numpy.argmax(output_proba, axis=1)
|
415 |
-
|
416 |
return output
|
|
|
53 |
|
54 |
|
55 |
def clean_temporary_files(n_keys=20):
|
56 |
+
"""Clean older keys and encrypted files.
|
57 |
|
58 |
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
|
59 |
limit is reached, the oldest files are deleted.
|
|
|
73 |
user_ids.append(key_dir.name)
|
74 |
shutil.rmtree(key_dir)
|
75 |
|
76 |
+
# Get all the encrypted files in the temporary folders
|
77 |
client_files = CLIENT_FILES.iterdir()
|
78 |
server_files = SERVER_FILES.iterdir()
|
79 |
|
80 |
+
# Delete all files related to the IDs whose keys were deleted
|
81 |
+
for user_dir in chain(client_files, server_files):
|
82 |
for user_id in user_ids:
|
83 |
+
if user_id in user_dir.name:
|
84 |
+
shutil.rmtree(user_dir)
|
|
|
85 |
|
86 |
|
87 |
def _get_client(client_id):
|
|
|
121 |
return dir_path / f"{name}{client_type_suffix}"
|
122 |
|
123 |
|
124 |
+
def _send_to_server(client_id, client_type, file_name):
|
125 |
+
"""Send the encrypted inputs or the evaluation key to the server.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
client_id (int): The client ID to consider.
|
129 |
+
client_type (Optional[str]): The type of client to consider (either 'user', 'bank', 'third_party' or
|
130 |
+
None).
|
131 |
+
file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs').
|
132 |
+
"""
|
133 |
+
# Get the paths to the encrypted inputs
|
134 |
+
encrypted_file_path = _get_client_file_path(file_name, client_id, client_type)
|
135 |
+
|
136 |
+
# Define the data and files to post
|
137 |
+
data = {
|
138 |
+
"client_id": client_id,
|
139 |
+
"client_type": client_type,
|
140 |
+
"file_name": file_name,
|
141 |
+
}
|
142 |
+
|
143 |
+
files = [
|
144 |
+
("files", open(encrypted_file_path, "rb")),
|
145 |
+
]
|
146 |
+
|
147 |
+
# Send the encrypted inputs or evaluation key to the server
|
148 |
+
url = SERVER_URL + "send_file"
|
149 |
+
with requests.post(
|
150 |
+
url=url,
|
151 |
+
data=data,
|
152 |
+
files=files,
|
153 |
+
) as response:
|
154 |
+
return response.ok
|
155 |
+
|
156 |
+
|
157 |
def keygen_send():
|
158 |
"""Generate the private and evaluation key, and send the evaluation key to the server.
|
159 |
|
160 |
Returns:
|
161 |
+
client_id (str): The current client ID to consider.
|
162 |
"""
|
163 |
# Clean temporary files
|
164 |
clean_temporary_files()
|
|
|
186 |
|
187 |
# Send the evaluation key to the server
|
188 |
_send_to_server(client_id, None, file_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
# Create a truncated version of the evaluation key for display
|
191 |
+
evaluation_key_short = shorten_bytes_object(evaluation_key)
|
192 |
+
|
193 |
+
return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent ✅")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
|
196 |
def _encrypt_send(client_id, inputs, client_type):
|
197 |
"""Encrypt the given inputs for a specific client and send it to the server.
|
198 |
|
199 |
Args:
|
200 |
+
client_id (str): The current client ID to consider.
|
201 |
inputs (numpy.ndarray): The inputs to encrypt.
|
202 |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
203 |
|
|
|
205 |
client_id, encrypted_inputs_short (int, bytes): Integer ID representing the current client
|
206 |
and a byte short representation of the encrypted input to send.
|
207 |
"""
|
208 |
+
if client_id == "":
|
209 |
+
raise gr.Error("Please generate the keys first.")
|
210 |
|
211 |
# Retrieve the client instance
|
212 |
client = _get_client(client_id)
|
|
|
240 |
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
|
241 |
|
242 |
Args:
|
243 |
+
client_id (str): The current client ID to consider.
|
244 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
245 |
|
246 |
Returns:
|
|
|
288 |
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
|
289 |
|
290 |
Args:
|
291 |
+
client_id (str): The current client ID to consider.
|
292 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
293 |
|
294 |
Returns:
|
|
|
304 |
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
|
305 |
|
306 |
Args:
|
307 |
+
client_id (str): The current client ID to consider.
|
308 |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
309 |
|
310 |
Returns:
|
|
|
330 |
"""Run the model on the encrypted inputs previously sent using FHE.
|
331 |
|
332 |
Args:
|
333 |
+
client_id (str): The current client ID to consider.
|
334 |
"""
|
335 |
|
336 |
+
if client_id == "":
|
337 |
+
raise gr.Error("Please generate the keys first.")
|
338 |
|
339 |
data = {
|
340 |
"client_id": client_id,
|
|
|
349 |
if response.ok:
|
350 |
return response.json()
|
351 |
else:
|
352 |
+
raise gr.Error("Please send the inputs from all three parties to the server first.")
|
353 |
|
354 |
|
355 |
def get_output(client_id):
|
356 |
"""Retrieve the encrypted output.
|
357 |
|
358 |
Args:
|
359 |
+
client_id (str): The current client ID to consider.
|
360 |
|
361 |
Returns:
|
362 |
encrypted_output_short (bytes): A byte short representation of the encrypted output.
|
363 |
"""
|
364 |
+
|
365 |
+
if client_id == "":
|
366 |
+
raise gr.Error("Please generate the keys first.")
|
367 |
+
|
368 |
data = {
|
369 |
"client_id": client_id,
|
370 |
}
|
|
|
390 |
|
391 |
return encrypted_output_short
|
392 |
else:
|
393 |
+
raise gr.Error("Please run the FHE execution first and wait for it to be completed.")
|
394 |
|
395 |
|
396 |
def decrypt_output(client_id):
|
397 |
"""Decrypt the result.
|
398 |
|
399 |
Args:
|
400 |
+
client_id (str): The current client ID to consider.
|
401 |
|
402 |
Returns:
|
403 |
output(numpy.ndarray): The decrypted output
|
|
|
404 |
"""
|
405 |
+
|
406 |
+
if client_id == "":
|
407 |
+
raise gr.Error("Please generate the keys first.")
|
408 |
+
|
409 |
# Get the encrypted output path
|
410 |
encrypted_output_path = _get_client_file_path("encrypted_output", client_id)
|
411 |
|
412 |
if not encrypted_output_path.is_file():
|
413 |
+
raise gr.Error("Please receive the outputs from the server first.")
|
414 |
|
415 |
# Load the encrypted output as bytes
|
416 |
with encrypted_output_path.open("rb") as encrypted_output_file:
|
|
|
424 |
|
425 |
# Determine the predicted class
|
426 |
output = numpy.argmax(output_proba, axis=1)
|
427 |
+
|
428 |
return output
|