Kevin Fink
commited on
Commit
·
9ac7e52
1
Parent(s):
451a63d
dev
Browse files
app.py
CHANGED
@@ -241,17 +241,14 @@ def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, g
|
|
241 |
def get_checkpoint_int(s):
|
242 |
int_index = s.find('-')
|
243 |
return int(s[int_index+1:])
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
shutil.rmtree(os.path.join('/data/results/checkpoints', check))
|
253 |
-
except:
|
254 |
-
pass
|
255 |
try:
|
256 |
train_result = trainer.train(resume_from_checkpoint=True)
|
257 |
except Exception as e:
|
@@ -259,12 +256,13 @@ def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, g
|
|
259 |
import shutil
|
260 |
checkpoint_dir = training_args.output_dir
|
261 |
# If the trainer_state.json is missing, look for the previous checkpoint
|
262 |
-
|
|
|
263 |
print(f'CHECKPOINTs: {previous_checkpoints}')
|
264 |
for check in previous_checkpoints:
|
265 |
try:
|
266 |
print(f"Removing previous checkpoint {check}")
|
267 |
-
shutil.rmtree(os.path.join(
|
268 |
train_result = trainer.train(resume_from_checkpoint=True)
|
269 |
trainer.push_to_hub(commit_message="Training complete!")
|
270 |
return 'DONE!'#train_result
|
|
|
241 |
def get_checkpoint_int(s):
|
242 |
int_index = s.find('-')
|
243 |
return int(s[int_index+1:])
|
244 |
+
|
245 |
+
def filter_checkpoints_dirs(l):
|
246 |
+
new_list = list()
|
247 |
+
for entry in l:
|
248 |
+
if 'checkpoint' in entry:
|
249 |
+
new_list.append(entry)
|
250 |
+
return new_list
|
251 |
+
|
|
|
|
|
|
|
252 |
try:
|
253 |
train_result = trainer.train(resume_from_checkpoint=True)
|
254 |
except Exception as e:
|
|
|
256 |
import shutil
|
257 |
checkpoint_dir = training_args.output_dir
|
258 |
# If the trainer_state.json is missing, look for the previous checkpoint
|
259 |
+
dir_entries = filter_checkpoints_dirs(os.listdir(checkpoint_dir))
|
260 |
+
previous_checkpoints = sorted(dir_entries, key=get_checkpoint_int, reverse=True)
|
261 |
print(f'CHECKPOINTs: {previous_checkpoints}')
|
262 |
for check in previous_checkpoints:
|
263 |
try:
|
264 |
print(f"Removing previous checkpoint {check}")
|
265 |
+
shutil.rmtree(os.path.join(checkpoint_dir, check))
|
266 |
train_result = trainer.train(resume_from_checkpoint=True)
|
267 |
trainer.push_to_hub(commit_message="Training complete!")
|
268 |
return 'DONE!'#train_result
|