Kevin Fink commited on
Commit
2185441
·
1 Parent(s): 9397fef
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -37,9 +37,11 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
37
  if isinstance(preds, tuple):
38
  preds = preds[0]
39
  # Replace -100s used for padding as we can't decode them
40
- #preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
41
  #preds = np.array(preds)
42
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
 
 
43
  #labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
44
  #labels = np.array(labels)
45
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
@@ -51,7 +53,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
51
  accuracy = accuracy_score(decoded_labels, decoded_preds)
52
  result["eval_accuracy"] = round(accuracy * 100, 4)
53
  return result
54
-
55
  login(api_key.strip())
56
 
57
 
@@ -232,7 +234,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
232
  )
233
 
234
  # Fine-tune the model
235
- trainer.train()
236
  #if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):
237
  #train_result = trainer.train(resume_from_checkpoint=True)
238
  #else:
 
37
  if isinstance(preds, tuple):
38
  preds = preds[0]
39
  # Replace -100s used for padding as we can't decode them
40
+ # preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
41
  #preds = np.array(preds)
42
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
43
+ if isinstance(labels, tuple):
44
+ labels = labels[0]
45
  #labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
46
  #labels = np.array(labels)
47
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
 
53
  accuracy = accuracy_score(decoded_labels, decoded_preds)
54
  result["eval_accuracy"] = round(accuracy * 100, 4)
55
  return result
56
+ tokenizer.
57
  login(api_key.strip())
58
 
59
 
 
234
  )
235
 
236
  # Fine-tune the model
237
+ trainer.evaluate()
238
  #if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):
239
  #train_result = trainer.train(resume_from_checkpoint=True)
240
  #else: