romanbredehoft-zama
commited on
Commit
•
c119738
1
Parent(s):
9e637e6
Adding deployment files and updating app
Browse files- .gitignore +16 -1
- app.py +130 -138
- data/clean_data.csv +0 -0
- deployment_files/client.zip +3 -0
- deployment_files/server.zip +3 -0
- deployment_files/versions.json +1 -0
- development/client_server_interface.py +77 -0
- development/development.py +67 -0
- development/model.py +130 -0
- development/pre_processing.py +122 -0
- requirements.txt +3 -1
- server.py +104 -0
- settings.py +16 -1
.gitignore
CHANGED
@@ -2,4 +2,19 @@
|
|
2 |
.venv
|
3 |
|
4 |
# Python cache
|
5 |
-
__pycache__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
.venv
|
3 |
|
4 |
# Python cache
|
5 |
+
__pycache__
|
6 |
+
|
7 |
+
# VS Code
|
8 |
+
.vscode
|
9 |
+
|
10 |
+
# Cocnrete Python artifacts
|
11 |
+
.artifacts
|
12 |
+
|
13 |
+
# Client-server related files
|
14 |
+
.fhe_keys
|
15 |
+
client_files
|
16 |
+
server_files
|
17 |
+
|
18 |
+
# Experiments
|
19 |
+
experiments
|
20 |
+
mlir
|
app.py
CHANGED
@@ -15,9 +15,13 @@ from settings import (
|
|
15 |
FHE_KEYS,
|
16 |
CLIENT_FILES,
|
17 |
SERVER_FILES,
|
|
|
|
|
|
|
|
|
18 |
)
|
19 |
|
20 |
-
|
21 |
|
22 |
|
23 |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
|
@@ -43,38 +47,33 @@ def shorten_bytes_object(bytes_object, limit=500):
|
|
43 |
return bytes_object[shift : limit + shift].hex()
|
44 |
|
45 |
|
46 |
-
def get_client(
|
47 |
"""Get the client API.
|
48 |
|
49 |
Args:
|
50 |
-
|
51 |
-
|
52 |
|
53 |
Returns:
|
54 |
FHEModelClient: The client API.
|
55 |
"""
|
56 |
-
|
57 |
-
# return FHEModelClient(
|
58 |
-
# FILTERS_PATH / f"{filter_name}/deployment",
|
59 |
-
# filter_name,
|
60 |
-
# key_dir=FHE_KEYS / f"{filter_name}_{user_id}",
|
61 |
-
# )
|
62 |
|
63 |
-
return
|
64 |
|
65 |
|
66 |
-
def get_client_file_path(name,
|
67 |
"""Get the correct temporary file path for the client.
|
68 |
|
69 |
Args:
|
70 |
-
name (str): The desired file name.
|
71 |
-
|
72 |
-
|
73 |
|
74 |
Returns:
|
75 |
pathlib.Path: The file path.
|
76 |
"""
|
77 |
-
return CLIENT_FILES / f"{name}_{
|
78 |
|
79 |
|
80 |
def clean_temporary_files(n_keys=20):
|
@@ -109,24 +108,18 @@ def clean_temporary_files(n_keys=20):
|
|
109 |
file.unlink()
|
110 |
|
111 |
|
112 |
-
def keygen(
|
113 |
"""Generate the private key associated to a filter.
|
114 |
|
115 |
Args:
|
116 |
-
|
117 |
-
|
118 |
-
Returns:
|
119 |
-
(user_id, True) (Tuple[int, bool]): The current user's ID and a boolean used for visual display.
|
120 |
-
|
121 |
"""
|
122 |
# Clean temporary files
|
123 |
clean_temporary_files()
|
124 |
|
125 |
-
#
|
126 |
-
|
127 |
-
|
128 |
-
# Retrieve the client API
|
129 |
-
client = get_client(user_id, filter_name)
|
130 |
|
131 |
# Generate a private key
|
132 |
client.generate_private_and_evaluation_keys(force=True)
|
@@ -138,78 +131,27 @@ def keygen(filter_name):
|
|
138 |
|
139 |
# Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
|
140 |
# buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
141 |
-
evaluation_key_path = get_client_file_path("evaluation_key",
|
142 |
|
143 |
with evaluation_key_path.open("wb") as evaluation_key_file:
|
144 |
evaluation_key_file.write(evaluation_key)
|
145 |
|
146 |
-
return (user_id, True)
|
147 |
-
|
148 |
-
|
149 |
-
def encrypt(user_id, input_image, filter_name):
|
150 |
-
"""Encrypt the given image for a specific user and filter.
|
151 |
-
|
152 |
-
Args:
|
153 |
-
user_id (int): The current user's ID.
|
154 |
-
input_image (numpy.ndarray): The image to encrypt.
|
155 |
-
filter_name (str): The current filter to consider.
|
156 |
-
|
157 |
-
Returns:
|
158 |
-
(input_image, encrypted_image_short) (Tuple[bytes]): The encrypted image and one of its
|
159 |
-
representation.
|
160 |
-
|
161 |
-
"""
|
162 |
-
user_id = keygen
|
163 |
-
|
164 |
-
if user_id == "":
|
165 |
-
raise gr.Error("Please generate the private key first.")
|
166 |
-
|
167 |
-
if input_image is None:
|
168 |
-
raise gr.Error("Please choose an image first.")
|
169 |
|
170 |
-
|
171 |
-
client = get_client(user_id, filter_name)
|
172 |
-
|
173 |
-
# Pre-process, encrypt and serialize the image
|
174 |
-
encrypted_image = client.encrypt_serialize(input_image)
|
175 |
-
|
176 |
-
# Save encrypted_image to bytes in a file, since too large to pass through regular Gradio
|
177 |
-
# buttons, https://github.com/gradio-app/gradio/issues/1877
|
178 |
-
encrypted_image_path = get_client_file_path("encrypted_image", user_id, filter_name)
|
179 |
-
|
180 |
-
with encrypted_image_path.open("wb") as encrypted_image_file:
|
181 |
-
encrypted_image_file.write(encrypted_image)
|
182 |
-
|
183 |
-
# Create a truncated version of the encrypted image for display
|
184 |
-
encrypted_image_short = shorten_bytes_object(encrypted_image)
|
185 |
-
|
186 |
-
send_input()
|
187 |
-
|
188 |
-
return (input_image, encrypted_image_short)
|
189 |
-
|
190 |
-
|
191 |
-
def send_input(user_id, filter_name):
|
192 |
"""Send the encrypted input image as well as the evaluation key to the server.
|
193 |
|
194 |
Args:
|
195 |
-
|
196 |
-
|
197 |
"""
|
198 |
-
# Get the evaluation key
|
199 |
-
evaluation_key_path = get_client_file_path("evaluation_key",
|
200 |
-
|
201 |
-
if user_id == "" or not evaluation_key_path.is_file():
|
202 |
-
raise gr.Error("Please generate the private key first.")
|
203 |
-
|
204 |
-
encrypted_input_path = get_client_file_path("encrypted_image", user_id, filter_name)
|
205 |
-
|
206 |
-
if not encrypted_input_path.is_file():
|
207 |
-
raise gr.Error("Please generate the private key and then encrypt an image first.")
|
208 |
|
209 |
# Define the data and files to post
|
210 |
data = {
|
211 |
-
"
|
212 |
-
"
|
213 |
}
|
214 |
|
215 |
files = [
|
@@ -227,19 +169,64 @@ def send_input(user_id, filter_name):
|
|
227 |
return response.ok
|
228 |
|
229 |
|
230 |
-
def
|
231 |
-
"""
|
232 |
|
233 |
Args:
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
236 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
data = {
|
238 |
-
"
|
239 |
-
"filter": filter_name,
|
240 |
}
|
241 |
|
242 |
-
# Trigger the FHE execution on the encrypted
|
243 |
url = SERVER_URL + "run_fhe"
|
244 |
with requests.post(
|
245 |
url=url,
|
@@ -248,23 +235,21 @@ def run_fhe(user_id, filter_name):
|
|
248 |
if response.ok:
|
249 |
return response.json()
|
250 |
else:
|
251 |
-
raise gr.Error("Please wait for the
|
252 |
|
253 |
|
254 |
-
def get_output(
|
255 |
-
"""Retrieve the encrypted output
|
256 |
|
257 |
Args:
|
258 |
-
|
259 |
-
filter_name (str): The current filter to consider.
|
260 |
|
261 |
Returns:
|
262 |
-
|
263 |
|
264 |
"""
|
265 |
data = {
|
266 |
-
"
|
267 |
-
"filter": filter_name,
|
268 |
}
|
269 |
|
270 |
# Retrieve the encrypted output image
|
@@ -278,54 +263,54 @@ def get_output(user_id, filter_name):
|
|
278 |
|
279 |
# Save the encrypted output to bytes in a file as it is too large to pass through regular
|
280 |
# Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
281 |
-
|
|
|
282 |
|
283 |
with encrypted_output_path.open("wb") as encrypted_output_file:
|
284 |
encrypted_output_file.write(encrypted_output)
|
285 |
|
286 |
# TODO
|
287 |
-
# Decrypt the
|
288 |
-
#
|
289 |
|
290 |
-
# return
|
291 |
|
292 |
return None
|
293 |
else:
|
294 |
raise gr.Error("Please wait for the FHE execution to be completed.")
|
295 |
|
296 |
|
297 |
-
def decrypt_output(
|
298 |
"""Decrypt the result.
|
299 |
|
300 |
Args:
|
301 |
-
|
302 |
-
|
303 |
|
304 |
Returns:
|
305 |
-
(
|
306 |
-
well as two booleans used for resetting Gradio checkboxes
|
307 |
|
308 |
"""
|
309 |
-
if user_id == "":
|
310 |
-
raise gr.Error("Please generate the private key first.")
|
311 |
-
|
312 |
# Get the encrypted output path
|
313 |
-
encrypted_output_path = get_client_file_path("encrypted_output",
|
314 |
|
315 |
if not encrypted_output_path.is_file():
|
316 |
raise gr.Error("Please run the FHE execution first.")
|
317 |
|
318 |
# Load the encrypted output as bytes
|
319 |
with encrypted_output_path.open("rb") as encrypted_output_file:
|
320 |
-
|
321 |
|
322 |
# Retrieve the client API
|
323 |
-
client = get_client(
|
324 |
|
325 |
# Deserialize, decrypt and post-process the encrypted output
|
326 |
-
|
327 |
|
328 |
-
|
|
|
|
|
|
|
329 |
|
330 |
|
331 |
demo = gr.Blocks()
|
@@ -344,7 +329,7 @@ with demo:
|
|
344 |
gr.Markdown("### Step 1: Infos. ")
|
345 |
with gr.Row():
|
346 |
with gr.Column():
|
347 |
-
gr.Markdown("###
|
348 |
# TODO : change infos
|
349 |
choice_1 = gr.Dropdown(choices=["Yes, No"], label="Choose", interactive=True)
|
350 |
slide_1 = gr.Slider(2, 20, value=4, label="Count", info="Choose between 2 and 20")
|
@@ -363,25 +348,25 @@ with demo:
|
|
363 |
gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.")
|
364 |
with gr.Row():
|
365 |
with gr.Column():
|
366 |
-
gr.Markdown("###
|
367 |
-
|
368 |
-
|
369 |
label="Keys representation:", max_lines=2, interactive=False
|
370 |
)
|
371 |
-
|
372 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
373 |
)
|
374 |
|
375 |
-
|
376 |
|
377 |
|
378 |
with gr.Column():
|
379 |
gr.Markdown("### Bank ")
|
380 |
-
|
381 |
-
|
382 |
label="Keys representation:", max_lines=2, interactive=False
|
383 |
)
|
384 |
-
|
385 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
386 |
)
|
387 |
|
@@ -390,15 +375,15 @@ with demo:
|
|
390 |
|
391 |
with gr.Column():
|
392 |
gr.Markdown("### Third Party ")
|
393 |
-
|
394 |
keys_3 = gr.Textbox(
|
395 |
label="Keys representation:", max_lines=2, interactive=False
|
396 |
)
|
397 |
-
|
398 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
399 |
)
|
400 |
|
401 |
-
|
402 |
|
403 |
gr.Markdown("## Server side")
|
404 |
gr.Markdown(
|
@@ -428,7 +413,7 @@ with demo:
|
|
428 |
get_output_button = gr.Button("Receive the encrypted output from the server.")
|
429 |
|
430 |
encrypted_output_representation = gr.Textbox(
|
431 |
-
label="
|
432 |
)
|
433 |
|
434 |
gr.Markdown("### Step 8: Decrypt the output.")
|
@@ -438,15 +423,22 @@ with demo:
|
|
438 |
label="Credit card approval decision: ", max_lines=1, interactive=False
|
439 |
)
|
440 |
|
441 |
-
#
|
442 |
-
#
|
443 |
-
#
|
444 |
-
# inputs=[filter_name],
|
445 |
-
# outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
# )
|
447 |
|
448 |
# # Button to encrypt inputs on the client side
|
449 |
-
#
|
450 |
# encrypt,
|
451 |
# inputs=[user_id, input_image, filter_name],
|
452 |
# outputs=[original_image, encrypted_input],
|
|
|
15 |
FHE_KEYS,
|
16 |
CLIENT_FILES,
|
17 |
SERVER_FILES,
|
18 |
+
DEPLOYMENT_PATH,
|
19 |
+
INITIAL_INPUT_SHAPE,
|
20 |
+
INPUT_INDEXES,
|
21 |
+
START_POSITIONS,
|
22 |
)
|
23 |
|
24 |
+
from development.client_server_interface import MultiInputsFHEModelClient
|
25 |
|
26 |
|
27 |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
|
|
|
47 |
return bytes_object[shift : limit + shift].hex()
|
48 |
|
49 |
|
50 |
+
def get_client(client_id, client_type):
|
51 |
"""Get the client API.
|
52 |
|
53 |
Args:
|
54 |
+
client_id (int): The client ID to consider.
|
55 |
+
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
|
56 |
|
57 |
Returns:
|
58 |
FHEModelClient: The client API.
|
59 |
"""
|
60 |
+
key_dir = FHE_KEYS / f"{client_type}_{client_id}"
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir)
|
63 |
|
64 |
|
65 |
+
def get_client_file_path(name, client_id, client_type):
|
66 |
"""Get the correct temporary file path for the client.
|
67 |
|
68 |
Args:
|
69 |
+
name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs').
|
70 |
+
client_id (int): The client ID to consider.
|
71 |
+
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
|
72 |
|
73 |
Returns:
|
74 |
pathlib.Path: The file path.
|
75 |
"""
|
76 |
+
return CLIENT_FILES / f"{name}_{client_type}_{client_id}"
|
77 |
|
78 |
|
79 |
def clean_temporary_files(n_keys=20):
|
|
|
108 |
file.unlink()
|
109 |
|
110 |
|
111 |
+
def keygen(client_id, client_type):
|
112 |
"""Generate the private key associated to a filter.
|
113 |
|
114 |
Args:
|
115 |
+
client_id (int): The client ID to consider.
|
116 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
|
|
|
|
|
|
117 |
"""
|
118 |
# Clean temporary files
|
119 |
clean_temporary_files()
|
120 |
|
121 |
+
# Retrieve the client instance
|
122 |
+
client = get_client(client_id, client_type)
|
|
|
|
|
|
|
123 |
|
124 |
# Generate a private key
|
125 |
client.generate_private_and_evaluation_keys(force=True)
|
|
|
131 |
|
132 |
# Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
|
133 |
# buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
134 |
+
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type)
|
135 |
|
136 |
with evaluation_key_path.open("wb") as evaluation_key_file:
|
137 |
evaluation_key_file.write(evaluation_key)
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
def send_input(client_id, client_type):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
"""Send the encrypted input image as well as the evaluation key to the server.
|
142 |
|
143 |
Args:
|
144 |
+
client_id (int): The client ID to consider.
|
145 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
146 |
"""
|
147 |
+
# Get the paths to the evaluation key and encrypted inputs
|
148 |
+
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type)
|
149 |
+
encrypted_input_path = get_client_file_path("encrypted_inputs", client_id, client_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# Define the data and files to post
|
152 |
data = {
|
153 |
+
"client_id": client_id,
|
154 |
+
"client_type": client_type,
|
155 |
}
|
156 |
|
157 |
files = [
|
|
|
169 |
return response.ok
|
170 |
|
171 |
|
172 |
+
def keygen_encrypt_send(inputs, client_type):
|
173 |
+
"""Encrypt the given inputs for a specific client.
|
174 |
|
175 |
Args:
|
176 |
+
inputs (numpy.ndarray): The inputs to encrypt.
|
177 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
|
181 |
"""
|
182 |
+
# Create an ID for the current client to consider
|
183 |
+
client_id = numpy.random.randint(0, 2**32)
|
184 |
+
|
185 |
+
keygen(client_id, client_type)
|
186 |
+
|
187 |
+
# Retrieve the client instance
|
188 |
+
client = get_client(client_id, client_type)
|
189 |
+
|
190 |
+
# TODO : pre-process the data first
|
191 |
+
|
192 |
+
# Quantize, encrypt and serialize the inputs
|
193 |
+
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
|
194 |
+
inputs,
|
195 |
+
input_index=INPUT_INDEXES[client_type],
|
196 |
+
initial_input_shape=INITIAL_INPUT_SHAPE,
|
197 |
+
start_position=START_POSITIONS[client_type],
|
198 |
+
)
|
199 |
+
|
200 |
+
# Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio
|
201 |
+
# buttons, https://github.com/gradio-app/gradio/issues/1877
|
202 |
+
encrypted_inputs_path = get_client_file_path("encrypted_inputs", client_id, client_type)
|
203 |
+
|
204 |
+
with encrypted_inputs_path.open("wb") as encrypted_inputs_file:
|
205 |
+
encrypted_inputs_file.write(encrypted_inputs)
|
206 |
+
|
207 |
+
# Create a truncated version of the encrypted image for display
|
208 |
+
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs)
|
209 |
+
|
210 |
+
send_input(client_id, client_type)
|
211 |
+
|
212 |
+
# TODO: also return private key representation if possible
|
213 |
+
return encrypted_inputs_short
|
214 |
+
|
215 |
+
|
216 |
+
def run_fhe(client_id):
|
217 |
+
"""Run the model on the encrypted inputs previously sent using FHE.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
client_id (int): The client ID to consider.
|
221 |
+
"""
|
222 |
+
|
223 |
+
# TODO : add a warning for users to send all client types' inputs
|
224 |
+
|
225 |
data = {
|
226 |
+
"client_id": client_id,
|
|
|
227 |
}
|
228 |
|
229 |
+
# Trigger the FHE execution on the encrypted inputs previously sent
|
230 |
url = SERVER_URL + "run_fhe"
|
231 |
with requests.post(
|
232 |
url=url,
|
|
|
235 |
if response.ok:
|
236 |
return response.json()
|
237 |
else:
|
238 |
+
raise gr.Error("Please wait for the inputs to be sent to the server.")
|
239 |
|
240 |
|
241 |
+
def get_output(client_id):
|
242 |
+
"""Retrieve the encrypted output.
|
243 |
|
244 |
Args:
|
245 |
+
client_id (int): The client ID to consider.
|
|
|
246 |
|
247 |
Returns:
|
248 |
+
output_encrypted_representation (numpy.ndarray): A representation of the encrypted output.
|
249 |
|
250 |
"""
|
251 |
data = {
|
252 |
+
"client_id": client_id,
|
|
|
253 |
}
|
254 |
|
255 |
# Retrieve the encrypted output image
|
|
|
263 |
|
264 |
# Save the encrypted output to bytes in a file as it is too large to pass through regular
|
265 |
# Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
266 |
+
# TODO : check if output to user is relevant
|
267 |
+
encrypted_output_path = get_client_file_path("encrypted_output", client_id, "user")
|
268 |
|
269 |
with encrypted_output_path.open("wb") as encrypted_output_file:
|
270 |
encrypted_output_file.write(encrypted_output)
|
271 |
|
272 |
# TODO
|
273 |
+
# Decrypt the output using a different (wrong) key for display
|
274 |
+
# output_encrypted_representation = decrypt_output_with_wrong_key(encrypted_output, client_type)
|
275 |
|
276 |
+
# return output_encrypted_representation
|
277 |
|
278 |
return None
|
279 |
else:
|
280 |
raise gr.Error("Please wait for the FHE execution to be completed.")
|
281 |
|
282 |
|
283 |
+
def decrypt_output(client_id, client_type):
|
284 |
"""Decrypt the result.
|
285 |
|
286 |
Args:
|
287 |
+
client_id (int): The client ID to consider.
|
288 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
289 |
|
290 |
Returns:
|
291 |
+
output(numpy.ndarray): The decrypted output
|
|
|
292 |
|
293 |
"""
|
|
|
|
|
|
|
294 |
# Get the encrypted output path
|
295 |
+
encrypted_output_path = get_client_file_path("encrypted_output", client_id, client_type)
|
296 |
|
297 |
if not encrypted_output_path.is_file():
|
298 |
raise gr.Error("Please run the FHE execution first.")
|
299 |
|
300 |
# Load the encrypted output as bytes
|
301 |
with encrypted_output_path.open("rb") as encrypted_output_file:
|
302 |
+
encrypted_output_proba = encrypted_output_file.read()
|
303 |
|
304 |
# Retrieve the client API
|
305 |
+
client = get_client(client_id, client_type)
|
306 |
|
307 |
# Deserialize, decrypt and post-process the encrypted output
|
308 |
+
output_proba = client.deserialize_decrypt_post_process(encrypted_output_proba)
|
309 |
|
310 |
+
# Determine the predicted class
|
311 |
+
output = numpy.argmax(output_proba, axis=1)
|
312 |
+
|
313 |
+
return output
|
314 |
|
315 |
|
316 |
demo = gr.Blocks()
|
|
|
329 |
gr.Markdown("### Step 1: Infos. ")
|
330 |
with gr.Row():
|
331 |
with gr.Column():
|
332 |
+
gr.Markdown("### User")
|
333 |
# TODO : change infos
|
334 |
choice_1 = gr.Dropdown(choices=["Yes, No"], label="Choose", interactive=True)
|
335 |
slide_1 = gr.Slider(2, 20, value=4, label="Count", info="Choose between 2 and 20")
|
|
|
348 |
gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.")
|
349 |
with gr.Row():
|
350 |
with gr.Column():
|
351 |
+
gr.Markdown("### User")
|
352 |
+
encrypt_button_user = gr.Button("Encrypt the inputs and send to server.")
|
353 |
+
keys_user = gr.Textbox(
|
354 |
label="Keys representation:", max_lines=2, interactive=False
|
355 |
)
|
356 |
+
encrypted_input_user = gr.Textbox(
|
357 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
358 |
)
|
359 |
|
360 |
+
user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
361 |
|
362 |
|
363 |
with gr.Column():
|
364 |
gr.Markdown("### Bank ")
|
365 |
+
encrypt_button_bank = gr.Button("Encrypt the inputs and send to server.")
|
366 |
+
keys_bank = gr.Textbox(
|
367 |
label="Keys representation:", max_lines=2, interactive=False
|
368 |
)
|
369 |
+
encrypted_input_bank = gr.Textbox(
|
370 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
371 |
)
|
372 |
|
|
|
375 |
|
376 |
with gr.Column():
|
377 |
gr.Markdown("### Third Party ")
|
378 |
+
encrypt_button_third_party = gr.Button("Encrypt the inputs and send to server.")
|
379 |
keys_3 = gr.Textbox(
|
380 |
label="Keys representation:", max_lines=2, interactive=False
|
381 |
)
|
382 |
+
encrypted_input__third_party = gr.Textbox(
|
383 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
384 |
)
|
385 |
|
386 |
+
third_party_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
387 |
|
388 |
gr.Markdown("## Server side")
|
389 |
gr.Markdown(
|
|
|
413 |
get_output_button = gr.Button("Receive the encrypted output from the server.")
|
414 |
|
415 |
encrypted_output_representation = gr.Textbox(
|
416 |
+
label="Encrypted output representation: ", max_lines=1, interactive=False
|
417 |
)
|
418 |
|
419 |
gr.Markdown("### Step 8: Decrypt the output.")
|
|
|
423 |
label="Credit card approval decision: ", max_lines=1, interactive=False
|
424 |
)
|
425 |
|
426 |
+
# Button to encrypt inputs on the client side
|
427 |
+
# encrypt_button_user.click(
|
428 |
+
# encrypt,
|
429 |
+
# inputs=[user_id, input_image, filter_name],
|
430 |
+
# outputs=[original_image, encrypted_input],
|
431 |
+
# )
|
432 |
+
|
433 |
+
# # Button to encrypt inputs on the client side
|
434 |
+
# encrypt_button_bank.click(
|
435 |
+
# encrypt,
|
436 |
+
# inputs=[user_id, input_image, filter_name],
|
437 |
+
# outputs=[original_image, encrypted_input],
|
438 |
# )
|
439 |
|
440 |
# # Button to encrypt inputs on the client side
|
441 |
+
# encrypt_button_third_party.click(
|
442 |
# encrypt,
|
443 |
# inputs=[user_id, input_image, filter_name],
|
444 |
# outputs=[original_image, encrypted_input],
|
data/clean_data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
deployment_files/client.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b42d1dff3521c2e7462994c6eafb072bf004108d27c838e690a6702d775c0b5
|
3 |
+
size 35673
|
deployment_files/server.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75b15663431ff4f3788b380c100ea87c1bf97959234aeefb51ae734bed7514c4
|
3 |
+
size 10953
|
deployment_files/versions.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"concrete-python": "2.5.0rc1", "concrete-ml": "1.3.0", "python": "3.10.11"}
|
development/client_server_interface.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import copy
|
3 |
+
|
4 |
+
from concrete.fhe import Value, EvaluationKeys
|
5 |
+
|
6 |
+
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
|
7 |
+
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
|
8 |
+
|
9 |
+
|
10 |
+
class MultiInputsFHEModelDev(FHEModelDev):
|
11 |
+
|
12 |
+
def __init__(self, *arg, **kwargs):
|
13 |
+
|
14 |
+
super().__init__(*arg, **kwargs)
|
15 |
+
|
16 |
+
model = copy.copy(self.model)
|
17 |
+
model.__class__ = ConcreteXGBClassifier
|
18 |
+
self.model = model
|
19 |
+
|
20 |
+
|
21 |
+
class MultiInputsFHEModelClient(FHEModelClient):
|
22 |
+
|
23 |
+
def __init__(self, *args, nb_inputs=1, **kwargs):
|
24 |
+
self.nb_inputs = nb_inputs
|
25 |
+
|
26 |
+
super().__init__(*args, **kwargs)
|
27 |
+
|
28 |
+
def quantize_encrypt_serialize_multi_inputs(self, x: numpy.ndarray, input_index, initial_input_shape, start_position) -> bytes:
|
29 |
+
|
30 |
+
x_padded = numpy.zeros(initial_input_shape)
|
31 |
+
|
32 |
+
end = start_position + x.shape[1]
|
33 |
+
x_padded[:, start_position:end] = x
|
34 |
+
|
35 |
+
q_x_padded = self.model.quantize_input(x_padded)
|
36 |
+
|
37 |
+
q_x = q_x_padded[:, start_position:end]
|
38 |
+
|
39 |
+
q_x_padded = [None for _ in range(self.nb_inputs)]
|
40 |
+
q_x_padded[input_index] = q_x
|
41 |
+
|
42 |
+
# Encrypt the values
|
43 |
+
q_x_enc = self.client.encrypt(*q_x_padded)
|
44 |
+
|
45 |
+
# Serialize the encrypted values to be sent to the server
|
46 |
+
q_x_enc_ser = q_x_enc[input_index].serialize()
|
47 |
+
return q_x_enc_ser
|
48 |
+
|
49 |
+
|
50 |
+
class MultiInputsFHEModelServer(FHEModelServer):
|
51 |
+
|
52 |
+
def run(
|
53 |
+
self,
|
54 |
+
*serialized_encrypted_quantized_data: bytes,
|
55 |
+
serialized_evaluation_keys: bytes,
|
56 |
+
) -> bytes:
|
57 |
+
"""Run the model on the server over encrypted data.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
serialized_encrypted_quantized_data (bytes): the encrypted, quantized
|
61 |
+
and serialized data
|
62 |
+
serialized_evaluation_keys (bytes): the serialized evaluation keys
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
bytes: the result of the model
|
66 |
+
"""
|
67 |
+
assert self.server is not None, "Model has not been loaded."
|
68 |
+
|
69 |
+
deserialized_encrypted_quantized_data = tuple(Value.deserialize(data) for data in serialized_encrypted_quantized_data)
|
70 |
+
|
71 |
+
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
|
72 |
+
|
73 |
+
result = self.server.run(
|
74 |
+
*deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
|
75 |
+
)
|
76 |
+
serialized_result = result.serialize()
|
77 |
+
return serialized_result
|
development/development.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"A script to generate all development files necessary for the project."
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
import numpy
|
5 |
+
import pandas
|
6 |
+
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from imblearn.over_sampling import SMOTE
|
9 |
+
|
10 |
+
from ..settings import DEPLOYMENT_PATH, RANDOM_STATE
|
11 |
+
from client_server_interface import MultiInputsFHEModelDev
|
12 |
+
from model import MultiInputXGBClassifier
|
13 |
+
from development.pre_processing import pre_process_data
|
14 |
+
|
15 |
+
|
16 |
+
print("Load and pre-process the data")
|
17 |
+
|
18 |
+
data = pandas.read_csv("data/clean_data.csv", encoding="utf-8")
|
19 |
+
|
20 |
+
# Make median annual salary similar to France (2023): from 157500 to 22050
|
21 |
+
data["Total_income"] = data["Total_income"] * 0.14
|
22 |
+
|
23 |
+
# Remove ID feature
|
24 |
+
data.drop("ID", axis=1, inplace=True)
|
25 |
+
|
26 |
+
# Feature engineer the data
|
27 |
+
pre_processed_data, training_bins = pre_process_data(data)
|
28 |
+
|
29 |
+
# Define input and target data
|
30 |
+
y = pre_processed_data.pop("Target")
|
31 |
+
x = pre_processed_data
|
32 |
+
|
33 |
+
# The initial data-set is very imbalanced: use SMOTE to get better results
|
34 |
+
x, y = SMOTE().fit_resample(x, y)
|
35 |
+
|
36 |
+
# Retrieve the training data
|
37 |
+
X_train, _, y_train, _ = train_test_split(
|
38 |
+
x, y, stratify=y, test_size=0.3, random_state=RANDOM_STATE
|
39 |
+
)
|
40 |
+
|
41 |
+
# Convert the Pandas data frames into Numpy arrays
|
42 |
+
X_train_np = X_train.to_numpy()
|
43 |
+
y_train_np = y_train.to_numpy()
|
44 |
+
|
45 |
+
|
46 |
+
print("Train and compile the model")
|
47 |
+
|
48 |
+
model = MultiInputXGBClassifier(max_depth=3, n_estimators=40)
|
49 |
+
|
50 |
+
model.fit(X_train_np, y_train_np)
|
51 |
+
|
52 |
+
multi_inputs_train = numpy.array_split(X_train_np, 3, axis=1)
|
53 |
+
|
54 |
+
model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
|
55 |
+
|
56 |
+
# Delete the deployment folder and its content if it already exists
|
57 |
+
if DEPLOYMENT_PATH.is_dir():
|
58 |
+
shutil.rmtree(DEPLOYMENT_PATH)
|
59 |
+
|
60 |
+
|
61 |
+
print("Save deployment files")
|
62 |
+
|
63 |
+
# Save the files needed for deployment
|
64 |
+
fhe_dev = MultiInputsFHEModelDev(model, DEPLOYMENT_PATH)
|
65 |
+
fhe_dev.save()
|
66 |
+
|
67 |
+
print("Done !")
|
development/model.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
from typing import Optional, Sequence, Union
|
3 |
+
|
4 |
+
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
|
5 |
+
|
6 |
+
from concrete.ml.common.check_inputs import check_array_and_assert
|
7 |
+
from concrete.ml.common.utils import (
|
8 |
+
generate_proxy_function,
|
9 |
+
manage_parameters_for_pbs_errors,
|
10 |
+
check_there_is_no_p_error_options_in_configuration
|
11 |
+
)
|
12 |
+
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
|
13 |
+
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
|
14 |
+
|
15 |
+
|
16 |
+
class MultiInputXGBClassifier(ConcreteXGBClassifier):
|
17 |
+
|
18 |
+
def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray:
|
19 |
+
self.check_model_is_fitted()
|
20 |
+
assert sum(input.shape[1] for input in X) == len(self.input_quantizers)
|
21 |
+
|
22 |
+
base_j = 0
|
23 |
+
q_inputs = []
|
24 |
+
for i, input in enumerate(X):
|
25 |
+
q_input = numpy.zeros_like(input, dtype=numpy.int64)
|
26 |
+
|
27 |
+
for j in range(input.shape[1]):
|
28 |
+
quantizer_index = base_j + j
|
29 |
+
q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j])
|
30 |
+
|
31 |
+
assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values"
|
32 |
+
|
33 |
+
q_inputs.append(q_input)
|
34 |
+
base_j += input.shape[1]
|
35 |
+
|
36 |
+
return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0]
|
37 |
+
|
38 |
+
def compile(
|
39 |
+
self,
|
40 |
+
*inputs,
|
41 |
+
configuration: Optional[Configuration] = None,
|
42 |
+
artifacts: Optional[DebugArtifacts] = None,
|
43 |
+
show_mlir: bool = False,
|
44 |
+
p_error: Optional[float] = None,
|
45 |
+
global_p_error: Optional[float] = None,
|
46 |
+
verbose: bool = False,
|
47 |
+
inputs_encryption_status: Optional[Sequence[str]] = None,
|
48 |
+
) -> Circuit:
|
49 |
+
|
50 |
+
# Check that the model is correctly fitted
|
51 |
+
self.check_model_is_fitted()
|
52 |
+
|
53 |
+
# Cast pandas, list or torch to numpy
|
54 |
+
inputs_as_array = []
|
55 |
+
for input in inputs:
|
56 |
+
input_as_array = check_array_and_assert(input)
|
57 |
+
inputs_as_array.append(input_as_array)
|
58 |
+
|
59 |
+
inputs_as_array = tuple(inputs_as_array)
|
60 |
+
|
61 |
+
# p_error or global_p_error should not be set in both the configuration and direct arguments
|
62 |
+
check_there_is_no_p_error_options_in_configuration(configuration)
|
63 |
+
|
64 |
+
# Find the right way to set parameters for compiler, depending on the way we want to default
|
65 |
+
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)
|
66 |
+
|
67 |
+
# Quantize the inputs
|
68 |
+
quantized_inputs = self.quantize_input(*inputs_as_array)
|
69 |
+
|
70 |
+
# Generate the compilation input-set with proper dimensions
|
71 |
+
inputset = _get_inputset_generator(quantized_inputs)
|
72 |
+
|
73 |
+
# Reset for double compile
|
74 |
+
self._is_compiled = False
|
75 |
+
|
76 |
+
# Retrieve the compiler instance
|
77 |
+
module_to_compile = self._get_module_to_compile(inputs_encryption_status)
|
78 |
+
|
79 |
+
# Compiling using a QuantizedModule requires different steps and should not be done here
|
80 |
+
assert isinstance(module_to_compile, Compiler), (
|
81 |
+
"Wrong module to compile. Expected to be of type `Compiler` but got "
|
82 |
+
f"{type(module_to_compile)}."
|
83 |
+
)
|
84 |
+
|
85 |
+
# Jit compiler is now deprecated and will soon be removed, it is thus forced to False
|
86 |
+
# by default
|
87 |
+
self.fhe_circuit_ = module_to_compile.compile(
|
88 |
+
inputset,
|
89 |
+
configuration=configuration,
|
90 |
+
artifacts=artifacts,
|
91 |
+
show_mlir=show_mlir,
|
92 |
+
p_error=p_error,
|
93 |
+
global_p_error=global_p_error,
|
94 |
+
verbose=verbose,
|
95 |
+
single_precision=False,
|
96 |
+
fhe_simulation=False,
|
97 |
+
fhe_execution=True,
|
98 |
+
jit=False,
|
99 |
+
)
|
100 |
+
|
101 |
+
self._is_compiled = True
|
102 |
+
|
103 |
+
# For mypy
|
104 |
+
assert isinstance(self.fhe_circuit, Circuit)
|
105 |
+
|
106 |
+
return self.fhe_circuit
|
107 |
+
|
108 |
+
def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]:
|
109 |
+
assert self._tree_inference is not None, self._is_not_fitted_error_message()
|
110 |
+
|
111 |
+
if not self._is_compiled:
|
112 |
+
xgb_inference = self._tree_inference
|
113 |
+
self._tree_inference = lambda *args: xgb_inference(numpy.concatenate(args, axis=1))
|
114 |
+
|
115 |
+
input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))]
|
116 |
+
|
117 |
+
# Generate the proxy function to compile
|
118 |
+
_tree_inference_proxy, function_arg_names = generate_proxy_function(
|
119 |
+
self._tree_inference, input_names
|
120 |
+
)
|
121 |
+
|
122 |
+
inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)}
|
123 |
+
|
124 |
+
# Create the compiler instance
|
125 |
+
compiler = Compiler(
|
126 |
+
_tree_inference_proxy,
|
127 |
+
inputs_encryption_statuses,
|
128 |
+
)
|
129 |
+
|
130 |
+
return compiler
|
development/pre_processing.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas
|
2 |
+
from copy import deepcopy
|
3 |
+
|
4 |
+
|
5 |
+
def convert_dummy(df, feature):
|
6 |
+
pos = pandas.get_dummies(df[feature], prefix=feature)
|
7 |
+
|
8 |
+
df.drop([feature], axis=1, inplace=True)
|
9 |
+
df = df.join(pos)
|
10 |
+
return df
|
11 |
+
|
12 |
+
|
13 |
+
def get_category(df, col, labels, qcut=False, binsnum=None, bins=None, retbins=False):
|
14 |
+
assert binsnum is not None or bins is not None
|
15 |
+
|
16 |
+
if qcut and binsnum is not None:
|
17 |
+
localdf, bin_edges = pandas.qcut(df[col], q=binsnum, labels=labels, retbins=True) # quantile cut
|
18 |
+
else:
|
19 |
+
input_bins = bins if bins is not None else binsnum
|
20 |
+
localdf, bin_edges = pandas.cut(df[col], bins=input_bins, labels=labels, retbins=True) # equal-length cut
|
21 |
+
|
22 |
+
df.drop(col, axis=1, inplace=True)
|
23 |
+
|
24 |
+
localdf = pandas.DataFrame(localdf)
|
25 |
+
df = df.join(localdf[col])
|
26 |
+
|
27 |
+
if retbins:
|
28 |
+
return df, bin_edges
|
29 |
+
|
30 |
+
return df
|
31 |
+
|
32 |
+
|
33 |
+
def pre_process_data(input_data, bins=None, columns=None):
|
34 |
+
assert bins is None or ("bin_edges_income" in bins and "bin_edges_age" in bins and "bin_edges_years_employed" in bins and columns is not None)
|
35 |
+
|
36 |
+
training_bins = {}
|
37 |
+
|
38 |
+
input_data = deepcopy(input_data)
|
39 |
+
bins = deepcopy(bins) if bins is not None else None
|
40 |
+
|
41 |
+
input_data.loc[input_data["Num_children"] >= 2, "Num_children"] = "2_or_more"
|
42 |
+
|
43 |
+
input_data = convert_dummy(input_data, "Num_children")
|
44 |
+
|
45 |
+
if bins is None:
|
46 |
+
input_data, bin_edges_income = get_category(input_data, "Total_income", ["low", "medium", "high"], qcut=True, binsnum=3, retbins=True)
|
47 |
+
training_bins["bin_edges_income"] = bin_edges_income
|
48 |
+
else:
|
49 |
+
input_data = get_category(input_data, "Total_income", ["low", "medium", "high"], bins=bins["bin_edges_income"])
|
50 |
+
|
51 |
+
input_data = convert_dummy(input_data, "Total_income")
|
52 |
+
|
53 |
+
if bins is None:
|
54 |
+
input_data, bin_edges_age = get_category(input_data, "Age", ["lowest", "low", "medium", "high", "highest"], binsnum=5, retbins=True)
|
55 |
+
training_bins["bin_edges_age"] = bin_edges_age
|
56 |
+
else:
|
57 |
+
input_data = get_category(input_data, "Age", ["lowest", "low", "medium", "high", "highest"], bins=bins["bin_edges_age"])
|
58 |
+
|
59 |
+
input_data = convert_dummy(input_data, "Age")
|
60 |
+
|
61 |
+
if bins is None:
|
62 |
+
input_data, bin_edges_years_employed = get_category(input_data, "Years_employed", ["lowest", "low", "medium", "high", "highest"], binsnum=5, retbins=True)
|
63 |
+
training_bins["bin_edges_years_employed"] = bin_edges_years_employed
|
64 |
+
else:
|
65 |
+
input_data = get_category(input_data, "Years_employed", ["lowest", "low", "medium", "high", "highest"], bins=bins["bin_edges_years_employed"])
|
66 |
+
|
67 |
+
input_data = convert_dummy(input_data, "Years_employed")
|
68 |
+
|
69 |
+
input_data.loc[input_data["Num_family"] >= 3, "Num_family"] = "3_or_more"
|
70 |
+
|
71 |
+
input_data = convert_dummy(input_data, "Num_family")
|
72 |
+
|
73 |
+
input_data.loc[input_data["Income_type"] == "Pensioner", "Income_type"] = "State servant"
|
74 |
+
input_data.loc[input_data["Income_type"] == "Student", "Income_type"] = "State servant"
|
75 |
+
|
76 |
+
input_data = convert_dummy(input_data, "Income_type")
|
77 |
+
|
78 |
+
input_data.loc[
|
79 |
+
(input_data["Occupation_type"] == "Cleaning staff")
|
80 |
+
| (input_data["Occupation_type"] == "Cooking staff")
|
81 |
+
| (input_data["Occupation_type"] == "Drivers")
|
82 |
+
| (input_data["Occupation_type"] == "Laborers")
|
83 |
+
| (input_data["Occupation_type"] == "Low-skill Laborers")
|
84 |
+
| (input_data["Occupation_type"] == "Security staff")
|
85 |
+
| (input_data["Occupation_type"] == "Waiters/barmen staff"),
|
86 |
+
"Occupation_type",
|
87 |
+
] = "Labor_work"
|
88 |
+
input_data.loc[
|
89 |
+
(input_data["Occupation_type"] == "Accountants")
|
90 |
+
| (input_data["Occupation_type"] == "Core staff")
|
91 |
+
| (input_data["Occupation_type"] == "HR staff")
|
92 |
+
| (input_data["Occupation_type"] == "Medicine staff")
|
93 |
+
| (input_data["Occupation_type"] == "Private service staff")
|
94 |
+
| (input_data["Occupation_type"] == "Realty agents")
|
95 |
+
| (input_data["Occupation_type"] == "Sales staff")
|
96 |
+
| (input_data["Occupation_type"] == "Secretaries"),
|
97 |
+
"Occupation_type",
|
98 |
+
] = "Office_work"
|
99 |
+
input_data.loc[
|
100 |
+
(input_data["Occupation_type"] == "Managers")
|
101 |
+
| (input_data["Occupation_type"] == "High skill tech staff")
|
102 |
+
| (input_data["Occupation_type"] == "IT staff"),
|
103 |
+
"Occupation_type",
|
104 |
+
] = "High_tech_work"
|
105 |
+
|
106 |
+
input_data = convert_dummy(input_data, "Occupation_type")
|
107 |
+
|
108 |
+
input_data = convert_dummy(input_data, "Housing_type")
|
109 |
+
|
110 |
+
input_data.loc[input_data["Education_type"] == "Academic degree", "Education_type"] = "Higher education"
|
111 |
+
input_data = convert_dummy(input_data, "Education_type")
|
112 |
+
|
113 |
+
input_data = convert_dummy(input_data, "Family_status")
|
114 |
+
|
115 |
+
input_data = input_data.astype("int")
|
116 |
+
|
117 |
+
if training_bins:
|
118 |
+
return input_data, training_bins
|
119 |
+
|
120 |
+
input_data = input_data.reindex(columns=columns, fill_value=0)
|
121 |
+
|
122 |
+
return input_data
|
requirements.txt
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
# concrete-ml==1.3.0
|
2 |
-
gradio==3.40.1
|
|
|
|
|
|
1 |
# concrete-ml==1.3.0
|
2 |
+
gradio==3.40.1
|
3 |
+
concrete-ml==1.3.0
|
4 |
+
imblearn==0.0
|
server.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Server that will listen for GET and POST requests from the client."""
|
2 |
+
|
3 |
+
import time
|
4 |
+
from typing import List
|
5 |
+
from fastapi import FastAPI, File, Form, UploadFile
|
6 |
+
from fastapi.responses import JSONResponse, Response
|
7 |
+
|
8 |
+
from settings import DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
|
9 |
+
from development.client_server_interface import MultiInputsFHEModelServer
|
10 |
+
|
11 |
+
# Load the server objects related to all currently available filters once and for all
|
12 |
+
FHE_SERVER = MultiInputsFHEModelServer(DEPLOYMENT_PATH / "deployment")
|
13 |
+
|
14 |
+
def get_server_file_path(name, client_id, client_type):
|
15 |
+
"""Get the correct temporary file path for the server.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs').
|
19 |
+
client_id (int): The client ID to consider.
|
20 |
+
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
pathlib.Path: The file path.
|
24 |
+
"""
|
25 |
+
return SERVER_FILES / f"{name}_{client_type}_{client_id}"
|
26 |
+
|
27 |
+
|
28 |
+
# Initialize an instance of FastAPI
|
29 |
+
app = FastAPI()
|
30 |
+
|
31 |
+
# Define the default route
|
32 |
+
@app.get("/")
|
33 |
+
def root():
|
34 |
+
return {"message": "Welcome to Credit Card Approval Prediction server!"}
|
35 |
+
|
36 |
+
|
37 |
+
@app.post("/send_input")
|
38 |
+
def send_input(
|
39 |
+
client_id: str = Form(),
|
40 |
+
client_type: str = Form(),
|
41 |
+
files: List[UploadFile] = File(),
|
42 |
+
):
|
43 |
+
"""Send the inputs to the server."""
|
44 |
+
# Retrieve the encrypted inputs and the evaluation key paths
|
45 |
+
encrypted_inputs_path = get_server_file_path("encrypted_inputs", client_id, client_type)
|
46 |
+
evaluation_key_path = get_server_file_path("evaluation_key", client_id, client_type)
|
47 |
+
|
48 |
+
# Write the files using the above paths
|
49 |
+
with encrypted_inputs_path.open("wb") as encrypted_inputs, evaluation_key_path.open(
|
50 |
+
"wb"
|
51 |
+
) as evaluation_key:
|
52 |
+
encrypted_inputs.write(files[0].file.read())
|
53 |
+
evaluation_key.write(files[1].file.read())
|
54 |
+
|
55 |
+
|
56 |
+
@app.post("/run_fhe")
|
57 |
+
def run_fhe(
|
58 |
+
client_id: str = Form(),
|
59 |
+
):
|
60 |
+
"""Execute the model on the encrypted inputs using FHE."""
|
61 |
+
# Retrieve the evaluation key
|
62 |
+
evaluation_key_path = get_server_file_path("evaluation_key", client_id, "user")
|
63 |
+
|
64 |
+
# Get the evaluation key
|
65 |
+
with evaluation_key_path.open("rb") as evaluation_key_file:
|
66 |
+
evaluation_key = evaluation_key_file.read()
|
67 |
+
|
68 |
+
# Get the encrypted inputs
|
69 |
+
encrypted_inputs = []
|
70 |
+
for client_type in CLIENT_TYPES:
|
71 |
+
encrypted_inputs_path = get_server_file_path("encrypted_inputs", client_id, client_type)
|
72 |
+
with encrypted_inputs_path.open("rb") as encrypted_inputs_file:
|
73 |
+
encrypted_input = encrypted_inputs_file.read()
|
74 |
+
encrypted_inputs.append(encrypted_input)
|
75 |
+
|
76 |
+
# Run the FHE execution
|
77 |
+
start = time.time()
|
78 |
+
encrypted_output = FHE_SERVER.run(*encrypted_inputs, serialized_evaluation_keys=evaluation_key)
|
79 |
+
fhe_execution_time = round(time.time() - start, 2)
|
80 |
+
|
81 |
+
# Retrieve the encrypted output path
|
82 |
+
encrypted_output_path = get_server_file_path("encrypted_output", client_id, client_type)
|
83 |
+
|
84 |
+
# Write the file using the above path
|
85 |
+
with encrypted_output_path.open("wb") as output_file:
|
86 |
+
output_file.write(encrypted_output)
|
87 |
+
|
88 |
+
return JSONResponse(content=fhe_execution_time)
|
89 |
+
|
90 |
+
|
91 |
+
@app.post("/get_output")
|
92 |
+
def get_output(
|
93 |
+
client_id: str = Form(),
|
94 |
+
client_type: str = Form(),
|
95 |
+
):
|
96 |
+
"""Retrieve the encrypted output."""
|
97 |
+
# Retrieve the encrypted output path
|
98 |
+
encrypted_output_path = get_server_file_path("encrypted_output", client_id, client_type)
|
99 |
+
|
100 |
+
# Read the file using the above path
|
101 |
+
with encrypted_output_path.open("rb") as encrypted_output_file:
|
102 |
+
encrypted_output = encrypted_output_file.read()
|
103 |
+
|
104 |
+
return Response(encrypted_output)
|
settings.py
CHANGED
@@ -6,7 +6,7 @@ from pathlib import Path
|
|
6 |
REPO_DIR = Path(__file__).parent
|
7 |
|
8 |
# This repository's main necessary directories
|
9 |
-
|
10 |
FHE_KEYS = REPO_DIR / ".fhe_keys"
|
11 |
CLIENT_FILES = REPO_DIR / "client_files"
|
12 |
SERVER_FILES = REPO_DIR / "server_files"
|
@@ -19,3 +19,18 @@ SERVER_FILES.mkdir(exist_ok=True)
|
|
19 |
# Store the server's URL
|
20 |
SERVER_URL = "http://localhost:8000/"
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
REPO_DIR = Path(__file__).parent
|
7 |
|
8 |
# This repository's main necessary directories
|
9 |
+
DEPLOYMENT_PATH = REPO_DIR / "deployment_files"
|
10 |
FHE_KEYS = REPO_DIR / ".fhe_keys"
|
11 |
CLIENT_FILES = REPO_DIR / "client_files"
|
12 |
SERVER_FILES = REPO_DIR / "server_files"
|
|
|
19 |
# Store the server's URL
|
20 |
SERVER_URL = "http://localhost:8000/"
|
21 |
|
22 |
+
RANDOM_STATE = 0
|
23 |
+
|
24 |
+
INITIAL_INPUT_SHAPE = (1, 49)
|
25 |
+
|
26 |
+
CLIENT_TYPES = ["user", "bank", "third_party"]
|
27 |
+
INPUT_INDEXES = {
|
28 |
+
"user": 0,
|
29 |
+
"bank": 1,
|
30 |
+
"third_party": 2,
|
31 |
+
}
|
32 |
+
START_POSITIONS = {
|
33 |
+
"user": 0, # First position: start from 0
|
34 |
+
"bank": 17, # Second position: start from len(input_user)
|
35 |
+
"third_party": 33, # Third position: start from len(input_user) + len(input_bank)
|
36 |
+
}
|