romanbredehoft-zama commited on
Commit
a241bb3
·
1 Parent(s): 993f2a6

Impose correct column order in pre-processing

Browse files
backend.py CHANGED
@@ -21,6 +21,8 @@ from settings import (
21
  PRE_PROCESSOR_USER_PATH,
22
  PRE_PROCESSOR_THIRD_PARTY_PATH,
23
  CLIENT_TYPES,
 
 
24
  )
25
 
26
  from utils.client_server_interface import MultiInputsFHEModelClient
@@ -270,6 +272,8 @@ def pre_process_encrypt_send_user(client_id, *inputs):
270
  "Housing_type": [housing_type],
271
  })
272
 
 
 
273
  preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
274
 
275
  return _encrypt_send(client_id, preprocessed_user_inputs, "user")
@@ -311,6 +315,8 @@ def pre_process_encrypt_send_third_party(client_id, *inputs):
311
  "Years_employed": [years_salaried],
312
  })
313
 
 
 
314
  preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
315
 
316
  return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party")
 
21
  PRE_PROCESSOR_USER_PATH,
22
  PRE_PROCESSOR_THIRD_PARTY_PATH,
23
  CLIENT_TYPES,
24
+ USER_COLUMNS,
25
+ THIRD_PARTY_COLUMNS,
26
  )
27
 
28
  from utils.client_server_interface import MultiInputsFHEModelClient
 
272
  "Housing_type": [housing_type],
273
  })
274
 
275
+ user_inputs = user_inputs.reindex(USER_COLUMNS, axis=1)
276
+
277
  preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
278
 
279
  return _encrypt_send(client_id, preprocessed_user_inputs, "user")
 
315
  "Years_employed": [years_salaried],
316
  })
317
 
318
+ third_party_inputs = third_party_inputs.reindex(THIRD_PARTY_COLUMNS, axis=1)
319
+
320
  preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
321
 
322
  return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party")
deployment_files/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a4c826efceb2e6c4d9fd1d3876d7adae10537814add6ae3f08b5dab9ae23f76b
3
- size 76339
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c0a655d225d0f31642c20c8f3e5537505b6b6904ad8af7636631024cf6e18b6
3
+ size 76383
deployment_files/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:adc3d696a290278148d2ac906018a3a58d3c545290f6fdb60a82a3f2e7eea531
3
- size 3322
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5570f7dfda2d5ced4a6bd411d9d2eba67b8bcbd523efac803be66abd4368a99
3
+ size 3321
development.py CHANGED
@@ -9,10 +9,20 @@ from sklearn.model_selection import train_test_split
9
  from sklearn.metrics import accuracy_score
10
  from imblearn.over_sampling import SMOTE
11
 
12
- from settings import DEPLOYMENT_PATH, RANDOM_STATE, DATA_PATH, INPUT_SLICES, PRE_PROCESSOR_USER_PATH, PRE_PROCESSOR_THIRD_PARTY_PATH
 
 
 
 
 
 
 
 
 
 
13
  from utils.client_server_interface import MultiInputsFHEModelDev
14
  from utils.model import MultiInputXGBClassifier
15
- from utils.pre_processing import get_pre_processors, select_and_pop_features
16
 
17
 
18
  def get_processed_multi_inputs(data):
@@ -39,9 +49,9 @@ data_y = data.pop("Target").copy()
39
  data_x = data.copy()
40
 
41
  # Get data from all parties
42
- data_third_party = select_and_pop_features(data_x, ["Years_employed", "Salaried"])
43
- data_bank = select_and_pop_features(data_x, ["Account_length"])
44
- data_user = data_x.copy()
45
 
46
  # Feature engineer the data
47
  pre_processor_user, pre_processor_third_party = get_pre_processors()
 
9
  from sklearn.metrics import accuracy_score
10
  from imblearn.over_sampling import SMOTE
11
 
12
+ from settings import (
13
+ DEPLOYMENT_PATH,
14
+ RANDOM_STATE,
15
+ DATA_PATH,
16
+ INPUT_SLICES,
17
+ PRE_PROCESSOR_USER_PATH,
18
+ PRE_PROCESSOR_THIRD_PARTY_PATH,
19
+ USER_COLUMNS,
20
+ BANK_COLUMNS,
21
+ THIRD_PARTY_COLUMNS,
22
+ )
23
  from utils.client_server_interface import MultiInputsFHEModelDev
24
  from utils.model import MultiInputXGBClassifier
25
+ from utils.pre_processing import get_pre_processors
26
 
27
 
28
  def get_processed_multi_inputs(data):
 
49
  data_x = data.copy()
50
 
51
  # Get data from all parties
52
+ data_user = data_x[USER_COLUMNS].copy()
53
+ data_bank = data_x[BANK_COLUMNS].copy()
54
+ data_third_party = data_x[THIRD_PARTY_COLUMNS].copy()
55
 
56
  # Feature engineer the data
57
  pre_processor_user, pre_processor_third_party = get_pre_processors()
settings.py CHANGED
@@ -29,7 +29,7 @@ SERVER_URL = "http://localhost:8000/"
29
  # files
30
  DATA_PATH = "data/data.csv"
31
 
32
- # Developement settings
33
  RANDOM_STATE = 0
34
  INITIAL_INPUT_SHAPE = (1, 49)
35
 
@@ -45,6 +45,14 @@ INPUT_SLICES = {
45
  "third_party": slice(43, 49), # Third position: start from n_feature_user + n_feature_bank
46
  }
47
 
 
 
 
 
 
 
 
 
48
  _data = pandas.read_csv(DATA_PATH, encoding="utf-8")
49
 
50
  def get_min_max(data, column):
 
29
  # files
30
  DATA_PATH = "data/data.csv"
31
 
32
+ # Development settings
33
  RANDOM_STATE = 0
34
  INITIAL_INPUT_SHAPE = (1, 49)
35
 
 
45
  "third_party": slice(43, 49), # Third position: start from n_feature_user + n_feature_bank
46
  }
47
 
48
+ USER_COLUMNS = [
49
+ 'Own_car', 'Own_property', 'Work_phone', 'Phone', 'Email', 'Num_children', 'Household_size',
50
+ 'Total_income', 'Age', 'Income_type', 'Education_type', 'Family_status', 'Housing_type',
51
+ 'Occupation_type',
52
+ ]
53
+ BANK_COLUMNS = ["Account_length"]
54
+ THIRD_PARTY_COLUMNS = ["Years_employed", "Salaried"]
55
+
56
  _data = pandas.read_csv(DATA_PATH, encoding="utf-8")
57
 
58
  def get_min_max(data, column):
utils/pre_processing.py CHANGED
@@ -83,10 +83,4 @@ def get_pre_processors():
83
  verbose_feature_names_out=False,
84
  )
85
 
86
- return pre_processor_user, pre_processor_third_party
87
-
88
-
89
- def select_and_pop_features(data, columns):
90
- new_data = data[columns].copy()
91
- data.drop(columns, axis=1, inplace=True)
92
- return new_data
 
83
  verbose_feature_names_out=False,
84
  )
85
 
86
+ return pre_processor_user, pre_processor_third_party