csk99 commited on
Commit
2c1c057
·
verified ·
1 Parent(s): 7cadf7a

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +12 -12
tasks/text.py CHANGED
@@ -73,8 +73,7 @@ async def evaluate_text(request: TextEvaluationRequest):
73
  # Load a pre-trained Sentence-BERT model
74
  print("loading model")
75
  model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
76
- # Generate sentence embeddings
77
- sentence_embeddings = model.encode(test_dataset["quote"])
78
 
79
  #load the models
80
  with open("xgb_bin.pkl","rb") as f:
@@ -84,18 +83,16 @@ async def evaluate_text(request: TextEvaluationRequest):
84
  xgb_multi = pickle.load(f)
85
 
86
 
87
- # Load the binary model
88
- #xgb_bin = xgb.Booster()
89
- #xgb_bin.load_model("xgb_model_bin.bin")
90
-
91
- # Load the binary model
92
- #xgb_multi = xgb.Booster()
93
- #xgb_multi.load_model("xgb_model_muli.bin")
94
 
95
 
96
 
 
97
 
98
-
 
 
 
 
99
  X_train = sentence_embeddings.copy()
100
 
101
  y_train = np.array(test_dataset["label"].copy())
@@ -105,9 +102,12 @@ async def evaluate_text(request: TextEvaluationRequest):
105
  y_train_binary[y_train_binary != 0] = 1
106
 
107
 
 
 
 
108
  #multi class
109
  X_train_multi = X_train[y_train != 0]
110
-
111
  y_train_multi = y_train[y_train != 0]
112
 
113
  logging.info(f"Xtrain_multi_shape:{X_train_multi.shape}")
@@ -125,7 +125,7 @@ async def evaluate_text(request: TextEvaluationRequest):
125
  logging.info(f"y_pred_bin:{y_pred_bin.shape}")
126
  logging.info(f"y_pred_multi shape:{y_pred_multi.shape}")
127
 
128
- y_pred_bin[y_pred_bin==1] = y_pred_multi
129
 
130
 
131
 
 
73
  # Load a pre-trained Sentence-BERT model
74
  print("loading model")
75
  model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
76
+
 
77
 
78
  #load the models
79
  with open("xgb_bin.pkl","rb") as f:
 
83
  xgb_multi = pickle.load(f)
84
 
85
 
 
 
 
 
 
 
 
86
 
87
 
88
 
89
+ logging.info("generating embedding")
90
 
91
+ # Generate sentence embeddings
92
+ sentence_embeddings = model.encode(test_dataset["quote"])
93
+ logging.info(" embedding done")
94
+
95
+
96
  X_train = sentence_embeddings.copy()
97
 
98
  y_train = np.array(test_dataset["label"].copy())
 
102
  y_train_binary[y_train_binary != 0] = 1
103
 
104
 
105
+
106
+
107
+
108
  #multi class
109
  X_train_multi = X_train[y_train != 0]
110
+
111
  y_train_multi = y_train[y_train != 0]
112
 
113
  logging.info(f"Xtrain_multi_shape:{X_train_multi.shape}")
 
125
  logging.info(f"y_pred_bin:{y_pred_bin.shape}")
126
  logging.info(f"y_pred_multi shape:{y_pred_multi.shape}")
127
 
128
+ y_pred_bin[y_train==1] = y_pred_multi
129
 
130
 
131