Zen0 commited on
Commit
731e8c7
·
verified ·
1 Parent(s): f1edb98

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +10 -29
tasks/text.py CHANGED
@@ -9,7 +9,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -18,9 +18,7 @@ async def evaluate_text(request: TextEvaluationRequest):
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
 
21
- Current Model: Random Baseline
22
- - Makes random predictions from the label space (0-7)
23
- - Used as a baseline for comparison
24
  """
25
  # Get space info
26
  username, space_url = get_space_info()
@@ -55,46 +53,29 @@ async def evaluate_text(request: TextEvaluationRequest):
55
  # YOUR MODEL INFERENCE CODE HERE
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
-
59
- #--------------------------------------------------------------------------------------------
60
- # Load your model and tokenizer from Hugging Face
61
- #--------------------------------------------------------------------------------------------
62
 
63
  model_name = "Zen0/FrugalDisinfoHunter" # Model identifier from Hugging Face
64
  tokenizer = AutoTokenizer.from_pretrained(model_name)
65
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
66
 
67
- #--------------------------------------------------------------------------------------------
68
- # Load the dataset
69
- #--------------------------------------------------------------------------------------------
70
-
71
- # Assuming 'quotaclimat/frugalaichallenge-text-train' is the dataset you are working with
72
- dataset = load_dataset("quotaclimat/frugalaichallenge-text-train")
73
-
74
- # Access the test dataset (you can change this if you want to use a different split)
75
- test_dataset = dataset['test'] # Assuming you have a 'test' split available
76
-
77
- #--------------------------------------------------------------------------------------------
78
- # Tokenize the text data
79
- #--------------------------------------------------------------------------------------------
80
-
81
- # Tokenize the test data (the text field contains the quotes)
82
  test_texts = test_dataset["text"] # The field 'text' contains the climate quotes
83
-
84
  inputs = tokenizer(test_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
 
 
85
 
86
- #--------------------------------------------------------------------------------------------
87
- # Inference
88
- #--------------------------------------------------------------------------------------------
89
 
90
- # Run inference on the dataset using the model
91
  with torch.no_grad(): # Disable gradient calculations
92
  outputs = model(**inputs)
93
  logits = outputs.logits
94
 
95
- # Get predictions from the logits (choose the class with the highest logit)
96
  predictions = torch.argmax(logits, dim=-1).cpu().numpy() # Convert to numpy array for use
97
 
 
 
 
98
  #--------------------------------------------------------------------------------------------
99
  # YOUR MODEL INFERENCE STOPS HERE
100
  #--------------------------------------------------------------------------------------------
 
9
 
10
  router = APIRouter()
11
 
12
+ DESCRIPTION = "FrugalDisinfoHunter Model"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
 
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
 
21
+ Current Model: FrugalDisinfoHunter
 
 
22
  """
23
  # Get space info
24
  username, space_url = get_space_info()
 
53
  # YOUR MODEL INFERENCE CODE HERE
54
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
55
  #--------------------------------------------------------------------------------------------
 
 
 
 
56
 
57
  model_name = "Zen0/FrugalDisinfoHunter" # Model identifier from Hugging Face
58
  tokenizer = AutoTokenizer.from_pretrained(model_name)
59
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  test_texts = test_dataset["text"] # The field 'text' contains the climate quotes
 
62
  inputs = tokenizer(test_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
63
+
64
+ dataset = load_dataset("quotaclimat/frugalaichallenge-text-train")
65
 
66
+ # Access the test dataset
67
+ test_dataset = dataset['test']
 
68
 
 
69
  with torch.no_grad(): # Disable gradient calculations
70
  outputs = model(**inputs)
71
  logits = outputs.logits
72
 
73
+ # Get predictions from the logits
74
  predictions = torch.argmax(logits, dim=-1).cpu().numpy() # Convert to numpy array for use
75
 
76
+ # Get true labels for accuracy calculation
77
+ true_labels = test_dataset["label"] # Extract true labels from the dataset
78
+
79
  #--------------------------------------------------------------------------------------------
80
  # YOUR MODEL INFERENCE STOPS HERE
81
  #--------------------------------------------------------------------------------------------