Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +12 -12
tasks/text.py
CHANGED
@@ -73,8 +73,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
73 |
# Load a pre-trained Sentence-BERT model
|
74 |
print("loading model")
|
75 |
model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
|
76 |
-
|
77 |
-
sentence_embeddings = model.encode(test_dataset["quote"])
|
78 |
|
79 |
#load the models
|
80 |
with open("xgb_bin.pkl","rb") as f:
|
@@ -84,18 +83,16 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
84 |
xgb_multi = pickle.load(f)
|
85 |
|
86 |
|
87 |
-
# Load the binary model
|
88 |
-
#xgb_bin = xgb.Booster()
|
89 |
-
#xgb_bin.load_model("xgb_model_bin.bin")
|
90 |
-
|
91 |
-
# Load the binary model
|
92 |
-
#xgb_multi = xgb.Booster()
|
93 |
-
#xgb_multi.load_model("xgb_model_muli.bin")
|
94 |
|
95 |
|
96 |
|
|
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
X_train = sentence_embeddings.copy()
|
100 |
|
101 |
y_train = np.array(test_dataset["label"].copy())
|
@@ -105,9 +102,12 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
105 |
y_train_binary[y_train_binary != 0] = 1
|
106 |
|
107 |
|
|
|
|
|
|
|
108 |
#multi class
|
109 |
X_train_multi = X_train[y_train != 0]
|
110 |
-
|
111 |
y_train_multi = y_train[y_train != 0]
|
112 |
|
113 |
logging.info(f"Xtrain_multi_shape:{X_train_multi.shape}")
|
@@ -125,7 +125,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
125 |
logging.info(f"y_pred_bin:{y_pred_bin.shape}")
|
126 |
logging.info(f"y_pred_multi shape:{y_pred_multi.shape}")
|
127 |
|
128 |
-
y_pred_bin[
|
129 |
|
130 |
|
131 |
|
|
|
73 |
# Load a pre-trained Sentence-BERT model
|
74 |
print("loading model")
|
75 |
model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
|
76 |
+
|
|
|
77 |
|
78 |
#load the models
|
79 |
with open("xgb_bin.pkl","rb") as f:
|
|
|
83 |
xgb_multi = pickle.load(f)
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
|
89 |
+
logging.info("generating embedding")
|
90 |
|
91 |
+
# Generate sentence embeddings
|
92 |
+
sentence_embeddings = model.encode(test_dataset["quote"])
|
93 |
+
logging.info(" embedding done")
|
94 |
+
|
95 |
+
|
96 |
X_train = sentence_embeddings.copy()
|
97 |
|
98 |
y_train = np.array(test_dataset["label"].copy())
|
|
|
102 |
y_train_binary[y_train_binary != 0] = 1
|
103 |
|
104 |
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
#multi class
|
109 |
X_train_multi = X_train[y_train != 0]
|
110 |
+
|
111 |
y_train_multi = y_train[y_train != 0]
|
112 |
|
113 |
logging.info(f"Xtrain_multi_shape:{X_train_multi.shape}")
|
|
|
125 |
logging.info(f"y_pred_bin:{y_pred_bin.shape}")
|
126 |
logging.info(f"y_pred_multi shape:{y_pred_multi.shape}")
|
127 |
|
128 |
+
y_pred_bin[y_train==1] = y_pred_multi
|
129 |
|
130 |
|
131 |
|