Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +12 -12
tasks/text.py
CHANGED
@@ -12,18 +12,6 @@ from safetensors.torch import load_file
|
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
14 |
|
15 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
16 |
-
|
17 |
-
model_repo = "elucidator8918/frugal-ai-text"
|
18 |
-
model = AutoBertClassifier(num_labels=8)
|
19 |
-
model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
|
20 |
-
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
21 |
-
|
22 |
-
model = model.to(device)
|
23 |
-
model.eval()
|
24 |
-
|
25 |
-
router = APIRouter()
|
26 |
-
|
27 |
DESCRIPTION = "GTE Architecture"
|
28 |
ROUTE = "/text"
|
29 |
|
@@ -44,6 +32,18 @@ class AutoBertClassifier(nn.Module):
|
|
44 |
logits = self.classifier(pooled_output)
|
45 |
return logits
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
@router.post(ROUTE, tags=["Text Task"],
|
48 |
description=DESCRIPTION)
|
49 |
async def evaluate_text(request: TextEvaluationRequest):
|
|
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
DESCRIPTION = "GTE Architecture"
|
16 |
ROUTE = "/text"
|
17 |
|
|
|
32 |
logits = self.classifier(pooled_output)
|
33 |
return logits
|
34 |
|
35 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
36 |
+
|
37 |
+
model_repo = "elucidator8918/frugal-ai-text"
|
38 |
+
model = AutoBertClassifier(num_labels=8)
|
39 |
+
model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
41 |
+
|
42 |
+
model = model.to(device)
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
router = APIRouter()
|
46 |
+
|
47 |
@router.post(ROUTE, tags=["Text Task"],
|
48 |
description=DESCRIPTION)
|
49 |
async def evaluate_text(request: TextEvaluationRequest):
|