sumesh4C commited on
Commit
6278e08
·
verified ·
1 Parent(s): e08248e

Create text2.py

Browse files
Files changed (1) hide show
  1. tasks/text2.py +148 -0
tasks/text2.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from datetime import datetime
3
+ from datasets import load_dataset
4
+ from sklearn.metrics import accuracy_score
5
+ import random
6
+
7
+ from .utils.evaluation import TextEvaluationRequest
8
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
+
10
+ #packages needed for inference
11
+ from sentence_transformers import SentenceTransformer
12
+ from xgboost import XGBClassifier
13
+ import pickle
14
+ import torch
15
+ import os
16
+
17
+ router = APIRouter()
18
+
19
+ DESCRIPTION = "Embedding + Neural Network"
20
+ ROUTE = "/text"
21
+
22
+ @router.post(ROUTE, tags=["Text Task"],
23
+ description=DESCRIPTION)
24
+ async def evaluate_text(request: TextEvaluationRequest):
25
+ """
26
+ Evaluate text classification for climate disinformation detection.
27
+
28
+ Current Model: Random Baseline
29
+ - Makes random predictions from the label space (0-7)
30
+ - Used as a baseline for comparison
31
+ """
32
+
33
+ # Get space info
34
+ username, space_url = get_space_info()
35
+
36
+ # Define the label mapping
37
+ LABEL_MAPPING = {
38
+ "0_not_relevant": 0,
39
+ "1_not_happening": 1,
40
+ "2_not_human": 2,
41
+ "3_not_bad": 3,
42
+ "4_solutions_harmful_unnecessary": 4,
43
+ "5_science_unreliable": 5,
44
+ "6_proponents_biased": 6,
45
+ "7_fossil_fuels_needed": 7
46
+ }
47
+
48
+ # Load and prepare the dataset
49
+ dataset = load_dataset(request.dataset_name)
50
+
51
+ # Convert string labels to integers
52
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
53
+
54
+ # Split dataset
55
+ train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
56
+ test_dataset = train_test["test"]
57
+
58
+ # Start tracking emissions
59
+ tracker.start()
60
+ tracker.start_task("inference")
61
+
62
+ #--------------------------------------------------------------------------------------------
63
+ # YOUR MODEL INFERENCE CODE HERE
64
+
65
+ # Set the device to MPS (if available)
66
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
67
+ print(f"Using device: {device}")
68
+
69
+ model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # You can use other Sentence Transformers models as needed
70
+ sentence_model = SentenceTransformer(model_name)
71
+
72
+ # Convert each sentence into a vector representation (embedding)
73
+ embeddings = sentence_model.encode(test_dataset['quote'], convert_to_tensor=True)
74
+
75
+ # Make random predictions (placeholder for actual model inference)
76
+ true_labels = test_dataset["label"]
77
+
78
+ """
79
+ from torch import nn, optim
80
+
81
+ class SimpleNN2(nn.Module):
82
+ def __init__(self, input_dim, output_dim):
83
+ super(SimpleNN2, self).__init__()
84
+ self.fc1 = nn.Linear(input_dim, 128) # Reduce hidden units
85
+ self.fc2 = nn.Linear(128, 64) # Further reduce units
86
+ self.fc3 = nn.Linear(64, output_dim)
87
+ self.relu = nn.ReLU()
88
+ self.dropout = nn.Dropout(0.3) # Add dropout
89
+ self.batch_norm1 = nn.BatchNorm1d(128)
90
+ self.batch_norm2 = nn.BatchNorm1d(64)
91
+
92
+ def forward(self, x):
93
+ x = self.relu(self.batch_norm1(self.fc1(x)))
94
+ x = self.dropout(x) # Apply dropout
95
+ x = self.relu(self.batch_norm2(self.fc2(x)))
96
+ x = self.dropout(x) # Apply dropout
97
+ x = self.fc3(x) # Output raw logits
98
+ return x
99
+ """
100
+
101
+ current_file_path = os.path.abspath(__file__)
102
+ current_dir = os.path.dirname(current_file_path)
103
+
104
+ # model_nn = torch.load(os.path.join(current_dir,"model_nn.pth"), map_location=device)
105
+ model_nn = torch.jit.load(os.path.join(current_dir,"model_nn_scripted.pth"), map_location=device)
106
+
107
+
108
+ # Set the model to evaluation mode
109
+ model_nn.eval()
110
+
111
+ # Make predictions
112
+ with torch.no_grad():
113
+ outputs = model_nn(embeddings)
114
+ _, predicted = torch.max(outputs, 1) # Get the class with the highest score
115
+
116
+ # Decode the predictions back to original labels using label_encoder
117
+ predictions = predicted.cpu().numpy()
118
+
119
+ #--------------------------------------------------------------------------------------------
120
+ # YOUR MODEL INFERENCE STOPS HERE
121
+ #--------------------------------------------------------------------------------------------
122
+
123
+
124
+ # Stop tracking emissions
125
+ emissions_data = tracker.stop_task()
126
+
127
+ # Calculate accuracy
128
+ accuracy = accuracy_score(true_labels, predictions)
129
+
130
+ # Prepare results dictionary
131
+ results = {
132
+ "username": username,
133
+ "space_url": space_url,
134
+ "submission_timestamp": datetime.now().isoformat(),
135
+ "model_description": DESCRIPTION,
136
+ "accuracy": float(accuracy),
137
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
138
+ "emissions_gco2eq": emissions_data.emissions * 1000,
139
+ "emissions_data": clean_emissions_data(emissions_data),
140
+ "api_route": ROUTE,
141
+ "dataset_config": {
142
+ "dataset_name": request.dataset_name,
143
+ "test_size": request.test_size,
144
+ "test_seed": request.test_seed
145
+ }
146
+ }
147
+
148
+ return results