Kevin Fink commited on
Commit
9ac7e52
·
1 Parent(s): 451a63d
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -241,17 +241,14 @@ def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, g
241
  def get_checkpoint_int(s):
242
  int_index = s.find('-')
243
  return int(s[int_index+1:])
244
-
245
- checkpoint_dir = training_args.output_dir
246
- # If the trainer_state.json is missing, look for the previous checkpoint
247
- previous_checkpoints = sorted(os.listdir("/data/results/checkpoints"), reverse=True)
248
- print(f'CHECKPOINTs: {previous_checkpoints}')
249
- for check in previous_checkpoints:
250
- try:
251
- print(f"Removing previous checkpoint {check}")
252
- shutil.rmtree(os.path.join('/data/results/checkpoints', check))
253
- except:
254
- pass
255
  try:
256
  train_result = trainer.train(resume_from_checkpoint=True)
257
  except Exception as e:
@@ -259,12 +256,13 @@ def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, g
259
  import shutil
260
  checkpoint_dir = training_args.output_dir
261
  # If the trainer_state.json is missing, look for the previous checkpoint
262
- previous_checkpoints = sorted(os.listdir("/data/results"), key=get_checkpoint_int, reverse=True)
 
263
  print(f'CHECKPOINTs: {previous_checkpoints}')
264
  for check in previous_checkpoints:
265
  try:
266
  print(f"Removing previous checkpoint {check}")
267
- shutil.rmtree(os.path.join('/data/results', check))
268
  train_result = trainer.train(resume_from_checkpoint=True)
269
  trainer.push_to_hub(commit_message="Training complete!")
270
  return 'DONE!'#train_result
 
241
  def get_checkpoint_int(s):
242
  int_index = s.find('-')
243
  return int(s[int_index+1:])
244
+
245
+ def filter_checkpoints_dirs(l):
246
+ new_list = list()
247
+ for entry in l:
248
+ if 'checkpoint' in entry:
249
+ new_list.append(entry)
250
+ return new_list
251
+
 
 
 
252
  try:
253
  train_result = trainer.train(resume_from_checkpoint=True)
254
  except Exception as e:
 
256
  import shutil
257
  checkpoint_dir = training_args.output_dir
258
  # If the trainer_state.json is missing, look for the previous checkpoint
259
+ dir_entries = filter_checkpoints_dirs(os.listdir(checkpoint_dir))
260
+ previous_checkpoints = sorted(dir_entries, key=get_checkpoint_int, reverse=True)
261
  print(f'CHECKPOINTs: {previous_checkpoints}')
262
  for check in previous_checkpoints:
263
  try:
264
  print(f"Removing previous checkpoint {check}")
265
+ shutil.rmtree(os.path.join(checkpoint_dir, check))
266
  train_result = trainer.train(resume_from_checkpoint=True)
267
  trainer.push_to_hub(commit_message="Training complete!")
268
  return 'DONE!'#train_result