Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
·
68255ee
1
Parent(s):
0537112
wandb fix
Browse files- llama_lora/lib/finetune.py +41 -22
- llama_lora/ui/finetune_ui.py +9 -1
llama_lora/lib/finetune.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
from typing import Any, List
|
4 |
|
5 |
import json
|
@@ -54,16 +55,38 @@ def train(
|
|
54 |
# wandb params
|
55 |
wandb_api_key = None,
|
56 |
wandb_project: str = "",
|
|
|
57 |
wandb_run_name: str = "",
|
|
|
58 |
wandb_watch: str = "false", # options: false | gradients | all
|
59 |
wandb_log_model: str = "true", # options: false | true
|
60 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
if wandb_api_key:
|
62 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
63 |
-
|
64 |
-
|
65 |
-
if
|
66 |
-
|
|
|
|
|
67 |
if wandb_watch:
|
68 |
os.environ["WANDB_WATCH"] = wandb_watch
|
69 |
if wandb_log_model:
|
@@ -73,6 +96,18 @@ def train(
|
|
73 |
)
|
74 |
if use_wandb:
|
75 |
os.environ['WANDB_MODE'] = "online"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
else:
|
77 |
os.environ['WANDB_MODE'] = "disabled"
|
78 |
|
@@ -243,24 +278,8 @@ def train(
|
|
243 |
os.makedirs(output_dir)
|
244 |
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
245 |
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
246 |
-
with open(os.path.join(output_dir, "
|
247 |
-
|
248 |
-
'micro_batch_size': micro_batch_size,
|
249 |
-
'gradient_accumulation_steps': gradient_accumulation_steps,
|
250 |
-
'num_train_epochs': num_train_epochs,
|
251 |
-
'learning_rate': learning_rate,
|
252 |
-
'cutoff_len': cutoff_len,
|
253 |
-
'lora_r': lora_r,
|
254 |
-
'lora_alpha': lora_alpha,
|
255 |
-
'lora_dropout': lora_dropout,
|
256 |
-
'lora_target_modules': lora_target_modules,
|
257 |
-
'train_on_inputs': train_on_inputs,
|
258 |
-
'group_by_length': group_by_length,
|
259 |
-
'save_steps': save_steps,
|
260 |
-
'save_total_limit': save_total_limit,
|
261 |
-
'logging_steps': logging_steps,
|
262 |
-
}
|
263 |
-
json.dump(finetune_params, finetune_params_json_file, indent=2)
|
264 |
|
265 |
# Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
|
266 |
# if train_data:
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import importlib
|
4 |
from typing import Any, List
|
5 |
|
6 |
import json
|
|
|
55 |
# wandb params
|
56 |
wandb_api_key = None,
|
57 |
wandb_project: str = "",
|
58 |
+
wandb_group = None,
|
59 |
wandb_run_name: str = "",
|
60 |
+
wandb_tags: List[str] = [],
|
61 |
wandb_watch: str = "false", # options: false | gradients | all
|
62 |
wandb_log_model: str = "true", # options: false | true
|
63 |
):
|
64 |
+
# for logging
|
65 |
+
finetune_args = {
|
66 |
+
'micro_batch_size': micro_batch_size,
|
67 |
+
'gradient_accumulation_steps': gradient_accumulation_steps,
|
68 |
+
'num_train_epochs': num_train_epochs,
|
69 |
+
'learning_rate': learning_rate,
|
70 |
+
'cutoff_len': cutoff_len,
|
71 |
+
'lora_r': lora_r,
|
72 |
+
'lora_alpha': lora_alpha,
|
73 |
+
'lora_dropout': lora_dropout,
|
74 |
+
'lora_target_modules': lora_target_modules,
|
75 |
+
'train_on_inputs': train_on_inputs,
|
76 |
+
'group_by_length': group_by_length,
|
77 |
+
'save_steps': save_steps,
|
78 |
+
'save_total_limit': save_total_limit,
|
79 |
+
'logging_steps': logging_steps,
|
80 |
+
}
|
81 |
+
|
82 |
if wandb_api_key:
|
83 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
84 |
+
|
85 |
+
# wandb: WARNING Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to https://wandb.me/wandb-init.
|
86 |
+
# if wandb_project:
|
87 |
+
# os.environ["WANDB_PROJECT"] = wandb_project
|
88 |
+
# if wandb_run_name:
|
89 |
+
# os.environ["WANDB_RUN_NAME"] = wandb_run_name
|
90 |
if wandb_watch:
|
91 |
os.environ["WANDB_WATCH"] = wandb_watch
|
92 |
if wandb_log_model:
|
|
|
96 |
)
|
97 |
if use_wandb:
|
98 |
os.environ['WANDB_MODE'] = "online"
|
99 |
+
wandb = importlib.import_module("wandb")
|
100 |
+
wandb.init(
|
101 |
+
project=wandb_project,
|
102 |
+
resume="auto",
|
103 |
+
group=wandb_group,
|
104 |
+
name=wandb_run_name,
|
105 |
+
tags=wandb_tags,
|
106 |
+
reinit=True,
|
107 |
+
magic=True,
|
108 |
+
config={'finetune_args': finetune_args},
|
109 |
+
# id=None # used for resuming
|
110 |
+
)
|
111 |
else:
|
112 |
os.environ['WANDB_MODE'] = "disabled"
|
113 |
|
|
|
278 |
os.makedirs(output_dir)
|
279 |
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
280 |
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
281 |
+
with open(os.path.join(output_dir, "finetune_args.json"), 'w') as finetune_args_json_file:
|
282 |
+
json.dump(finetune_args, finetune_args_json_file, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
# Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
|
285 |
# if train_data:
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -415,6 +415,12 @@ Train data (first 10):
|
|
415 |
if not should_training_progress_track_tqdm:
|
416 |
progress(0, desc="Train starting...")
|
417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
train_output = Global.train_fn(
|
419 |
base_model, # base_model
|
420 |
tokenizer, # tokenizer
|
@@ -440,7 +446,9 @@ Train data (first 10):
|
|
440 |
training_callbacks, # callbacks
|
441 |
Global.wandb_api_key, # wandb_api_key
|
442 |
Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
|
443 |
-
|
|
|
|
|
444 |
)
|
445 |
|
446 |
logs_str = "\n".join([json.dumps(log)
|
|
|
415 |
if not should_training_progress_track_tqdm:
|
416 |
progress(0, desc="Train starting...")
|
417 |
|
418 |
+
wandb_group = template
|
419 |
+
wandb_tags = [f"template:{template}"]
|
420 |
+
if load_dataset_from == "Data Dir" and dataset_from_data_dir:
|
421 |
+
wandb_group += f"/{dataset_from_data_dir}"
|
422 |
+
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
423 |
+
|
424 |
train_output = Global.train_fn(
|
425 |
base_model, # base_model
|
426 |
tokenizer, # tokenizer
|
|
|
446 |
training_callbacks, # callbacks
|
447 |
Global.wandb_api_key, # wandb_api_key
|
448 |
Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
|
449 |
+
wandb_group, # wandb_group
|
450 |
+
model_name, # wandb_run_name
|
451 |
+
wandb_tags # wandb_tags
|
452 |
)
|
453 |
|
454 |
logs_str = "\n".join([json.dumps(log)
|