gregoiregllt commited on
Commit
86b66ac
·
1 Parent(s): 5195d8d
Files changed (6) hide show
  1. .gitignore +4 -1
  2. app.py +201 -1
  3. dev.py +52 -0
  4. requirements.txt +1 -0
  5. symptoms_categories.py +197 -0
  6. utils.py +144 -0
.gitignore CHANGED
@@ -1,3 +1,6 @@
1
  venv/
2
 
3
- __pycache__/
 
 
 
 
1
  venv/
2
 
3
+ __pycache__/
4
+
5
+ client.zip
6
+ server.zip
app.py CHANGED
@@ -1,7 +1,207 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, Form, UploadFile
2
+ from fastapi.responses import JSONResponse, Response
3
+ from concrete.ml.deployment import FHEModelServer
4
+ import numpy as np
5
+
6
+ from concrete.ml.deployment import FHEModelClient
7
+ import subprocess
8
+ from pathlib import Path
9
+
10
+
11
+ from utils import (
12
+ CLIENT_DIR,
13
+ CURRENT_DIR,
14
+ DEPLOYMENT_DIR,
15
+ SERVER_DIR,
16
+ INPUT_BROWSER_LIMIT,
17
+ KEYS_DIR,
18
+ SERVER_URL,
19
+ TARGET_COLUMNS,
20
+ TRAINING_FILENAME,
21
+ clean_directory,
22
+ get_disease_name,
23
+ load_data,
24
+ pretty_print,
25
+ )
26
+
27
+
28
+ import time
29
+ from typing import List
30
+
31
+ # Load the FHE server
32
+ FHE_SERVER = FHEModelServer(DEPLOYMENT_DIR)
33
+
34
 
35
  app = FastAPI()
36
 
37
  @app.get("/")
38
  def greet_json():
39
  return {"Hello": "World!"}
