qqubb commited on
Commit
f775b6f
1 Parent(s): 444e347

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +13 -19
  2. server.py +2 -2
  3. utils.py +0 -110
app.py CHANGED
@@ -16,11 +16,7 @@ from utils import (
16
  INPUT_BROWSER_LIMIT,
17
  KEYS_DIR,
18
  SERVER_URL,
19
- TARGET_COLUMNS,
20
- TRAINING_FILENAME,
21
  clean_directory,
22
- load_data,
23
- pretty_print,
24
  )
25
 
26
  import requests
@@ -149,7 +145,7 @@ def collect_input(passenger_class, is_male, age, company, fare, embark_point):
149
  (1 if "Sibling" in company else 0) + (2 if "Child" in company else 0)
150
  ]
151
  }
152
- print(input_dict)
153
  return input_dict
154
 
155
  def clear_predict_survival(input_dict):
@@ -166,9 +162,9 @@ def concrete_predict_survival(input_dict):
166
  pred = concrete_clf.predict_proba(df)[0]
167
  return {"Perishes": float(pred[0]), "Survives": float(pred[1])}
168
 
169
- print("\nclear_test ", clear_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]}))
170
 
171
- print("encrypted_test", concrete_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]}),"\n")
172
 
173
 
174
  def key_gen_fn() -> Dict:
@@ -230,9 +226,6 @@ def encrypt_fn(user_inputs: np.ndarray, user_id: str) -> None:
230
  # Retrieve the client API
231
  client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
232
  client.load()
233
-
234
- # user_inputs = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1)
235
- # quant_user_symptoms = client.model.quantize_input(user_inputs)
236
 
237
  user_inputs_df = pd.DataFrame.from_dict(user_inputs)
238
  user_inputs_df = encode_age(user_inputs_df)
@@ -268,7 +261,7 @@ def send_input_fn(user_id: str, user_inputs: np.ndarray) -> Dict:
268
  error_box4: gr.update(
269
  visible=True,
270
  value="⚠️ Please check your connectivity \n"
271
- "⚠️ Ensure that the symptoms have been submitted and the evaluation "
272
  "key has been generated before sending the data to the server.",
273
  )
274
  }
@@ -333,7 +326,7 @@ def run_fhe_fn(user_id: str) -> Dict:
333
  error_box5: gr.update(
334
  visible=True,
335
  value="⚠️ Please check your connectivity \n"
336
- "⚠️ Ensure that the symptoms have been submitted, the evaluation "
337
  "key has been generated and the server received the data "
338
  "before processing the data.",
339
  ),
@@ -379,7 +372,7 @@ def send_input_fn(user_id: str, user_inputs: np.ndarray) -> Dict:
379
  error_box4: gr.update(
380
  visible=True,
381
  value="⚠️ Please check your connectivity \n"
382
- "⚠️ Ensure that the symptoms have been submitted and the evaluation "
383
  "key has been generated before sending the data to the server.",
384
  )
385
  }
@@ -534,17 +527,18 @@ def decrypt_fn(user_id: str, user_inputs: np.ndarray) -> Dict:
534
  with gr.Blocks() as demo:
535
 
536
  # Step 1.1: Provide inputs
 
537
  with gr.Row():
538
  inp = [
539
- gr.Dropdown(["first", "second", "third"], type="index"),
540
- gr.Checkbox(label="is_male"),
541
- gr.Slider(0, 80, value=25),
542
  gr.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"),
543
- gr.Number(value=20),
544
- gr.Radio(["S", "C", "Q"], type="index"),
545
  ]
546
  out = gr.JSON()
547
- btn = gr.Button("Run")
548
  btn.click(fn=collect_input, inputs=inp, outputs=out)
549
 
550
  # Step 2.1: Key generation
 
16
  INPUT_BROWSER_LIMIT,
17
  KEYS_DIR,
18
  SERVER_URL,
 
 
19
  clean_directory,
 
 
20
  )
21
 
22
  import requests
 
