rohanshaw commited on
Commit
4deb4e8
·
verified ·
1 Parent(s): 2387eb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -74,16 +74,38 @@ def get_response_with_fallback(prompt):
74
  # If none succeed, return an error message
75
  return "Failed to generate diagnosis with any model."
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @app.post("/predict", response_model=PersonalityResponse)
78
  def predict_personality(request: PersonalityRequest):
79
- # input_datas = list(map(five2one, request.responses))
80
 
81
- input_data = np.array(request.responses).reshape(1, -1, 1)
82
-
83
- prediction = model.predict(input_data)
84
 
85
- personality_type = int(np.argmax(prediction, axis=1)[0])
86
- personality_name = personality_mapping[personality_type]
87
 
88
  return PersonalityResponse(personality_type=personality_type, personality_name=personality_name)
89
 
 
74
  # If none succeed, return an error message
75
  return "Failed to generate diagnosis with any model."
76
 
77
+
78
+ def preprocess_data(new_data_path):
79
+ # Load new data
80
+ new_data = pd.DataFrame(new_data_path)
81
+
82
+ # Scale the data using the same scaler used during training
83
+ scaler = MinMaxScaler(feature_range=(0, 1))
84
+ scaled_data = scaler.fit_transform(new_data)
85
+
86
+ # Reshape the data to fit the model input
87
+ reshaped_data = scaled_data.reshape((scaled_data.shape[1], scaled_data.shape[0], 1))
88
+
89
+ return reshaped_data
90
+
91
+ def predict_clusters(model, preprocessed_data):
92
+ # Predict the cluster for each instance
93
+ predictions = model.predict(preprocessed_data)
94
+
95
+ # Get the cluster with the highest probability
96
+ predicted_clusters = np.argmax(predictions, axis=1)
97
+
98
+ return predicted_clusters
99
+
100
  @app.post("/predict", response_model=PersonalityResponse)
101
  def predict_personality(request: PersonalityRequest):
102
+ preprocessed_data = preprocess_data(request.responses)
103
 
104
+ predicted_clusters = predict_clusters(model, preprocessed_data)
105
+ personality_type = predicted_clusters[0]
 
106
 
107
+ personality_name = personality_mapping[predicted_clusters[0]]
108
+ print(personality_name)
109
 
110
  return PersonalityResponse(personality_type=personality_type, personality_name=personality_name)
111