40
+
41
+
42
+ def root():
43
+ """
44
+ Root endpoint of the health prediction API.
45
+
46
+ Returns:
47
+ dict: The welcome message.
48
+ """
49
+ return {"message": "Welcome to your disease prediction with FHE!"}
50
+
51
+
52
+ @app.post("/send_input")
53
+ def send_input(
54
+ user_id: str = Form(),
55
+ files: List[UploadFile] = File(),
56
+ ):
57
+ """Send the inputs to the server."""
58
+
59
+ print("\nSend the data to the server ............\n")
60
+
61
+ # Receive the Client's files (Evaluation key + Encrypted symptoms)
62
+ evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key"
63
+ encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input"
64
+
65
+ # Save the files using the above paths
66
+ with encrypted_input_path.open("wb") as encrypted_input, evaluation_key_path.open(
67
+ "wb"
68
+ ) as evaluation_key:
69
+ encrypted_input.write(files[0].file.read())
70
+ evaluation_key.write(files[1].file.read())
71
+
72
+
73
+ @app.post("/run_fhe")
74
+ def run_fhe(
75
+ user_id: str = Form(),
76
+ ):
77
+ """Inference in FHE."""
78
+
79
+ print("\nRun in FHE in the server ............\n")
80
+ evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key"
81
+ encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input"
82
+
83
+ # Read the files (Evaluation key + Encrypted symptoms) using the above paths
84
+ with encrypted_input_path.open("rb") as encrypted_output_file, evaluation_key_path.open(
85
+ "rb"
86
+ ) as evaluation_key_file:
87
+ encrypted_output = encrypted_output_file.read()
88
+ evaluation_key = evaluation_key_file.read()
89
+
90
+ # Run the FHE execution
91
+ start = time.time()
92
+ encrypted_output = FHE_SERVER.run(encrypted_output, evaluation_key)
93
+ assert isinstance(encrypted_output, bytes)
94
+ fhe_execution_time = round(time.time() - start, 2)
95
+
96
+ # Retrieve the encrypted output path
97
+ encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output"
98
+
99
+ # Write the file using the above path
100
+ with encrypted_output_path.open("wb") as f:
101
+ f.write(encrypted_output)
102
+
103
+ return JSONResponse(content=fhe_execution_time)
104
+
105
+
106
+ @app.post("/get_output")
107
+ def get_output(user_id: str = Form()):
108
+ """Retrieve the encrypted output from the server."""
109
+
110
+ print("\nGet the output from the server ............\n")
111
+
112
+ # Path where the encrypted output is saved
113
+ encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output"
114
+
115
+ # Read the file using the above path
116
+ with encrypted_output_path.open("rb") as f:
117
+ encrypted_output = f.read()
118
+
119
+ time.sleep(1)
120
+
121
+ # Send the encrypted output
122
+ return Response(encrypted_output)
123
+
124
+
125
+ @app.post("/generate_keys")
126
+ def generate_keys(user_symptoms: List[str]):
127
+ """
128
+ Endpoint to generate keys based on user symptoms.
129
+
130
+ Args:
131
+ user_symptoms (List[str]): The list of user symptoms.
132
+
133
+ Returns:
134
+ JSONResponse: A response containing the generated keys and user ID.
135
+ """
136
+ def is_none(obj):
137
+ return obj is None or (obj is not None and len(obj) == 0)
138
+
139
+ # Call the key generation function
140
+ clean_directory()
141
+
142
+ if is_none(user_symptoms):
143
+ return JSONResponse(
144
+ status_code=400, content={"error": "Please submit your symptoms first."}
145
+ )
146
+
147
+ # Generate a random user ID
148
+ user_id = np.random.randint(0, 2**32)
149
+ print(f"Your user ID is: {user_id}....")
150
+
151
+ client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
152
+ client.load()
153
+
154
+ # Creates the private and evaluation keys on the client side
155
+ client.generate_private_and_evaluation_keys()
156
+
157
+ # Get the serialized evaluation keys
158
+ serialized_evaluation_keys = client.get_serialized_evaluation_keys()
159
+ assert isinstance(serialized_evaluation_keys, bytes)
160
+
161
+ # Save the evaluation key
162
+ evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key"
163
+ with evaluation_key_path.open("wb") as f:
164
+ f.write(serialized_evaluation_keys)
165
+
166
+ serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
167
+
168
+ return JSONResponse(
169
+ content={
170
+ "user_id": user_id,
171
+ "evaluation_key": serialized_evaluation_keys_shorten_hex,
172
+ "evaluation_key_size": f"{len(serialized_evaluation_keys) / (10**6):.2f} MB"
173
+ }
174
+ )
175
+
176
+
177
+ @app.post("/run_dev")
178
+ def run_dev_script():
179
+ """
180
+ Endpoint to execute the dev.py script to generate deployment files.
181
+
182
+ Returns:
183
+ JSONResponse: Success message or error details.
184
+ """
185
+ try:
186
+ # Define the path to dev.py
187
+ dev_script_path = Path(__file__).parent / "dev.py"
188
+
189
+ # Execute the dev.py script
190
+ result = subprocess.run(
191
+ ["python", str(dev_script_path)],
192
+ capture_output=True,
193
+ text=True,
194
+ check=True
195
+ )
196
+
197
+ # Return success message with output
198
+ return JSONResponse(
199
+ content={"message": "dev.py executed successfully!", "output": result.stdout}
200
+ )
201
+
202
+ except subprocess.CalledProcessError as e:
203
+ # Return error message in case of failure
204
+ return JSONResponse(
205
+ status_code=500,
206
+ content={"error": "Failed to execute dev.py", "details": e.stderr}
207
+ )
dev.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generating deployment files."""
2
+
3
+ import shutil
4
+
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+
9
+ from concrete.ml.sklearn import LogisticRegression as ConcreteLogisticRegression
10
+ from concrete.ml.deployment import FHEModelDev
11
+
12
+
13
+ # Data files location
14
+ TRAINING_FILE_NAME = "./data/Training_preprocessed.csv"
15
+ TESTING_FILE_NAME = "./data/Testing_preprocessed.csv"
16
+
17
+ # Load data
18
+ df_train = pd.read_csv(TRAINING_FILE_NAME)
19
+ df_test = pd.read_csv(TESTING_FILE_NAME)
20
+
21
+ # Split the data into X_train, y_train, X_test_, y_test sets
22
+ TARGET_COLUMN = ["prognosis_encoded", "prognosis"]
23
+
24
+ y_train = df_train[TARGET_COLUMN[0]].values.flatten()
25
+ y_test = df_test[TARGET_COLUMN[0]].values.flatten()
26
+
27
+ X_train = df_train.drop(TARGET_COLUMN, axis=1)
28
+ X_test = df_test.drop(TARGET_COLUMN, axis=1)
29
+
30
+ # Concrete ML model
31
+
32
+ # Models parameters
33
+ optimal_param = {"C": 0.9, "n_bits": 13, "solver": "sag", "multi_class": "auto"}
34
+
35
+ clf = ConcreteLogisticRegression(**optimal_param)
36
+
37
+ # Fit the model
38
+ clf.fit(X_train, y_train)
39
+
40
+ # Compile the model
41
+ fhe_circuit = clf.compile(X_train)
42
+
43
+ fhe_circuit.client.keygen(force=False)
44
+
45
+ path_to_model = Path("./deployment_files/").resolve()
46
+
47
+ if path_to_model.exists():
48
+ shutil.rmtree(path_to_model)
49
+
50
+ dev = FHEModelDev(path_to_model, clf)
51
+
52
+ dev.save(via_mlir=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  fastapi
2
  uvicorn[standard]
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ concrete-ml==1.4.0
symptoms_categories.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In this file, we roughly split up a list of symptoms, taken from "./training.csv" file, avalaible
3
+ through: "https://github.com/anujdutt9/Disease-Prediction-from-Symptoms/tree/master/dataset"
4
+ into medical categories, in order to make the UI more plesant for the users.
5
+
6
+ Each variable contains a list of symptoms sthat can be pecific to a part of the body or to a list
7
+ of similar symptoms.
8
+ """
9
+
10
+
11
+ DIGESTIVE_SYSTEM_SYMPTOMS = {
12
+ "DIGESTIVE_SYSTEM_CONCERNS": [
13
+ "stomach_pain",
14
+ "acidity",
15
+ "vomiting",
16
+ "indigestion",
17
+ "constipation",
18
+ "abdominal_pain",
19
+ "diarrhea",
20
+ "nausea",
21
+ "distention_of_abdomen",
22
+ "stomach_bleeding",
23
+ "pain_during_bowel_movements",
24
+ "passage_of_gases",
25
+ "red_spots_over_body",
26
+ "swelling_of_stomach",
27
+ "bloody_stool",
28
+ "irritation_in_anus",
29
+ "pain_in_anal_region",
30
+ "abnormal_menstruation",
31
+ ]
32
+ }
33
+
34
+ DERMATOLOGICAL_SYMPTOMS = {
35
+ "DERMATOLOGICAL_CONCERNS": [
36
+ "itching",
37
+ "skin_rash",
38
+ "pus_filled_pimples",
39
+ "blackheads",
40
+ "scurving",
41
+ "skin_peeling",
42
+ "silver_like_dusting",
43
+ "small_dents_in_nails",
44
+ "inflammatory_nails",
45
+ "blister",
46
+ "red_sore_around_nose",
47
+ "bruising",
48
+ "yellow_crust_ooze",
49
+ "dischromic_patches",
50
+ "nodal_skin_eruptions",
51
+ "toxic_look_(typhus)",
52
+ "brittle_nails",
53
+ "yellowish_skin",
54
+ ]
55
+ }
56
+
57
+ ORL_SYMPTOMS = {
58
+ "ORL_CONCERNS": [
59
+ "loss_of_smell",
60
+ "continuous_sneezing",
61
+ "runny_nose",
62
+ "patches_in_throat",
63
+ "throat_irritation",
64
+ "sinus_pressure",
65
+ "enlarged_thyroid",
66
+ "loss_of_balance",
67
+ "unsteadiness",
68
+ "dizziness",
69
+ "spinning_movements",
70
+ ]
71
+ }
72
+
73
+ THORAX_SYMPTOMS = {
74
+ "THORAX_CONCERNS": [
75
+ "breathlessness",
76
+ "chest_pain",
77
+ "cough",
78
+ "rusty_sputum",
79
+ "phlegm",
80
+ "mucoid_sputum",
81
+ "congestion",
82
+ "blood_in_sputum",
83
+ "fast_heart_rate",
84
+ ]
85
+ }
86
+
87
+ OPHTHALMOLOGICAL_SYMPTOMS = {
88
+ "OPHTHALMOLOGICAL_CONCERNS": [
89
+ "sunken_eyes",
90
+ "redness_of_eyes",
91
+ "watering_from_eyes",
92
+ "blurred_and_distorted_vision",
93
+ "pain_behind_the_eyes",
94
+ "visual_disturbances",
95
+ ]
96
+ }
97
+
98
+ VASCULAR_LYMPHATIC_SYMPTOMS = {
99
+ "VASCULAR_AND_LYMPHATIC_CONCERNS": [
100
+ "cold_hands_and_feets",
101
+ "swollen_blood_vessels",
102
+ "swollen_legs",
103
+ "swelled_lymph_nodes",
104
+ "palpitations",
105
+ "prominent_veins_on_calf",
106
+ "yellowing_of_eyes",
107
+ "puffy_face_and_eyes",
108
+ "severe_fluid_overload",
109
+ "swollen_extremeties",
110
+ ]
111
+ }
112
+
113
+ UROLOGICAL_SYMPTOMS = {
114
+ "UROLOGICAL_CONCERNS": [
115
+ "burning_micturition",
116
+ "spotting_urination",
117
+ "yellow_urine",
118
+ "bladder_discomfort",
119
+ "foul_smell_of_urine",
120
+ "continuous_feel_of_urine",
121
+ "polyuria",
122
+ "dark_urine",
123
+ ]
124
+ }
125
+
126
+ MUSCULOSKELETAL_SYMPTOMS = {
127
+ "MUSCULOSKELETAL_CONCERNS": [
128
+ "joint_pain",
129
+ "muscle_wasting",
130
+ "muscle_pain",
131
+ "muscle_weakness",
132
+ "knee_pain",
133
+ "stiff_neck",
134
+ "swelling_joints",
135
+ "movement_stiffness",
136
+ "hip_joint_pain",
137
+ "painful_walking",
138
+ "weakness_of_one_body_side",
139
+ "neck_pain",
140
+ "back_pain",
141
+ "weakness_in_limbs",
142
+ "cramps",
143
+ ]
144
+ }
145
+
146
+ GENERAL_SYMPTOMS = {
147
+ "GENERAL_CONCERNS": [
148
+ "acute_liver_failure",
149
+ "anxiety",
150
+ "restlessness",
151
+ "lethargy",
152
+ "mood_swings",
153
+ "irritability",
154
+ "lack_of_concentration",
155
+ "fatigue",
156
+ "malaise",
157
+ "weight_gain",
158
+ "increased_appetite",
159
+ "weight_loss",
160
+ "loss_of_appetite",
161
+ "excess_body_fat",
162
+ "excessive_hunger",
163
+ "ulcers_on_tongue",
164
+ "shivering",
165
+ "chills",
166
+ "irregular_sugar_level",
167
+ "high_fever",
168
+ "slurred_speech",
169
+ "sweating",
170
+ "internal_itching",
171
+ "mild_fever",
172
+ "dehydration",
173
+ "headache",
174
+ "frequent_unprotected_sexual_intercourse_with_multiple_partners",
175
+ "drying_and_tingling_lips",
176
+ "altered_sensorium",
177
+ "family_history",
178
+ "receiving_blood_transfusion",
179
+ "receiving_unsterile_injections",
180
+ "chronic_alcohol_abuse",
181
+ ]
182
+ }
183
+
184
+ SYMPTOMS_LIST = [
185
+ # Column 1
186
+ DIGESTIVE_SYSTEM_SYMPTOMS,
187
+ UROLOGICAL_SYMPTOMS,
188
+ VASCULAR_LYMPHATIC_SYMPTOMS,
189
+ # Column 2
190
+ ORL_SYMPTOMS,
191
+ DERMATOLOGICAL_SYMPTOMS,
192
+ MUSCULOSKELETAL_SYMPTOMS,
193
+ # Column 3
194
+ OPHTHALMOLOGICAL_SYMPTOMS,
195
+ THORAX_SYMPTOMS,
196
+ GENERAL_SYMPTOMS,
197
+ ]
utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Union
5
+
6
+ import numpy
7
+ import pandas
8
+
9
+ from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
10
+
11
+ # Max Input to be displayed on the HuggingFace space brower using Gradio
12
+ # Too large inputs, slow down the server: https://github.com/gradio-app/gradio/issues/1877
13
+ INPUT_BROWSER_LIMIT = 380
14
+
15
+ # Store the server's URL
16
+ SERVER_URL = "http://localhost:8000/"
17
+
18
+ CURRENT_DIR = Path(__file__).parent
19
+ DEPLOYMENT_DIR = CURRENT_DIR / "deployment_files"
20
+ KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
21
+ CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
22
+ SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
23
+
24
+ ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]
25
+
26
+ # Columns that define the target
27
+ TARGET_COLUMNS = ["prognosis_encoded", "prognosis"]
28
+
29
+ TRAINING_FILENAME = "./data/Training_preprocessed.csv"
30
+ TESTING_FILENAME = "./data/Testing_preprocessed.csv"
31
+
32
+ # pylint: disable=invalid-name
33
+
34
+
35
+ def pretty_print(
36
+ inputs, case_conversion=str.title, which_replace: str = "_", to_what: str = " ", delimiter=None
37
+ ):
38
+ """
39
+ Prettify and sort the input as a list of string.
40
+
41
+ Args:
42
+ inputs (Any): The inputs to be prettified.
43
+
44
+ Returns:
45
+ List: The prettified and sorted list of inputs.
46
+
47
+ """
48
+ # Flatten the list if required
49
+ pretty_list = []
50
+ for item in inputs:
51
+ if isinstance(item, list):
52
+ pretty_list.extend(item)
53
+ else:
54
+ pretty_list.append(item)
55
+
56
+ # Sort
57
+ pretty_list = sorted(list(set(pretty_list)))
58
+ # Replace
59
+ pretty_list = [item.replace(which_replace, to_what) for item in pretty_list]
60
+ pretty_list = [case_conversion(item) for item in pretty_list]
61
+ if delimiter:
62
+ pretty_list = f"{delimiter.join(pretty_list)}."
63
+
64
+ return pretty_list
65
+
66
+
67
+ def clean_directory() -> None:
68
+ """
69
+ Clear direcgtories
70
+ """
71
+ print("Cleaning...\n")
72
+ for target_dir in ALL_DIRS:
73
+ if os.path.exists(target_dir) and os.path.isdir(target_dir):
74
+ shutil.rmtree(target_dir)
75
+ target_dir.mkdir(exist_ok=True, parents=True)
76
+
77
+
78
+ def get_disease_name(encoded_prediction: int, file_name: str = TRAINING_FILENAME) -> str:
79
+ """Return the disease name given its encoded label.
80
+
81
+ Args:
82
+ encoded_prediction (int): The encoded prediction
83
+ file_name (str): The data file path
84
+
85
+ Returns:
86
+ str: The according disease name
87
+ """
88
+ df = pandas.read_csv(file_name, usecols=TARGET_COLUMNS).drop_duplicates()
89
+ disease_name, _ = df[df[TARGET_COLUMNS[0]] == encoded_prediction].values.flatten()
90
+ return disease_name
91
+
92
+
93
+ def load_data() -> Union[Tuple[pandas.DataFrame, numpy.ndarray], List]:
94
+ """
95
+ Return the data
96
+
97
+ Args:
98
+ None
99
+
100
+ Return:
101
+ The train, testing set and valid symptoms.
102
+ """
103
+ # Load data
104
+ df_train = pandas.read_csv(TRAINING_FILENAME)
105
+ df_test = pandas.read_csv(TESTING_FILENAME)
106
+
107
+ # Separate the traget from the training / testing set:
108
+ # TARGET_COLUMNS[0] -> "prognosis_encoded" -> contains the numeric label of the disease
109
+ # TARGET_COLUMNS[1] -> "prognosis" -> contains the name of the disease
110
+
111
+ y_train = df_train[TARGET_COLUMNS[0]]
112
+ X_train = df_train.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
113
+
114
+ y_test = df_test[TARGET_COLUMNS[0]]
115
+ X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
116
+
117
+ return (
118
+ (X_train, X_test),
119
+ (y_train, y_test),
120
+ X_train.columns.to_list(),
121
+ df_train[TARGET_COLUMNS[1]].unique().tolist(),
122
+ )
123
+
124
+
125
+ def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray):
126
+ """
127
+ Load a pre-trained serialized model
128
+
129
+ Args:
130
+ X_train (pandas.DataFrame): Training set
131
+ y_train (numpy.ndarray): Targets of the training set
132
+
133
+ Return:
134
+ The Concrete ML model and its circuit
135
+ """
136
+ # Parameters
137
+ concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1}
138
+ classifier = ConcreteXGBoostClassifier(**concrete_args)
139
+ # Train the model
140
+ classifier.fit(X_train, y_train)
141
+ # Compile the model
142
+ circuit = classifier.compile(X_train)
143
+
144
+ return classifier, circuit