zetavg commited on
Commit
6ac1eb1
·
1 Parent(s): 05ad97e
Files changed (1) hide show
  1. llama_lora/lib/finetune.py +13 -11
llama_lora/lib/finetune.py CHANGED
@@ -53,16 +53,16 @@ def train(
53
  train_on_inputs: bool = True, # if False, masks out inputs in loss
54
  group_by_length: bool = False, # faster, but produces an odd training loss curve
55
  # either training checkpoint or final adapter
56
- resume_from_checkpoint = None,
57
  save_steps: int = 200,
58
  save_total_limit: int = 3,
59
  logging_steps: int = 10,
60
  # logging
61
  callbacks: List[Any] = [],
62
  # wandb params
63
- wandb_api_key = None,
64
  wandb_project: str = "",
65
- wandb_group = None,
66
  wandb_run_name: str = "",
67
  wandb_tags: List[str] = [],
68
  wandb_watch: str = "false", # options: false | gradients | all
@@ -115,8 +115,8 @@ def train(
115
  if wandb_log_model:
116
  os.environ["WANDB_LOG_MODEL"] = wandb_log_model
117
  use_wandb = (wandb_project and len(wandb_project) > 0) or (
118
- "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
119
- )
120
  if use_wandb:
121
  os.environ['WANDB_MODE'] = "online"
122
  wandb = importlib.import_module("wandb")
@@ -130,7 +130,7 @@ def train(
130
  magic=True,
131
  config={'finetune_args': finetune_args},
132
  # id=None # used for resuming
133
- )
134
  else:
135
  os.environ['WANDB_MODE'] = "disabled"
136
 
@@ -177,7 +177,8 @@ def train(
177
  raise e
178
 
179
  if re.match("[^/]+/llama", tokenizer_name):
180
- print(f"Setting special tokens for LLaMA tokenizer {tokenizer_name}...")
 
181
  tokenizer.pad_token_id = 0
182
  tokenizer.bos_token_id = 1
183
  tokenizer.eos_token_id = 2
@@ -276,17 +277,18 @@ def train(
276
 
277
  # Be more transparent about the % of trainable params.
278
  trainable_params = 0
279
- all_param = 0
280
  for _, param in model.named_parameters():
281
- all_param += param.numel()
282
  if param.requires_grad:
283
  trainable_params += param.numel()
284
  print(
285
- f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param} (calculated)"
286
  )
287
  model.print_trainable_parameters()
288
  if use_wandb and wandb:
289
- wandb.config.update({"model": { "all_param": all_param, "trainable_params": trainable_params, "trainable%": 100 * trainable_params / all_param }})
 
290
 
291
  if val_set_size > 0:
292
  train_val = train_data.train_test_split(
 
53
  train_on_inputs: bool = True, # if False, masks out inputs in loss
54
  group_by_length: bool = False, # faster, but produces an odd training loss curve
55
  # either training checkpoint or final adapter
56
+ resume_from_checkpoint=None,
57
  save_steps: int = 200,
58
  save_total_limit: int = 3,
59
  logging_steps: int = 10,
60
  # logging
61
  callbacks: List[Any] = [],
62
  # wandb params
63
+ wandb_api_key=None,
64
  wandb_project: str = "",
65
+ wandb_group=None,
66
  wandb_run_name: str = "",
67
  wandb_tags: List[str] = [],
68
  wandb_watch: str = "false", # options: false | gradients | all
 
115
  if wandb_log_model:
116
  os.environ["WANDB_LOG_MODEL"] = wandb_log_model
117
  use_wandb = (wandb_project and len(wandb_project) > 0) or (
118
+ "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
119
+ )
120
  if use_wandb:
121
  os.environ['WANDB_MODE'] = "online"
122
  wandb = importlib.import_module("wandb")
 
130
  magic=True,
131
  config={'finetune_args': finetune_args},
132
  # id=None # used for resuming
133
+ )
134
  else:
135
  os.environ['WANDB_MODE'] = "disabled"
136
 
 
177
  raise e
178
 
179
  if re.match("[^/]+/llama", tokenizer_name):
180
+ print(
181
+ f"Setting special tokens for LLaMA tokenizer {tokenizer_name}...")
182
  tokenizer.pad_token_id = 0
183
  tokenizer.bos_token_id = 1
184
  tokenizer.eos_token_id = 2
 
277
 
278
  # Be more transparent about the % of trainable params.
279
  trainable_params = 0
280
+ all_params = 0
281
  for _, param in model.named_parameters():
282
+ all_params += param.numel()
283
  if param.requires_grad:
284
  trainable_params += param.numel()
285
  print(
286
+ f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params} (calculated)"
287
  )
288
  model.print_trainable_parameters()
289
  if use_wandb and wandb:
290
+ wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
291
+ "trainable%": 100 * trainable_params / all_params}})
292
 
293
  if val_set_size > 0:
294
  train_val = train_data.train_test_split(