Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +16 -11
tasks/text.py
CHANGED
@@ -12,6 +12,16 @@ from safetensors.torch import load_file
|
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
router = APIRouter()
|
16 |
|
17 |
DESCRIPTION = "GTE Architecture"
|
@@ -70,16 +80,6 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
70 |
true_labels = test_dataset["label"]
|
71 |
texts = test_dataset["quote"]
|
72 |
|
73 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
74 |
-
|
75 |
-
model_repo = "elucidator8918/frugal-ai-text"
|
76 |
-
model = AutoBertClassifier(num_labels=8)
|
77 |
-
model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
|
78 |
-
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
79 |
-
|
80 |
-
model = model.to(device)
|
81 |
-
model.eval()
|
82 |
-
|
83 |
# Start tracking emissions
|
84 |
tracker.start()
|
85 |
tracker.start_task("inference")
|
@@ -94,6 +94,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
94 |
truncation=True,
|
95 |
padding=True,
|
96 |
return_tensors="pt",
|
|
|
97 |
)
|
98 |
|
99 |
with torch.no_grad():
|
@@ -101,7 +102,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
101 |
text_attention_mask = text_encoding["attention_mask"].to(device)
|
102 |
outputs = model(text_input_ids, text_attention_mask)
|
103 |
predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
|
104 |
-
|
105 |
#--------------------------------------------------------------------------------------------
|
106 |
# YOUR MODEL INFERENCE STOPS HERE
|
107 |
#--------------------------------------------------------------------------------------------
|
@@ -111,6 +112,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
111 |
|
112 |
# Calculate accuracy
|
113 |
accuracy = accuracy_score(true_labels, predictions)
|
|
|
|
|
114 |
|
115 |
# Prepare results dictionary
|
116 |
results = {
|
@@ -129,5 +132,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
129 |
"test_seed": request.test_seed
|
130 |
}
|
131 |
}
|
|
|
|
|
132 |
|
133 |
return results
|
|
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
14 |
|
15 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
16 |
+
|
17 |
+
model_repo = "elucidator8918/frugal-ai-text"
|
18 |
+
model = AutoBertClassifier(num_labels=8)
|
19 |
+
model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
21 |
+
|
22 |
+
model = model.to(device)
|
23 |
+
model.eval()
|
24 |
+
|
25 |
router = APIRouter()
|
26 |
|
27 |
DESCRIPTION = "GTE Architecture"
|
|
|
80 |
true_labels = test_dataset["label"]
|
81 |
texts = test_dataset["quote"]
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
# Start tracking emissions
|
84 |
tracker.start()
|
85 |
tracker.start_task("inference")
|
|
|
94 |
truncation=True,
|
95 |
padding=True,
|
96 |
return_tensors="pt",
|
97 |
+
max_length=256
|
98 |
)
|
99 |
|
100 |
with torch.no_grad():
|
|
|
102 |
text_attention_mask = text_encoding["attention_mask"].to(device)
|
103 |
outputs = model(text_input_ids, text_attention_mask)
|
104 |
predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
|
105 |
+
|
106 |
#--------------------------------------------------------------------------------------------
|
107 |
# YOUR MODEL INFERENCE STOPS HERE
|
108 |
#--------------------------------------------------------------------------------------------
|
|
|
112 |
|
113 |
# Calculate accuracy
|
114 |
accuracy = accuracy_score(true_labels, predictions)
|
115 |
+
|
116 |
+
print(f"Accuracy = {accuracy}")
|
117 |
|
118 |
# Prepare results dictionary
|
119 |
results = {
|
|
|
132 |
"test_seed": request.test_seed
|
133 |
}
|
134 |
}
|
135 |
+
|
136 |
+
print(results)
|
137 |
|
138 |
return results
|