Medissa commited on
Commit
6f98f11
·
verified ·
1 Parent(s): 3222d5d

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +67 -3
tasks/text.py CHANGED
@@ -1,3 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
@@ -56,10 +84,46 @@ async def evaluate_text(request: TextEvaluationRequest):
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
 
59
- # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
65
  #--------------------------------------------------------------------------------------------
 
1
+ import os
2
+ import gc
3
+
4
+ os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+ from typing import Optional, Union
9
+ import pandas as pd, numpy as np, torch
10
+ from datasets import Dataset, load_dataset
11
+ from dataclasses import dataclass
12
+ from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, AutoModel
13
+ from transformers import EarlyStoppingCallback
14
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
15
+ from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
16
+ import numpy as np
17
+ from sklearn.metrics import recall_score, accuracy_score
18
+ from transformers import DataCollatorWithPadding
19
+
20
+ VER = 1
21
+ MAX_LEN = 256
22
+ TOKENIZER_BINARY = "crarojasca/BinaryAugmentedCARDS"
23
+ BINARY_MODEL = "Medissa/Roberta_Binary"
24
+ TOKENIZER_MULTI_CLASS = "crarojasca/TaxonomyAugmentedCARDS"
25
+ MULTI_CLASS_MODEL = "Medissa/Deberta_Taxonomy"
26
+
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+
29
  from fastapi import APIRouter
30
  from datetime import datetime
31
  from datasets import load_dataset
 
84
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
85
  #--------------------------------------------------------------------------------------------
86
 
87
+ # Binary Model
88
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_BINARY)
89
+ model = AutoModelForSequenceClassification.from_pretrained(BINARY_MODEL)
90
+ model.to(device)
91
+ model.eval()
92
+
93
+ predictions = []
94
+ for text in tqdm(dataset_test_df["quote"]):
95
+ with torch.no_grad():
96
+ tokenized_text = tokenizer(text, truncation=True, padding='max_length', return_tensors = "pt")
97
+ inputt = {k:v.to(device) for k,v in tokenized_text.items()}
98
+ # Running Binary Model
99
+ outputs = model(**inputt)
100
+ binary_prediction = torch.argmax(outputs.logits, axis=1)
101
+ binary_predictions = binary_prediction.to('cpu').item()
102
+
103
+ prediction = "0_not_relevant" if binary_prediction==0 else 1
104
+ predictions.append(prediction)
105
+
106
+ gc.collect()
107
 
108
+ ## 2. Taxonomy Model
109
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MULTI_CLASS)
110
+ model = AutoModelForSequenceClassification.from_pretrained(MULTI_CLASS_MODEL)
111
+ model.to(device)
112
+ model.eval()
113
+
114
+ for i,text in tqdm(enumerate(dataset_test_df["quote"])):
115
+ if isinstance(predictions[i], str):
116
+ continue
117
+ with torch.no_grad():
118
+ tokenized_text = tokenizer(text, truncation=True, padding='max_length', return_tensors = "pt")
119
+ inputt = {k:v.to(device) for k,v in tokenized_text.items()}
120
+ outputs = model(**inputt)
121
+ taxonomy_prediction = torch.argmax(outputs.logits, axis=1)
122
+ taxonomy_prediction = taxonomy_prediction.to('cpu').item()
123
+
124
+ prediction = ID2LABEL[taxonomy_prediction]
125
+ predictions[i] = prediction
126
+
127
  #--------------------------------------------------------------------------------------------
128
  # YOUR MODEL INFERENCE STOPS HERE
129
  #--------------------------------------------------------------------------------------------