Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
38fb491
1
Parent(s):
00263ef
support resume_from_checkpoint
Browse files- llama_lora/lib/finetune.py +12 -3
- llama_lora/ui/finetune_ui.py +23 -2
llama_lora/lib/finetune.py
CHANGED
@@ -33,7 +33,7 @@ def train(
|
|
33 |
num_train_epochs: int = 3,
|
34 |
learning_rate: float = 3e-4,
|
35 |
cutoff_len: int = 256,
|
36 |
-
val_set_size: int = 2000,
|
37 |
# lora hyperparams
|
38 |
lora_r: int = 8,
|
39 |
lora_alpha: int = 16,
|
@@ -46,7 +46,7 @@ def train(
|
|
46 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
47 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
48 |
# either training checkpoint or final adapter
|
49 |
-
resume_from_checkpoint
|
50 |
save_steps: int = 200,
|
51 |
save_total_limit: int = 3,
|
52 |
logging_steps: int = 10,
|
@@ -68,6 +68,7 @@ def train(
|
|
68 |
'num_train_epochs': num_train_epochs,
|
69 |
'learning_rate': learning_rate,
|
70 |
'cutoff_len': cutoff_len,
|
|
|
71 |
'lora_r': lora_r,
|
72 |
'lora_alpha': lora_alpha,
|
73 |
'lora_dropout': lora_dropout,
|
@@ -78,7 +79,12 @@ def train(
|
|
78 |
'save_total_limit': save_total_limit,
|
79 |
'logging_steps': logging_steps,
|
80 |
}
|
|
|
|
|
|
|
|
|
81 |
|
|
|
82 |
if wandb_api_key:
|
83 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
84 |
|
@@ -220,7 +226,7 @@ def train(
|
|
220 |
adapters_weights = torch.load(checkpoint_name)
|
221 |
model = set_peft_model_state_dict(model, adapters_weights)
|
222 |
else:
|
223 |
-
|
224 |
|
225 |
# Be more transparent about the % of trainable params.
|
226 |
model.print_trainable_parameters()
|
@@ -315,4 +321,7 @@ def train(
|
|
315 |
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
316 |
json.dump(train_output, train_output_json_file, indent=2)
|
317 |
|
|
|
|
|
|
|
318 |
return train_output
|
|
|
33 |
num_train_epochs: int = 3,
|
34 |
learning_rate: float = 3e-4,
|
35 |
cutoff_len: int = 256,
|
36 |
+
val_set_size: int = 2000,
|
37 |
# lora hyperparams
|
38 |
lora_r: int = 8,
|
39 |
lora_alpha: int = 16,
|
|
|
46 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
47 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
48 |
# either training checkpoint or final adapter
|
49 |
+
resume_from_checkpoint = None,
|
50 |
save_steps: int = 200,
|
51 |
save_total_limit: int = 3,
|
52 |
logging_steps: int = 10,
|
|
|
68 |
'num_train_epochs': num_train_epochs,
|
69 |
'learning_rate': learning_rate,
|
70 |
'cutoff_len': cutoff_len,
|
71 |
+
'val_set_size': val_set_size,
|
72 |
'lora_r': lora_r,
|
73 |
'lora_alpha': lora_alpha,
|
74 |
'lora_dropout': lora_dropout,
|
|
|
79 |
'save_total_limit': save_total_limit,
|
80 |
'logging_steps': logging_steps,
|
81 |
}
|
82 |
+
if val_set_size and val_set_size > 0:
|
83 |
+
finetune_args['val_set_size'] = val_set_size
|
84 |
+
if resume_from_checkpoint:
|
85 |
+
finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
|
86 |
|
87 |
+
wandb = None
|
88 |
if wandb_api_key:
|
89 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
90 |
|
|
|
226 |
adapters_weights = torch.load(checkpoint_name)
|
227 |
model = set_peft_model_state_dict(model, adapters_weights)
|
228 |
else:
|
229 |
+
raise ValueError(f"Checkpoint {checkpoint_name} not found")
|
230 |
|
231 |
# Be more transparent about the % of trainable params.
|
232 |
model.print_trainable_parameters()
|
|
|
321 |
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
322 |
json.dump(train_output, train_output_json_file, indent=2)
|
323 |
|
324 |
+
if use_wandb and wandb:
|
325 |
+
wandb.finish()
|
326 |
+
|
327 |
return train_output
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -306,6 +306,17 @@ def do_train(
|
|
306 |
):
|
307 |
try:
|
308 |
base_model_name = Global.base_model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
310 |
if os.path.exists(output_dir):
|
311 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
@@ -376,6 +387,8 @@ Train options: {json.dumps({
|
|
376 |
'lora_dropout': lora_dropout,
|
377 |
'lora_target_modules': lora_target_modules,
|
378 |
'model_name': model_name,
|
|
|
|
|
379 |
}, indent=2)}
|
380 |
|
381 |
Train data (first 10):
|
@@ -386,7 +399,7 @@ Train data (first 10):
|
|
386 |
return message
|
387 |
|
388 |
if not should_training_progress_track_tqdm:
|
389 |
-
progress(0, desc="Preparing model for training...")
|
390 |
|
391 |
log_history = []
|
392 |
|
@@ -461,6 +474,10 @@ Train data (first 10):
|
|
461 |
# 'lora_dropout': lora_dropout,
|
462 |
# 'lora_target_modules': lora_target_modules,
|
463 |
}
|
|
|
|
|
|
|
|
|
464 |
json.dump(info, info_json_file, indent=2)
|
465 |
|
466 |
if not should_training_progress_track_tqdm:
|
@@ -490,7 +507,7 @@ Train data (first 10):
|
|
490 |
lora_target_modules, # lora_target_modules
|
491 |
train_on_inputs, # train_on_inputs
|
492 |
False, # group_by_length
|
493 |
-
|
494 |
save_steps, # save_steps
|
495 |
save_total_limit, # save_total_limit
|
496 |
logging_steps, # logging_steps
|
@@ -582,6 +599,8 @@ def handle_load_params_from_model(
|
|
582 |
cutoff_len = value
|
583 |
elif key == "evaluate_data_count":
|
584 |
evaluate_data_count = value
|
|
|
|
|
585 |
elif key == "micro_batch_size":
|
586 |
micro_batch_size = value
|
587 |
elif key == "gradient_accumulation_steps":
|
@@ -610,6 +629,8 @@ def handle_load_params_from_model(
|
|
610 |
logging_steps = value
|
611 |
elif key == "group_by_length":
|
612 |
pass
|
|
|
|
|
613 |
else:
|
614 |
unknown_keys.append(key)
|
615 |
except Exception as e:
|
|
|
306 |
):
|
307 |
try:
|
308 |
base_model_name = Global.base_model_name
|
309 |
+
|
310 |
+
resume_from_checkpoint = None
|
311 |
+
if continue_from_model == "-" or continue_from_model == "None":
|
312 |
+
continue_from_model = None
|
313 |
+
if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
|
314 |
+
continue_from_checkpoint = None
|
315 |
+
if continue_from_model:
|
316 |
+
resume_from_checkpoint = os.path.join(Global.data_dir, "lora_models", continue_from_model)
|
317 |
+
if continue_from_checkpoint:
|
318 |
+
resume_from_checkpoint = os.path.join(resume_from_checkpoint, continue_from_checkpoint)
|
319 |
+
|
320 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
321 |
if os.path.exists(output_dir):
|
322 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
|
|
387 |
'lora_dropout': lora_dropout,
|
388 |
'lora_target_modules': lora_target_modules,
|
389 |
'model_name': model_name,
|
390 |
+
'continue_from_model': continue_from_model,
|
391 |
+
'continue_from_checkpoint': continue_from_checkpoint,
|
392 |
}, indent=2)}
|
393 |
|
394 |
Train data (first 10):
|
|
|
399 |
return message
|
400 |
|
401 |
if not should_training_progress_track_tqdm:
|
402 |
+
progress(0, desc=f"Preparing model {base_model_name} for training...")
|
403 |
|
404 |
log_history = []
|
405 |
|
|
|
474 |
# 'lora_dropout': lora_dropout,
|
475 |
# 'lora_target_modules': lora_target_modules,
|
476 |
}
|
477 |
+
if continue_from_model:
|
478 |
+
info['continued_from_model'] = continue_from_model
|
479 |
+
if continue_from_checkpoint:
|
480 |
+
info['continued_from_checkpoint'] = continue_from_checkpoint
|
481 |
json.dump(info, info_json_file, indent=2)
|
482 |
|
483 |
if not should_training_progress_track_tqdm:
|
|
|
507 |
lora_target_modules, # lora_target_modules
|
508 |
train_on_inputs, # train_on_inputs
|
509 |
False, # group_by_length
|
510 |
+
resume_from_checkpoint, # resume_from_checkpoint
|
511 |
save_steps, # save_steps
|
512 |
save_total_limit, # save_total_limit
|
513 |
logging_steps, # logging_steps
|
|
|
599 |
cutoff_len = value
|
600 |
elif key == "evaluate_data_count":
|
601 |
evaluate_data_count = value
|
602 |
+
elif key == "val_set_size":
|
603 |
+
evaluate_data_count = value
|
604 |
elif key == "micro_batch_size":
|
605 |
micro_batch_size = value
|
606 |
elif key == "gradient_accumulation_steps":
|
|
|
629 |
logging_steps = value
|
630 |
elif key == "group_by_length":
|
631 |
pass
|
632 |
+
elif key == "resume_from_checkpoint":
|
633 |
+
pass
|
634 |
else:
|
635 |
unknown_keys.append(key)
|
636 |
except Exception as e:
|