nhankins commited on
Commit
188ea19
·
verified ·
1 Parent(s): 9685f7b

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +41 -1
tasks/text.py CHANGED
@@ -7,9 +7,12 @@ import random
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
 
 
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -55,6 +58,43 @@ async def evaluate_text(request: TextEvaluationRequest):
55
  # YOUR MODEL INFERENCE CODE HERE
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Make random predictions (placeholder for actual model inference)
60
  true_labels = test_dataset["label"]
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
+ from huggingface_hub import InferenceClient
11
+ import json
12
+
13
  router = APIRouter()
14
 
15
+ DESCRIPTION = "Modified small RoBERTa checkpoint that focuses on emotions"
16
  ROUTE = "/text"
17
 
18
  @router.post(ROUTE, tags=["Text Task"],
 
58
  # YOUR MODEL INFERENCE CODE HERE
59
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
60
  #--------------------------------------------------------------------------------------------
61
+
62
+
63
+ repo_id = "nhankins/frugal_ai_submission"
64
+
65
+
66
+ llm_client = InferenceClient(
67
+
68
+
69
+ model=repo_id,
70
+
71
+ timeout=120,
72
+ )
73
+
74
+
75
+ def call_llm(inference_client: InferenceClient, prompt: str):
76
+
77
+ response = inference_client.post(
78
+
79
+ json={
80
+
81
+ "inputs": prompt,
82
+
83
+ "parameters": {"max_new_tokens": 200},
84
+
85
+ "task": "text-classification",
86
+
87
+ },
88
+
89
+ )
90
+
91
+ return json.loads(response.decode())[0]["generated_label"]
92
+
93
+
94
+
95
+ response=call_llm(llm_client, "climate disinformation here")
96
+
97
+ print (response)
98
 
99
  # Make random predictions (placeholder for actual model inference)
100
  true_labels = test_dataset["label"]