sumesh4C commited on
Commit
3a14b2c
·
verified ·
1 Parent(s): 884821c

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +21 -4
tasks/text.py CHANGED
@@ -11,6 +11,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
  from sentence_transformers import SentenceTransformer
12
  from xgboost import XGBClassifier
13
  import pickle
 
14
 
15
 
16
  router = APIRouter()
@@ -61,10 +62,13 @@ async def evaluate_text(request: TextEvaluationRequest):
61
  # YOUR MODEL INFERENCE CODE HERE
62
 
63
  #Load the embedding model
64
- model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
 
 
65
 
66
  # Convert each sentence into a vector representation (embedding)
67
  embeddings = model.encode(test_dataset['quote'].tolist())
 
68
 
69
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
70
  #--------------------------------------------------------------------------------------------
@@ -74,11 +78,24 @@ async def evaluate_text(request: TextEvaluationRequest):
74
 
75
 
76
  #load the xgboost model
77
- with open("models/stella_400_xgb_500.pkl",'rb') as f:
78
- xgbclassifier = pickle.load(f)
 
 
 
 
 
79
 
80
  #make inference
81
- predictions = xgbclassifier.predict(embeddings)
 
 
 
 
 
 
 
 
82
 
83
  #--------------------------------------------------------------------------------------------
84
  # YOUR MODEL INFERENCE STOPS HERE
 
11
  from sentence_transformers import SentenceTransformer
12
  from xgboost import XGBClassifier
13
  import pickle
14
+ import torch
15
 
16
 
17
  router = APIRouter()
 
62
  # YOUR MODEL INFERENCE CODE HERE
63
 
64
  #Load the embedding model
65
+ #model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
66
+ model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # You can use other Sentence Transformers models as needed
67
+ sentence_model = SentenceTransformer(model_name)
68
 
69
  # Convert each sentence into a vector representation (embedding)
70
  embeddings = model.encode(test_dataset['quote'].tolist())
71
+ embeddings = embeddings.cpu()
72
 
73
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
74
  #--------------------------------------------------------------------------------------------
 
78
 
79
 
80
  #load the xgboost model
81
+ #with open("models/stella_400_xgb_500.pkl",'rb') as f:
82
+ # xgbclassifier = pickle.load(f)
83
+
84
+ model_nn = torch.load("models/model_nn.pth")
85
+
86
+ # Set the model to evaluation mode
87
+ model_nn.eval()
88
 
89
  #make inference
90
+ #predictions = xgbclassifier.predict(embeddings)
91
+
92
+ # Make predictions
93
+ with torch.no_grad():
94
+ outputs = model_nn(text_embeddings)
95
+ _, predicted = torch.max(outputs, 1) # Get the class with the highest score
96
+
97
+ # Decode the predictions back to original labels using label_encoder
98
+ predicted_labels = label_encoder.inverse_transform(predicted.cpu().numpy())
99
 
100
  #--------------------------------------------------------------------------------------------
101
  # YOUR MODEL INFERENCE STOPS HERE