145
  (1 if "Sibling" in company else 0) + (2 if "Child" in company else 0)
146
  ]
147
  }
148
+ # print(input_dict)
149
  return input_dict
150
 
151
  def clear_predict_survival(input_dict):
 
162
  pred = concrete_clf.predict_proba(df)[0]
163
  return {"Perishes": float(pred[0]), "Survives": float(pred[1])}
164
 
165
+ # print("\nclear_test ", clear_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]}))
166
 
167
+ # print("encrypted_test", concrete_predict_survival({'Pclass': [1], 'Sex': [0], 'Age': [25], 'Fare': [20.0], 'Embarked': [2], 'Company': [1]}),"\n")
168
 
169
 
170
  def key_gen_fn() -> Dict:
 
226
  # Retrieve the client API
227
  client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
228
  client.load()
 
 
 
229
 
230
  user_inputs_df = pd.DataFrame.from_dict(user_inputs)
231
  user_inputs_df = encode_age(user_inputs_df)
 
261
  error_box4: gr.update(
262
  visible=True,
263
  value="⚠️ Please check your connectivity \n"
264
+ "⚠️ Ensure that the inputs have been submitted and the evaluation "
265
  "key has been generated before sending the data to the server.",
266
  )
267
  }
 
326
  error_box5: gr.update(
327
  visible=True,
328
  value="⚠️ Please check your connectivity \n"
329
+ "⚠️ Ensure that the inputs have been submitted, the evaluation "
330
  "key has been generated and the server received the data "
331
  "before processing the data.",
332
  ),
 
372
  error_box4: gr.update(
373
  visible=True,
374
  value="⚠️ Please check your connectivity \n"
375
+ "⚠️ Ensure that the inputs have been submitted and the evaluation "
376
  "key has been generated before sending the data to the server.",
377
  )
378
  }
 
527
  with gr.Blocks() as demo:
528
 
529
  # Step 1.1: Provide inputs
530
+ gr.Markdown("###Titanic Survival Prediction with ML and Private Computation")
531
  with gr.Row():
532
  inp = [
533
+ gr.Dropdown(["first", "second", "third"], type="index", label="Select Passenger Class"),
534
+ gr.Checkbox(label="Male?"),
535
+ gr.Slider(0, 80, value=25, label="Age", step=1),
536
  gr.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"),
537
+ gr.Number(value=20, label="Fare"),
538
+ gr.Radio(["Southampton", "Cherbourg", "Queenstown"], type="index", label="Embark point:"),
539
  ]
540
  out = gr.JSON()
541
+ btn = gr.Button("Confirm inputs")
542
  btn.click(fn=collect_input, inputs=inp, outputs=out)
543
 
544
  # Step 2.1: Key generation
server.py CHANGED
@@ -16,12 +16,12 @@ app = FastAPI()
16
  @app.get("/")
17
  def root():
18
  """
19
- Root endpoint of the health prediction API.
20
 
21
  Returns:
22
  dict: The welcome message.
23
  """
24
- return {"message": "Welcome to your disease prediction with FHE!"}
25
 
26
 
27
  @app.post("/send_input")
 
16
  @app.get("/")
17
  def root():
18
  """
19
+ Root endpoint of the titanic survival prediction API.
20
 
21
  Returns:
22
  dict: The welcome message.
23
  """
24
+ return {"message": "Welcome to titanic survival prediction with FHE!"}
25
 
26
 
27
  @app.post("/send_input")
utils.py CHANGED
@@ -23,47 +23,6 @@ SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
23
 
24
  ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]
25
 
