Varun Wadhwa commited on
Commit
4fb10c8
·
unverified ·
1 Parent(s): 995cb33
Files changed (1) hide show
  1. app.py +2 -8
app.py CHANGED
@@ -152,11 +152,6 @@ def evaluate_model(model, dataloader, device):
152
  print(len(all_labels))
153
  all_preds = np.asarray(all_preds, dtype=np.float32)
154
  all_labels = np.asarray(all_labels, dtype=np.float32)
155
- print("Flattened sizes")
156
- print(all_preds.size)
157
- print(all_labels.size)
158
- all_preds = all_preds.flatten()
159
- all_labels = all_labels.flatten()
160
  accuracy = accuracy_score(all_labels, all_preds)
161
  precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
162
 
@@ -199,9 +194,8 @@ dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size, collate_
199
  # create testing data loader
200
  test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator)
201
 
202
- # TEMPORARY - for testing
203
- teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
204
- print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
205
 
206
  # put student model in train mode
207
  student_model.train()
 
152
  print(len(all_labels))
153
  all_preds = np.asarray(all_preds, dtype=np.float32)
154
  all_labels = np.asarray(all_labels, dtype=np.float32)
 
 
 
 
 
155
  accuracy = accuracy_score(all_labels, all_preds)
156
  precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
157
 
 
194
  # create testing data loader
195
  test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator)
196
 
197
+ untrained_student_accuracy, untrained_student_precision, untrained_student_recall, untrained_student_f1 = evaluate_model(student_model, test_dataloader, device)
198
+ print(f"Untrained Student (test) - Accuracy: {untrained_student_accuracy:.4f}, Precision: {untrained_student_precision:.4f}, Recall: {untrained_student_recall:.4f}, F1 Score: {untrained_student_f1:.4f}")
 
199
 
200
  # put student model in train mode
201
  student_model.train()