Tenefix 80cols commited on
Commit
43cc119
·
verified ·
1 Parent(s): ad9e544

Update predictor.py (#3)

Browse files

- Update predictor.py (5a3b7b14551571ba0955b56a644f9968440c30e8)


Co-authored-by: Sandro Ferroni <[email protected]>

Files changed (1) hide show
  1. predictor.py +71 -17
predictor.py CHANGED
@@ -3,9 +3,13 @@ import joblib
3
  import numpy as np
4
  from concrete.ml.deployment import FHEModelClient, FHEModelServer
5
  import logging
 
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
 
 
 
9
 
10
  # Paths to required files
11
  SCALER_PATH = os.path.join("models", "scaler.pkl")
@@ -32,40 +36,90 @@ except FileNotFoundError:
32
  # Load evaluation keys
33
  evaluation_keys = client.get_serialized_evaluation_keys()
34
 
35
- def predict(input_data):
36
  """
37
  Perform a local prediction using the compiled FHE model.
38
 
39
- Args:
40
- input_data (dict): User input data as a dictionary.
41
-
42
  Returns:
43
- str: Prediction result ("Fraudulent" or "Non-fraudulent").
 
44
  """
 
 
 
45
  try:
46
- logging.info(f"Input Data: {input_data}")
47
-
48
- # Scale the input data
49
- scaled_data = scaler.transform([list(input_data.values())])
50
- logging.info(f"Scaled Data: {scaled_data}")
51
-
52
- # Encrypt the scaled data
53
- encrypted_data = client.quantize_encrypt_serialize(scaled_data)
54
- logging.info("Data encrypted successfully.")
55
-
56
  # Execute the model locally on encrypted data
57
  encrypted_prediction = server.run(
58
  encrypted_data, serialized_evaluation_keys=evaluation_keys
59
  )
60
  logging.info(f"Encrypted Prediction: {encrypted_prediction}")
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
62
  # Decrypt the prediction result
63
  decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction)
64
  logging.info(f"Decrypted Prediction: {decrypted_prediction}")
65
 
66
  # Interpret the prediction
67
  binary_prediction = int(np.argmax(decrypted_prediction))
68
- return "Fraudulent" if binary_prediction == 1 else "Non-fraudulent"
 
69
  except Exception as e:
70
  logging.error(f"Error during prediction: {e}")
71
- return "Error during prediction"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  from concrete.ml.deployment import FHEModelClient, FHEModelServer
5
  import logging
6
+ import gradio as gr
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO)
10
+ key_already_generated_condition = False
11
+ encrypted_data = None
12
+ encrypted_prediction = None
13
 
14
  # Paths to required files
15
  SCALER_PATH = os.path.join("models", "scaler.pkl")
 
36
  # Load evaluation keys
37
  evaluation_keys = client.get_serialized_evaluation_keys()
38
 
39
+ def predict():
40
  """
41
  Perform a local prediction using the compiled FHE model.
42
 
 
 
 
43
  Returns:
44
+ str: The prediction result.
45
+ str: A message indicating the status of the prediction.
46
  """
47
+ global encrypted_data, encrypted_prediction
48
+ if encrypted_data is None:
49
+ return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
50
  try:
 
 
 
 
 
 
 
 
 
 
51
  # Execute the model locally on encrypted data
52
  encrypted_prediction = server.run(
53
  encrypted_data, serialized_evaluation_keys=evaluation_keys
54
  )
55
  logging.info(f"Encrypted Prediction: {encrypted_prediction}")
56
+ return encrypted_prediction, gr.update(value="FHE evaluation is done. ✅")
57
+
58
+ except Exception as e:
59
+ logging.error(f"Error during prediction: {e}")
60
+ return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
61
+
62
+ def decrypt_prediction():
63
+ """
64
+ Decrypt and interpret the prediction result.
65
 
66
+ Returns:
67
+ str: The interpreted prediction result.
68
+ """
69
+ global encrypted_prediction
70
+ if encrypted_prediction is None:
71
+ return "No prediction to decrypt. Please make a prediction first. ❌", "No prediction to decrypt. Please make a prediction first. ❌"
72
+ try:
73
  # Decrypt the prediction result
74
  decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction)
75
  logging.info(f"Decrypted Prediction: {decrypted_prediction}")
76
 
77
  # Interpret the prediction
78
  binary_prediction = int(np.argmax(decrypted_prediction))
79
+ return "⚠️ Fraudulent ⚠️" if binary_prediction == 1 else "😊 Non-fraudulent 😊", gr.update(value="Decryption successful ✅")
80
+
81
  except Exception as e:
82
  logging.error(f"Error during prediction: {e}")
83
+ return "Error during prediction", "Error during prediction❌"
84
+
85
+ def key_already_generated():
86
+ """
87
+ Check if the evaluation keys have already been generated.
88
+
89
+ Returns:
90
+ bool: True if the evaluation keys have already been generated, False otherwise.
91
+ """
92
+ global key_already_generated_condition
93
+ if evaluation_keys:
94
+ key_already_generated_condition = True
95
+ return True
96
+ return False
97
+
98
+ def pre_process_encrypt_send_purchase(*inputs):
99
+ """
100
+ Pre-processes, encrypts, and sends the purchase data for prediction.
101
+
102
+ Args:
103
+ *inputs: Variable number of input arguments.
104
+
105
+ Returns:
106
+ (str): A short representation of the encrypted input to send in hex.
107
+ """
108
+ global key_already_generated_condition, encrypted_data
109
+ if key_already_generated_condition == False:
110
+ return None, gr.update(value="Generate your key before. ❌")
111
+ try:
112
+ key_already_generated_condition = True
113
+ logging.info(f"Input Data: {inputs}")
114
+
115
+ # Scale the input data
116
+ scaled_data = scaler.transform([list(inputs)])
117
+ logging.info(f"Scaled Data: {scaled_data}")
118
+
119
+ # Encrypt the scaled data
120
+ encrypted_data = client.quantize_encrypt_serialize(scaled_data)
121
+ logging.info("Data encrypted successfully.")
122
+ return encrypted_data, gr.update(value="Inputs are encrypted and sent to server. ✅")
123
+ except Exception as e:
124
+ logging.error(f"Error during pre-processing: {e}")
125
+ return "Error during pre-processing"