26
- # Columns that define the target
27
- TARGET_COLUMNS = ["prognosis_encoded", "prognosis"]
28
-
29
- TRAINING_FILENAME = "./data/Training_preprocessed.csv"
30
- TESTING_FILENAME = "./data/Testing_preprocessed.csv"
31
-
32
- # pylint: disable=invalid-name
33
-
34
-
35
- def pretty_print(
36
- inputs, case_conversion=str.title, which_replace: str = "_", to_what: str = " ", delimiter=None
37
- ):
38
- """
39
- Prettify and sort the input as a list of string.
40
-
41
- Args:
42
- inputs (Any): The inputs to be prettified.
43
-
44
- Returns:
45
- List: The prettified and sorted list of inputs.
46
-
47
- """
48
- # Flatten the list if required
49
- pretty_list = []
50
- for item in inputs:
51
- if isinstance(item, list):
52
- pretty_list.extend(item)
53
- else:
54
- pretty_list.append(item)
55
-
56
- # Sort
57
- pretty_list = sorted(list(set(pretty_list)))
58
- # Replace
59
- pretty_list = [item.replace(which_replace, to_what) for item in pretty_list]
60
- pretty_list = [case_conversion(item) for item in pretty_list]
61
- if delimiter:
62
- pretty_list = f"{delimiter.join(pretty_list)}."
63
-
64
- return pretty_list
65
-
66
-
67
  def clean_directory() -> None:
68
  """
69
  Clear direcgtories
@@ -73,72 +32,3 @@ def clean_directory() -> None:
73
  if os.path.exists(target_dir) and os.path.isdir(target_dir):
74
  shutil.rmtree(target_dir)
75
  target_dir.mkdir(exist_ok=True, parents=True)
76
-
77
-
78
- def get_disease_name(encoded_prediction: int, file_name: str = TRAINING_FILENAME) -> str:
79
- """Return the disease name given its encoded label.
80
-
81
- Args:
82
- encoded_prediction (int): The encoded prediction
83
- file_name (str): The data file path
84
-
85
- Returns:
86
- str: The according disease name
87
- """
88
- df = pandas.read_csv(file_name, usecols=TARGET_COLUMNS).drop_duplicates()
89
- disease_name, _ = df[df[TARGET_COLUMNS[0]] == encoded_prediction].values.flatten()
90
- return disease_name
91
-
92
-
93
- def load_data() -> Union[Tuple[pandas.DataFrame, numpy.ndarray], List]:
94
- """
95
- Return the data
96
-
97
- Args:
98
- None
99
-
100
- Return:
101
- The train, testing set and valid symptoms.
102
- """
103
- # Load data
104
- df_train = pandas.read_csv(TRAINING_FILENAME)
105
- df_test = pandas.read_csv(TESTING_FILENAME)
106
-
107
- # Separate the traget from the training / testing set:
108
- # TARGET_COLUMNS[0] -> "prognosis_encoded" -> contains the numeric label of the disease
109
- # TARGET_COLUMNS[1] -> "prognosis" -> contains the name of the disease
110
-
111
- y_train = df_train[TARGET_COLUMNS[0]]
112
- X_train = df_train.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
113
-
114
- y_test = df_test[TARGET_COLUMNS[0]]
115
- X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
116
-
117
- return (
118
- (X_train, X_test),
119
- (y_train, y_test),
120
- X_train.columns.to_list(),
121
- df_train[TARGET_COLUMNS[1]].unique().tolist(),
122
- )
123
-
124
-
125
- def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray):
126
- """
127
- Load a pre-trained serialized model
128
-
129
- Args:
130
- X_train (pandas.DataFrame): Training set
131
- y_train (numpy.ndarray): Targets of the training set
132
-
133
- Return:
134
- The Concrete ML model and its circuit
135
- """
136
- # Parameters
137
- concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1}
138
- classifier = ConcreteXGBoostClassifier(**concrete_args)
139
- # Train the model
140
- classifier.fit(X_train, y_train)
141
- # Compile the model
142
- circuit = classifier.compile(X_train)
143
-
144
- return classifier, circuit
 
23
 
24
  ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def clean_directory() -> None:
27
  """
28
  Clear direcgtories
 
32
  if os.path.exists(target_dir) and os.path.isdir(target_dir):
33
  shutil.rmtree(target_dir)
34
  target_dir.mkdir(exist_ok=True, parents=True)