Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
·
6ac1eb1
1
Parent(s):
05ad97e
fix
Browse files- llama_lora/lib/finetune.py +13 -11
llama_lora/lib/finetune.py
CHANGED
@@ -53,16 +53,16 @@ def train(
|
|
53 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
54 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
55 |
# either training checkpoint or final adapter
|
56 |
-
resume_from_checkpoint
|
57 |
save_steps: int = 200,
|
58 |
save_total_limit: int = 3,
|
59 |
logging_steps: int = 10,
|
60 |
# logging
|
61 |
callbacks: List[Any] = [],
|
62 |
# wandb params
|
63 |
-
wandb_api_key
|
64 |
wandb_project: str = "",
|
65 |
-
wandb_group
|
66 |
wandb_run_name: str = "",
|
67 |
wandb_tags: List[str] = [],
|
68 |
wandb_watch: str = "false", # options: false | gradients | all
|
@@ -115,8 +115,8 @@ def train(
|
|
115 |
if wandb_log_model:
|
116 |
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
117 |
use_wandb = (wandb_project and len(wandb_project) > 0) or (
|
118 |
-
|
119 |
-
|
120 |
if use_wandb:
|
121 |
os.environ['WANDB_MODE'] = "online"
|
122 |
wandb = importlib.import_module("wandb")
|
@@ -130,7 +130,7 @@ def train(
|
|
130 |
magic=True,
|
131 |
config={'finetune_args': finetune_args},
|
132 |
# id=None # used for resuming
|
133 |
-
|
134 |
else:
|
135 |
os.environ['WANDB_MODE'] = "disabled"
|
136 |
|
@@ -177,7 +177,8 @@ def train(
|
|
177 |
raise e
|
178 |
|
179 |
if re.match("[^/]+/llama", tokenizer_name):
|
180 |
-
print(
|
|
|
181 |
tokenizer.pad_token_id = 0
|
182 |
tokenizer.bos_token_id = 1
|
183 |
tokenizer.eos_token_id = 2
|
@@ -276,17 +277,18 @@ def train(
|
|
276 |
|
277 |
# Be more transparent about the % of trainable params.
|
278 |
trainable_params = 0
|
279 |
-
|
280 |
for _, param in model.named_parameters():
|
281 |
-
|
282 |
if param.requires_grad:
|
283 |
trainable_params += param.numel()
|
284 |
print(
|
285 |
-
f"trainable params: {trainable_params} || all params: {
|
286 |
)
|
287 |
model.print_trainable_parameters()
|
288 |
if use_wandb and wandb:
|
289 |
-
wandb.config.update({"model": {
|
|
|
290 |
|
291 |
if val_set_size > 0:
|
292 |
train_val = train_data.train_test_split(
|
|
|
53 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
54 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
55 |
# either training checkpoint or final adapter
|
56 |
+
resume_from_checkpoint=None,
|
57 |
save_steps: int = 200,
|
58 |
save_total_limit: int = 3,
|
59 |
logging_steps: int = 10,
|
60 |
# logging
|
61 |
callbacks: List[Any] = [],
|
62 |
# wandb params
|
63 |
+
wandb_api_key=None,
|
64 |
wandb_project: str = "",
|
65 |
+
wandb_group=None,
|
66 |
wandb_run_name: str = "",
|
67 |
wandb_tags: List[str] = [],
|
68 |
wandb_watch: str = "false", # options: false | gradients | all
|
|
|
115 |
if wandb_log_model:
|
116 |
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
117 |
use_wandb = (wandb_project and len(wandb_project) > 0) or (
|
118 |
+
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
119 |
+
)
|
120 |
if use_wandb:
|
121 |
os.environ['WANDB_MODE'] = "online"
|
122 |
wandb = importlib.import_module("wandb")
|
|
|
130 |
magic=True,
|
131 |
config={'finetune_args': finetune_args},
|
132 |
# id=None # used for resuming
|
133 |
+
)
|
134 |
else:
|
135 |
os.environ['WANDB_MODE'] = "disabled"
|
136 |
|
|
|
177 |
raise e
|
178 |
|
179 |
if re.match("[^/]+/llama", tokenizer_name):
|
180 |
+
print(
|
181 |
+
f"Setting special tokens for LLaMA tokenizer {tokenizer_name}...")
|
182 |
tokenizer.pad_token_id = 0
|
183 |
tokenizer.bos_token_id = 1
|
184 |
tokenizer.eos_token_id = 2
|
|
|
277 |
|
278 |
# Be more transparent about the % of trainable params.
|
279 |
trainable_params = 0
|
280 |
+
all_params = 0
|
281 |
for _, param in model.named_parameters():
|
282 |
+
all_params += param.numel()
|
283 |
if param.requires_grad:
|
284 |
trainable_params += param.numel()
|
285 |
print(
|
286 |
+
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params} (calculated)"
|
287 |
)
|
288 |
model.print_trainable_parameters()
|
289 |
if use_wandb and wandb:
|
290 |
+
wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
|
291 |
+
"trainable%": 100 * trainable_params / all_params}})
|
292 |
|
293 |
if val_set_size > 0:
|
294 |
train_val = train_data.train_test_split(
|