elucidator8918 commited on
Commit
f94cae2
·
verified ·
1 Parent(s): b8aff33

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +16 -11
tasks/text.py CHANGED
@@ -12,6 +12,16 @@ from safetensors.torch import load_file
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
14
 
 
 
 
 
 
 
 
 
 
 
15
  router = APIRouter()
16
 
17
  DESCRIPTION = "GTE Architecture"
@@ -70,16 +80,6 @@ async def evaluate_text(request: TextEvaluationRequest):
70
  true_labels = test_dataset["label"]
71
  texts = test_dataset["quote"]
72
 
73
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
74
-
75
- model_repo = "elucidator8918/frugal-ai-text"
76
- model = AutoBertClassifier(num_labels=8)
77
- model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
78
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
79
-
80
- model = model.to(device)
81
- model.eval()
82
-
83
  # Start tracking emissions
84
  tracker.start()
85
  tracker.start_task("inference")
@@ -94,6 +94,7 @@ async def evaluate_text(request: TextEvaluationRequest):
94
  truncation=True,
95
  padding=True,
96
  return_tensors="pt",
 
97
  )
98
 
99
  with torch.no_grad():
@@ -101,7 +102,7 @@ async def evaluate_text(request: TextEvaluationRequest):
101
  text_attention_mask = text_encoding["attention_mask"].to(device)
102
  outputs = model(text_input_ids, text_attention_mask)
103
  predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
104
-
105
  #--------------------------------------------------------------------------------------------
106
  # YOUR MODEL INFERENCE STOPS HERE
107
  #--------------------------------------------------------------------------------------------
@@ -111,6 +112,8 @@ async def evaluate_text(request: TextEvaluationRequest):
111
 
112
  # Calculate accuracy
113
  accuracy = accuracy_score(true_labels, predictions)
 
 
114
 
115
  # Prepare results dictionary
116
  results = {
@@ -129,5 +132,7 @@ async def evaluate_text(request: TextEvaluationRequest):
129
  "test_seed": request.test_seed
130
  }
131
  }
 
 
132
 
133
  return results
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
14
 
15
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
+
17
+ model_repo = "elucidator8918/frugal-ai-text"
18
+ model = AutoBertClassifier(num_labels=8)
19
+ model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
20
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
21
+
22
+ model = model.to(device)
23
+ model.eval()
24
+
25
  router = APIRouter()
26
 
27
  DESCRIPTION = "GTE Architecture"
 
80
  true_labels = test_dataset["label"]
81
  texts = test_dataset["quote"]
82
 
 
 
 
 
 
 
 
 
 
 
83
  # Start tracking emissions
84
  tracker.start()
85
  tracker.start_task("inference")
 
94
  truncation=True,
95
  padding=True,
96
  return_tensors="pt",
97
+ max_length=256
98
  )
99
 
100
  with torch.no_grad():
 
102
  text_attention_mask = text_encoding["attention_mask"].to(device)
103
  outputs = model(text_input_ids, text_attention_mask)
104
  predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
105
+
106
  #--------------------------------------------------------------------------------------------
107
  # YOUR MODEL INFERENCE STOPS HERE
108
  #--------------------------------------------------------------------------------------------
 
112
 
113
  # Calculate accuracy
114
  accuracy = accuracy_score(true_labels, predictions)
115
+
116
+ print(f"Accuracy = {accuracy}")
117
 
118
  # Prepare results dictionary
119
  results = {
 
132
  "test_seed": request.test_seed
133
  }
134
  }
135
+
136
+ print(results)
137
 
138
  return results