Reyad-Ahmmed commited on
Commit
b497dbc
·
verified ·
1 Parent(s): 31d163b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -104,15 +104,15 @@ if (runModel=='1'):
104
  # Create an instance of the custom loss function
105
  training_args = TrainingArguments(
106
  output_dir='./results_' + modelNameToUse,
107
- num_train_epochs=25,
108
- per_device_train_batch_size=2,
109
- per_device_eval_batch_size=2,
110
  warmup_steps=500,
111
  weight_decay=0.02,
112
  logging_dir='./logs_' + modelNameToUse,
113
  logging_steps=10,
114
  evaluation_strategy="epoch",
115
-
116
  )
117
 
118
  trainer = Trainer(
@@ -167,7 +167,13 @@ if (runModel=='1'):
167
  else:
168
  print("\nNo incorrect predictions found.")
169
 
170
- train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
 
 
 
 
 
 
171
  evaluate_and_report_errors(model,train_dataloader, tokenizer)
172
 
173
  model_path = './' + modelNameToUse + '_model'
 
104
  # Create an instance of the custom loss function
105
  training_args = TrainingArguments(
106
  output_dir='./results_' + modelNameToUse,
107
+ num_train_epochs=5,
108
+ per_device_train_batch_size=8,
109
+ per_device_eval_batch_size=8,
110
  warmup_steps=500,
111
  weight_decay=0.02,
112
  logging_dir='./logs_' + modelNameToUse,
113
  logging_steps=10,
114
  evaluation_strategy="epoch",
115
+ load_best_model_at_end=True, # Load the best model based on evaluation
116
  )
117
 
118
  trainer = Trainer(
 
167
  else:
168
  print("\nNo incorrect predictions found.")
169
 
170
+ train_dataloader = DataLoader(
171
+ train_dataset,
172
+ batch_size=10,
173
+ shuffle=True,
174
+ num_workers=4 # Increase workers for faster data loading
175
+ )
176
+
177
  evaluate_and_report_errors(model,train_dataloader, tokenizer)
178
 
179
  model_path = './' + modelNameToUse + '_model'