Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +14 -4
tasks/text.py
CHANGED
@@ -16,6 +16,9 @@ from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
|
|
16 |
import numpy as np
|
17 |
from sklearn.metrics import recall_score, accuracy_score
|
18 |
from transformers import DataCollatorWithPadding
|
|
|
|
|
|
|
19 |
|
20 |
VER = 1
|
21 |
MAX_LEN = 256
|
@@ -92,15 +95,17 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
92 |
# YOUR MODEL INFERENCE CODE HERE
|
93 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
94 |
#--------------------------------------------------------------------------------------------
|
95 |
-
|
96 |
# Binary Model
|
97 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_BINARY)
|
|
|
98 |
model = AutoModelForSequenceClassification.from_pretrained(BINARY_MODEL)
|
|
|
99 |
model.to(device)
|
100 |
model.eval()
|
101 |
|
102 |
predictions = []
|
103 |
-
for text in tqdm(test_dataset["quote"]):
|
104 |
with torch.no_grad():
|
105 |
tokenized_text = tokenizer(text, truncation=True, padding='max_length', return_tensors = "pt")
|
106 |
inputt = {k:v.to(device) for k,v in tokenized_text.items()}
|
@@ -111,15 +116,19 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
111 |
|
112 |
prediction = "0_not_relevant" if binary_prediction==0 else 1
|
113 |
predictions.append(prediction)
|
|
|
|
|
114 |
|
115 |
gc.collect()
|
116 |
|
117 |
## 2. Taxonomy Model
|
|
|
118 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MULTI_CLASS)
|
|
|
119 |
model = AutoModelForSequenceClassification.from_pretrained(MULTI_CLASS_MODEL)
|
120 |
model.to(device)
|
121 |
model.eval()
|
122 |
-
|
123 |
for i,text in tqdm(enumerate(test_dataset["quote"])):
|
124 |
if isinstance(predictions[i], str):
|
125 |
continue
|
@@ -132,7 +141,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
132 |
|
133 |
prediction = ID2LABEL[taxonomy_prediction]
|
134 |
predictions[i] = prediction
|
135 |
-
|
|
|
136 |
predictions = [LABEL_MAPPING[pred] for pred in predictions]
|
137 |
#--------------------------------------------------------------------------------------------
|
138 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
|
16 |
import numpy as np
|
17 |
from sklearn.metrics import recall_score, accuracy_score
|
18 |
from transformers import DataCollatorWithPadding
|
19 |
+
import logging
|
20 |
+
# import mylib
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
|
23 |
VER = 1
|
24 |
MAX_LEN = 256
|
|
|
95 |
# YOUR MODEL INFERENCE CODE HERE
|
96 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
97 |
#--------------------------------------------------------------------------------------------
|
98 |
+
logger.info('Start Binary')
|
99 |
# Binary Model
|
100 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_BINARY)
|
101 |
+
logger.info('Loaded Tokenizer')
|
102 |
model = AutoModelForSequenceClassification.from_pretrained(BINARY_MODEL)
|
103 |
+
logger.info('Loaded Model')
|
104 |
model.to(device)
|
105 |
model.eval()
|
106 |
|
107 |
predictions = []
|
108 |
+
for i,text in tqdm(enumerate(test_dataset["quote"])):
|
109 |
with torch.no_grad():
|
110 |
tokenized_text = tokenizer(text, truncation=True, padding='max_length', return_tensors = "pt")
|
111 |
inputt = {k:v.to(device) for k,v in tokenized_text.items()}
|
|
|
116 |
|
117 |
prediction = "0_not_relevant" if binary_prediction==0 else 1
|
118 |
predictions.append(prediction)
|
119 |
+
if i%10:
|
120 |
+
logger.info(f'iteration: {i}')
|
121 |
|
122 |
gc.collect()
|
123 |
|
124 |
## 2. Taxonomy Model
|
125 |
+
logger.info('Start Multi')
|
126 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MULTI_CLASS)
|
127 |
+
logger.info('Loaded Tokenizer')
|
128 |
model = AutoModelForSequenceClassification.from_pretrained(MULTI_CLASS_MODEL)
|
129 |
model.to(device)
|
130 |
model.eval()
|
131 |
+
logger.info('Loaded Model')
|
132 |
for i,text in tqdm(enumerate(test_dataset["quote"])):
|
133 |
if isinstance(predictions[i], str):
|
134 |
continue
|
|
|
141 |
|
142 |
prediction = ID2LABEL[taxonomy_prediction]
|
143 |
predictions[i] = prediction
|
144 |
+
if i%10:
|
145 |
+
logger.info(f'iteration: {i}')
|
146 |
predictions = [LABEL_MAPPING[pred] for pred in predictions]
|
147 |
#--------------------------------------------------------------------------------------------
|
148 |
# YOUR MODEL INFERENCE STOPS HERE
|