Kevin Fink commited on
Commit
1ec3de2
·
1 Parent(s): 759ca46
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import spaces
 
2
  import gradio as gr
3
  from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
4
  from transformers import DataCollatorForSeq2Seq, AutoConfig
@@ -231,12 +232,29 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
231
  ##data_collator=data_collator,
232
  ##processing_class=tokenizer,
233
  #)
234
-
235
- # Fine-tune the model
236
- if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):
237
- train_result = trainer.train(resume_from_checkpoint=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  else:
 
239
  train_result = trainer.train()
 
240
  trainer.push_to_hub(commit_message="Training complete!")
241
  except Exception as e:
242
  return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
 
1
  import spaces
2
+ import glob
3
  import gradio as gr
4
  from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
5
  from transformers import DataCollatorForSeq2Seq, AutoConfig
 
232
  ##data_collator=data_collator,
233
  ##processing_class=tokenizer,
234
  #)
235
+ checkpoint_dir = training_args.output_dir
236
+ if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
237
+ # Check if the trainer_state.json file exists in the specified checkpoint
238
+ trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
239
+ if os.path.exists(trainer_state_path):
240
+ train_result = trainer.train(resume_from_checkpoint=True)
241
+ else:
242
+ # If the trainer_state.json is missing, look for the previous checkpoint
243
+ print(f"Checkpoint {checkpoint_dir} is missing 'trainer_state.json'. Looking for previous checkpoints...")
244
+ previous_checkpoints = sorted(glob.glob(os.path.join(os.path.dirname(checkpoint_dir), 'checkpoint-*')), key=os.path.getmtime)
245
+
246
+ if previous_checkpoints:
247
+ # Load the most recent previous checkpoint
248
+ last_checkpoint = previous_checkpoints[-1]
249
+ print(f"Loading previous checkpoint: {last_checkpoint}")
250
+ train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
251
+ else:
252
+ print("No previous checkpoints found. Starting training from scratch.")
253
+ train_result = trainer.train()
254
  else:
255
+ print("No checkpoints found. Starting training from scratch.")
256
  train_result = trainer.train()
257
+
258
  trainer.push_to_hub(commit_message="Training complete!")
259
  except Exception as e:
260
  return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"