Kevin Fink
commited on
Commit
·
1ec3de2
1
Parent(s):
759ca46
dev
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import spaces
|
|
|
2 |
import gradio as gr
|
3 |
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from transformers import DataCollatorForSeq2Seq, AutoConfig
|
@@ -231,12 +232,29 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
231 |
##data_collator=data_collator,
|
232 |
##processing_class=tokenizer,
|
233 |
#)
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
else:
|
|
|
239 |
train_result = trainer.train()
|
|
|
240 |
trainer.push_to_hub(commit_message="Training complete!")
|
241 |
except Exception as e:
|
242 |
return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
|
|
|
1 |
import spaces
|
2 |
+
import glob
|
3 |
import gradio as gr
|
4 |
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
from transformers import DataCollatorForSeq2Seq, AutoConfig
|
|
|
232 |
##data_collator=data_collator,
|
233 |
##processing_class=tokenizer,
|
234 |
#)
|
235 |
+
checkpoint_dir = training_args.output_dir
|
236 |
+
if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
|
237 |
+
# Check if the trainer_state.json file exists in the specified checkpoint
|
238 |
+
trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
|
239 |
+
if os.path.exists(trainer_state_path):
|
240 |
+
train_result = trainer.train(resume_from_checkpoint=True)
|
241 |
+
else:
|
242 |
+
# If the trainer_state.json is missing, look for the previous checkpoint
|
243 |
+
print(f"Checkpoint {checkpoint_dir} is missing 'trainer_state.json'. Looking for previous checkpoints...")
|
244 |
+
previous_checkpoints = sorted(glob.glob(os.path.join(os.path.dirname(checkpoint_dir), 'checkpoint-*')), key=os.path.getmtime)
|
245 |
+
|
246 |
+
if previous_checkpoints:
|
247 |
+
# Load the most recent previous checkpoint
|
248 |
+
last_checkpoint = previous_checkpoints[-1]
|
249 |
+
print(f"Loading previous checkpoint: {last_checkpoint}")
|
250 |
+
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
|
251 |
+
else:
|
252 |
+
print("No previous checkpoints found. Starting training from scratch.")
|
253 |
+
train_result = trainer.train()
|
254 |
else:
|
255 |
+
print("No checkpoints found. Starting training from scratch.")
|
256 |
train_result = trainer.train()
|
257 |
+
|
258 |
trainer.push_to_hub(commit_message="Training complete!")
|
259 |
except Exception as e:
|
260 |
return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
|