seronk commited on
Commit
5f29886
·
verified ·
1 Parent(s): 2b8c56b

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +22 -4
tasks/text.py CHANGED
@@ -8,7 +8,7 @@ from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  #additional imports
11
- from transformers import Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizerFast
12
  import logging
13
 
14
  router = APIRouter()
@@ -62,11 +62,29 @@ async def evaluate_text(request: TextEvaluationRequest):
62
 
63
  # Make random predictions (placeholder for actual model inference)
64
  true_labels = test_dataset["label"]
65
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
66
 
67
- print("hello_world")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
-
70
 
71
 
72
  #--------------------------------------------------------------------------------------------
 
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  #additional imports
11
+ from transformers import Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizerFast, AutoModelForSequenceClassification,DataCollatorWithPadding
12
  import logging
13
 
14
  router = APIRouter()
 
62
 
63
  # Make random predictions (placeholder for actual model inference)
64
  true_labels = test_dataset["label"]
 
65
 
66
+
67
+ model_name = "seronk/distillbert-frugal-ai"
68
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
69
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
70
+
71
+ def preprocess_function(df):
72
+ return tokenizer(df["quote"], truncation=True)
73
+ tokenized_test = test_dataset.map(preprocess_function, batched=True)
74
+
75
+ training_args = torch.load("./tasks/utils/training_args.bin")
76
+ training_args.eval_strategy='no'
77
+
78
+ trainer = Trainer(
79
+ model=model,
80
+ args=training_args,
81
+ tokenizer=tokenizer
82
+ )
83
+
84
+ ## prediction
85
+ preds = trainer.predict(tokenized_test)
86
+ predictions = np.array([np.argmax(x) for x in preds[0]])
87
 
 
88
 
89
 
90
  #--------------------------------------------------------------------------------------------