elucidator8918 commited on
Commit
c7e99c2
·
verified ·
1 Parent(s): f94cae2

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +12 -12
tasks/text.py CHANGED
@@ -12,18 +12,6 @@ from safetensors.torch import load_file
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
14
 
15
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
-
17
- model_repo = "elucidator8918/frugal-ai-text"
18
- model = AutoBertClassifier(num_labels=8)
19
- model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
20
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
21
-
22
- model = model.to(device)
23
- model.eval()
24
-
25
- router = APIRouter()
26
-
27
  DESCRIPTION = "GTE Architecture"
28
  ROUTE = "/text"
29
 
@@ -44,6 +32,18 @@ class AutoBertClassifier(nn.Module):
44
  logits = self.classifier(pooled_output)
45
  return logits
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @router.post(ROUTE, tags=["Text Task"],
48
  description=DESCRIPTION)
49
  async def evaluate_text(request: TextEvaluationRequest):
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  DESCRIPTION = "GTE Architecture"
16
  ROUTE = "/text"
17
 
 
32
  logits = self.classifier(pooled_output)
33
  return logits
34
 
35
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
36
+
37
+ model_repo = "elucidator8918/frugal-ai-text"
38
+ model = AutoBertClassifier(num_labels=8)
39
+ model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
40
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
41
+
42
+ model = model.to(device)
43
+ model.eval()
44
+
45
+ router = APIRouter()
46
+
47
  @router.post(ROUTE, tags=["Text Task"],
48
  description=DESCRIPTION)
49
  async def evaluate_text(request: TextEvaluationRequest):