Varun Wadhwa commited on
Commit
2343812
·
unverified ·
1 Parent(s): 29d7652
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -5,7 +5,7 @@ from datasets import load_dataset
5
 
6
  import numpy as np
7
  import os
8
- from sklearn.metrics import accuracy_score, precision_recall_fscore_support
9
 
10
  import torch
11
  import torch.nn as nn
@@ -170,10 +170,11 @@ def evaluate_model(model, dataloader, device):
170
  print(len(all_labels))
171
  all_preds = np.asarray(all_preds, dtype=np.float32)
172
  all_labels = np.asarray(all_labels, dtype=np.float32)
 
173
  accuracy = accuracy_score(all_labels, all_preds)
174
  precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
175
 
176
- return accuracy, precision, recall, f1
177
 
178
  # Function to compute distillation and hard-label loss
179
  def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
@@ -212,8 +213,10 @@ dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size, collate_
212
  # create testing data loader
213
  test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator)
214
 
215
- untrained_student_accuracy, untrained_student_precision, untrained_student_recall, untrained_student_f1 = evaluate_model(student_model, test_dataloader, device)
216
- print(f"Untrained Student (test) - Accuracy: {untrained_student_accuracy:.4f}, Precision: {untrained_student_precision:.4f}, Recall: {untrained_student_recall:.4f}, F1 Score: {untrained_student_f1:.4f}")
 
 
217
 
218
  # put student model in train mode
219
  student_model.train()
@@ -248,28 +251,22 @@ for epoch in range(num_epochs):
248
  test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator, shuffle=True)
249
 
250
  # Evaluate the teacher model
251
- teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
252
- print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
 
 
 
253
 
254
  # Evaluate the student model
255
- student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
256
- print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
 
 
257
  print("\n")
258
 
259
  # put student model back into train mode
260
  student_model.train()
261
 
262
- #Compare the models
263
- # create testing data loader
264
- validation_dataloader = DataLoader(tokenized_data['test'], batch_size=8, collate_fn=data_collator)
265
- # Evaluate the teacher model
266
- teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
267
- print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
268
- # Evaluate the student model
269
- student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
270
- print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
271
-
272
-
273
  st.write('Pushing model to huggingface')
274
 
275
  # Push model to huggingface
 
5
 
6
  import numpy as np
7
  import os
8
+ from sklearn.metrics import classification_report, accuracy_score, precision_recall_fscore_support
9
 
10
  import torch
11
  import torch.nn as nn
 
170
  print(len(all_labels))
171
  all_preds = np.asarray(all_preds, dtype=np.float32)
172
  all_labels = np.asarray(all_labels, dtype=np.float32)
173
+ report = classification_report(all_labels, all_preds, target_names=id2label.values(), zero_division=0)
174
  accuracy = accuracy_score(all_labels, all_preds)
175
  precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
176
 
177
+ return report, accuracy, precision, recall, f1
178
 
179
  # Function to compute distillation and hard-label loss
180
  def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
 
213
  # create testing data loader
214
  test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator)
215
 
216
+ untrained_student_report, untrained_student_accuracy, untrained_student_precision, untrained_student_recall, untrained_student_f1 = evaluate_model(student_model, test_dataloader, device)
217
+ print(f"Untrained Student (test) - Report:")
218
+ print(untrained_student_report)
219
+ print(f"Accuracy: {untrained_student_accuracy:.4f}, Precision: {untrained_student_precision:.4f}, Recall: {untrained_student_recall:.4f}, F1 Score: {untrained_student_f1:.4f}")
220
 
221
  # put student model in train mode
222
  student_model.train()
 
251
  test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator, shuffle=True)
252
 
253
  # Evaluate the teacher model
254
+ teacher_report, teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
255
+ print(f"Teacher (test) - Report:")
256
+ print(teacher_report)
257
+ print(f"Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
258
+ print("\n")
259
 
260
  # Evaluate the student model
261
+ student_report, student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
262
+ print(f"Student (test) - Report:")
263
+ print(student_report)
264
+ print(f"Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
265
  print("\n")
266
 
267
  # put student model back into train mode
268
  student_model.train()
269
 
 
 
 
 
 
 
 
 
 
 
 
270
  st.write('Pushing model to huggingface')
271
 
272
  # Push model to huggingface