zetavg commited on
Commit
1583e8c
·
unverified ·
2 Parent(s): 1e8710e 726fa4d

Merge branch 'main' into hf-ui-demo

Browse files
.gitignore CHANGED
@@ -3,4 +3,5 @@ __pycache__/
3
  /venv
4
  .vscode
5
 
 
6
  /data
 
3
  /venv
4
  .vscode
5
 
6
+ /wandb
7
  /data
LLaMA_LoRA.ipynb CHANGED
@@ -60,12 +60,20 @@
60
  "# @title A small workaround { display-mode: \"form\" }\n",
61
  "# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
62
  "# @markdown The error will disappear on the next run.\n",
63
- "!pip install Pillow==9.3.0\n",
 
64
  "import PIL\n",
65
  "major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
66
  "version_float = major + minor / 10**len(str(minor))\n",
67
- "print(version_float)\n",
68
  "if version_float < 9.003:\n",
 
 
 
 
 
 
 
69
  " raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
70
  ],
71
  "metadata": {
@@ -281,7 +289,8 @@
281
  "\n",
282
  "# Set Configs\n",
283
  "from llama_lora.llama_lora.globals import Global\n",
284
- "Global.default_base_model_name = base_model\n",
 
285
  "data_dir_realpath = !realpath ./data\n",
286
  "Global.data_dir = data_dir_realpath[0]\n",
287
  "Global.load_8bit = True\n",
 
60
  "# @title A small workaround { display-mode: \"form\" }\n",
61
  "# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
62
  "# @markdown The error will disappear on the next run.\n",
63
+ "!pip install Pillow==9.3.0 numpy==1.23.5\n",
64
+ "\n",
65
  "import PIL\n",
66
  "major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
67
  "version_float = major + minor / 10**len(str(minor))\n",
68
+ "print('PIL', version_float)\n",
69
  "if version_float < 9.003:\n",
70
+ " raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")\n",
71
+ "\n",
72
+ "import numpy\n",
73
+ "major, minor = map(float, numpy.__version__.split(\".\")[:2])\n",
74
+ "version_float = major + minor / 10**len(str(minor))\n",
75
+ "print('numpy', version_float)\n",
76
+ "if version_float < 1.0023:\n",
77
  " raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
78
  ],
79
  "metadata": {
 
289
  "\n",
290
  "# Set Configs\n",
291
  "from llama_lora.llama_lora.globals import Global\n",
292
+ "Global.default_base_model_name = Global.base_model_name = base_model\n",
293
+ "Global.base_model_choices = [base_model]\n",
294
  "data_dir_realpath = !realpath ./data\n",
295
  "Global.data_dir = data_dir_realpath[0]\n",
296
  "Global.load_8bit = True\n",
README.md CHANGED
@@ -34,8 +34,8 @@ Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) e
34
 
35
  * **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
36
  * Loads and stores data in Google Drive.
37
- * Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/A3kb4VkDWyY"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230272844-09f7a35b-46bf-4101-b15d-4ddf243b8bef.gif" /></a>
38
- * Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/5Db9U8PsaUk"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230277315-9a91d983-1690-4594-9d54-912eda8963ee.gif" /></a>
39
  * Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
40
  * Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
41
  * Use prompt templates to keep your dataset DRY.
@@ -51,6 +51,8 @@ There are various ways to run this app:
51
 
52
  ### Run On Google Colab
53
 
 
 
54
  Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
55
 
56
  You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
@@ -81,13 +83,14 @@ file_mounts:
81
  setup: |
82
  git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
83
  cd llama_lora_tuner && pip install -r requirements.lock.txt
 
84
  cd ..
85
  echo 'Dependencies installed.'
86
 
87
  # Start the app.
88
  run: |
89
  echo 'Starting...'
90
- python llama_lora_tuner/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
91
  ```
92
 
93
  Then launch a cluster to run the task:
@@ -135,6 +138,11 @@ For more options, see `python app.py --help`.
135
  </details>
136
 
137
 
 
 
 
 
 
138
  ## Acknowledgements
139
 
140
  * https://github.com/tloen/alpaca-lora
 
34
 
35
  * **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
36
  * Loads and stores data in Google Drive.
37
+ * Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/IoEMgouZ5xU"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231023326-f28c84e2-df74-4179-b0ac-c25c4e8ca001.gif" /></a>
38
+ * Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/IoEMgouZ5xU?t=60"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231026640-b5cf5c79-9fe9-430b-8d4e-7346eb9567ad.gif" /></a>
39
  * Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
40
  * Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
41
  * Use prompt templates to keep your dataset DRY.
 
51
 
52
  ### Run On Google Colab
53
 
54
+ *See [video](https://youtu.be/lByYOMdy9h4) for step-by-step instructions.*
55
+
56
  Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
57
 
58
  You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
 
83
  setup: |
84
  git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
85
  cd llama_lora_tuner && pip install -r requirements.lock.txt
86
+ pip install wandb
87
  cd ..
88
  echo 'Dependencies installed.'
89
 
90
  # Start the app.
91
  run: |
92
  echo 'Starting...'
93
+ python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key "$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model='decapoda-research/llama-7b-hf' --share
94
  ```
95
 
96
  Then launch a cluster to run the task:
 
138
  </details>
139
 
140
 
141
+ ## Usage
142
+
143
+ See [video on YouTube](https://youtu.be/IoEMgouZ5xU).
144
+
145
+
146
  ## Acknowledgements
147
 
148
  * https://github.com/tloen/alpaca-lora
app.py CHANGED
@@ -5,21 +5,41 @@ import fire
5
  import gradio as gr
6
 
7
  from llama_lora.globals import Global
 
8
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
9
  from llama_lora.utils.data import init_data_dir
10
 
11
 
 
12
  def main(
13
- load_8bit: bool = False,
14
  base_model: str = "",
15
  data_dir: str = "",
 
16
  # Allows to listen on all interfaces by providing '0.0.0.0'.
17
  server_name: str = "127.0.0.1",
18
  share: bool = False,
19
  skip_loading_base_model: bool = False,
 
20
  ui_show_sys_info: bool = True,
21
  ui_dev_mode: bool = False,
 
 
22
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
24
  data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
25
  assert (
@@ -30,16 +50,35 @@ def main(
30
  data_dir
31
  ), "Please specify a --data_dir, e.g. --data_dir='./data'"
32
 
33
- Global.default_base_model_name = base_model
 
 
 
 
 
 
 
 
 
34
  Global.data_dir = os.path.abspath(data_dir)
35
  Global.load_8bit = load_8bit
36
 
 
 
 
 
 
 
 
37
  Global.ui_dev_mode = ui_dev_mode
38
  Global.ui_show_sys_info = ui_show_sys_info
39
 
40
  os.makedirs(data_dir, exist_ok=True)
41
  init_data_dir()
42
 
 
 
 
43
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
44
  main_page()
45
 
 
5
  import gradio as gr
6
 
7
  from llama_lora.globals import Global
8
+ from llama_lora.models import prepare_base_model
9
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
10
  from llama_lora.utils.data import init_data_dir
11
 
12
 
13
+
14
  def main(
 
15
  base_model: str = "",
16
  data_dir: str = "",
17
+ base_model_choices: str = "",
18
  # Allows to listen on all interfaces by providing '0.0.0.0'.
19
  server_name: str = "127.0.0.1",
20
  share: bool = False,
21
  skip_loading_base_model: bool = False,
22
+ load_8bit: bool = False,
23
  ui_show_sys_info: bool = True,
24
  ui_dev_mode: bool = False,
25
+ wandb_api_key: str = "",
26
+ wandb_project: str = "",
27
  ):
28
+ '''
29
+ Start the LLaMA-LoRA Tuner UI.
30
+
31
+ :param base_model: (required) The name of the default base model to use.
32
+ :param data_dir: (required) The path to the directory to store data.
33
+
34
+ :param base_model_choices: Base model selections to display on the UI, seperated by ",". For example: 'decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'.
35
+
36
+ :param server_name: Allows to listen on all interfaces by providing '0.0.0.0'.
37
+ :param share: Create a public Gradio URL.
38
+
39
+ :param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
40
+ :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
41
+ '''
42
+
43
  base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
44
  data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
45
  assert (
 
50
  data_dir
51
  ), "Please specify a --data_dir, e.g. --data_dir='./data'"
52
 
53
+ Global.default_base_model_name = Global.base_model_name = base_model
54
+
55
+ if base_model_choices:
56
+ base_model_choices = base_model_choices.split(',')
57
+ base_model_choices = [name.strip() for name in base_model_choices]
58
+ Global.base_model_choices = base_model_choices
59
+
60
+ if base_model not in Global.base_model_choices:
61
+ Global.base_model_choices = [base_model] + Global.base_model_choices
62
+
63
  Global.data_dir = os.path.abspath(data_dir)
64
  Global.load_8bit = load_8bit
65
 
66
+ if len(wandb_api_key) > 0:
67
+ Global.enable_wandb = True
68
+ Global.wandb_api_key = wandb_api_key
69
+ if len(wandb_project) > 0:
70
+ Global.enable_wandb = True
71
+ Global.wandb_project = wandb_project
72
+
73
  Global.ui_dev_mode = ui_dev_mode
74
  Global.ui_show_sys_info = ui_show_sys_info
75
 
76
  os.makedirs(data_dir, exist_ok=True)
77
  init_data_dir()
78
 
79
+ if (not skip_loading_base_model) and (not ui_dev_mode):
80
+ prepare_base_model(base_model)
81
+
82
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
83
  main_page()
84
 
download_base_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+
3
+ from llama_lora.models import get_new_base_model, clear_cache
4
+
5
+
6
+ def main(
7
+ base_model_names: str = "",
8
+ ):
9
+ '''
10
+ Download and cache base models form Hugging Face.
11
+
12
+ :param base_model_names: Names of the base model you want to download, seperated by ",". For example: 'decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'.
13
+ '''
14
+
15
+ assert (
16
+ base_model_names
17
+ ), "Please specify --base_model_names, e.g. --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'"
18
+
19
+ base_model_names = base_model_names.split(',')
20
+ base_model_names = [name.strip() for name in base_model_names]
21
+
22
+ print(f"Base models: {', '.join(base_model_names)}.")
23
+
24
+ for name in base_model_names:
25
+ print(f"Preparing {name}...")
26
+ get_new_base_model(name)
27
+ clear_cache()
28
+
29
+ print("Done.")
30
+
31
+ if __name__ == "__main__":
32
+ fire.Fire(main)
llama_lora/globals.py CHANGED
@@ -17,6 +17,8 @@ class Global:
17
  load_8bit: bool = False
18
 
19
  default_base_model_name: str = ""
 
 
20
 
21
  # Functions
22
  train_fn: Any = train
@@ -40,6 +42,11 @@ class Global:
40
  gpu_total_cores = None # GPU total cores
41
  gpu_total_memory = None
42
 
 
 
 
 
 
43
  # UI related
44
  ui_title: str = "LLaMA-LoRA Tuner"
45
  ui_emoji: str = "🦙🎛️"
 
17
  load_8bit: bool = False
18
 
19
  default_base_model_name: str = ""
20
+ base_model_name: str = ""
21
+ base_model_choices: List[str] = []
22
 
23
  # Functions
24
  train_fn: Any = train
 
42
  gpu_total_cores = None # GPU total cores
43
  gpu_total_memory = None
44
 
45
+ # WandB
46
+ enable_wandb = False
47
+ wandb_api_key = None
48
+ default_wandb_project = "llama-lora-tuner"
49
+
50
  # UI related
51
  ui_title: str = "LLaMA-LoRA Tuner"
52
  ui_emoji: str = "🦙🎛️"
llama_lora/lib/finetune.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
  from typing import Any, List
4
 
5
  import json
@@ -32,7 +33,7 @@ def train(
32
  num_train_epochs: int = 3,
33
  learning_rate: float = 3e-4,
34
  cutoff_len: int = 256,
35
- val_set_size: int = 2000, # TODO: use percentage
36
  # lora hyperparams
37
  lora_r: int = 8,
38
  lora_alpha: int = 16,
@@ -45,13 +46,78 @@ def train(
45
  train_on_inputs: bool = True, # if False, masks out inputs in loss
46
  group_by_length: bool = False, # faster, but produces an odd training loss curve
47
  # either training checkpoint or final adapter
48
- resume_from_checkpoint: str = None,
49
  save_steps: int = 200,
50
  save_total_limit: int = 3,
51
  logging_steps: int = 10,
52
  # logging
53
- callbacks: List[Any] = []
 
 
 
 
 
 
 
 
54
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  if os.path.exists(output_dir):
56
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
57
  raise ValueError(
@@ -138,6 +204,8 @@ def train(
138
 
139
  # If train_dataset_data is a list, convert it to datasets.Dataset
140
  if isinstance(train_dataset_data, list):
 
 
141
  train_dataset_data = Dataset.from_list(train_dataset_data)
142
 
143
  if resume_from_checkpoint:
@@ -158,7 +226,7 @@ def train(
158
  adapters_weights = torch.load(checkpoint_name)
159
  model = set_peft_model_state_dict(model, adapters_weights)
160
  else:
161
- print(f"Checkpoint {checkpoint_name} not found")
162
 
163
  # Be more transparent about the % of trainable params.
164
  model.print_trainable_parameters()
@@ -197,15 +265,15 @@ def train(
197
  optim="adamw_torch",
198
  evaluation_strategy="steps" if val_set_size > 0 else "no",
199
  save_strategy="steps",
200
- eval_steps=200 if val_set_size > 0 else None,
201
  save_steps=save_steps,
202
  output_dir=output_dir,
203
  save_total_limit=save_total_limit,
204
  load_best_model_at_end=True if val_set_size > 0 else False,
205
  ddp_find_unused_parameters=False if ddp else None,
206
  group_by_length=group_by_length,
207
- # report_to="wandb" if use_wandb else None,
208
- # run_name=wandb_run_name if use_wandb else None,
209
  ),
210
  data_collator=transformers.DataCollatorForSeq2Seq(
211
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
@@ -217,24 +285,16 @@ def train(
217
  os.makedirs(output_dir)
218
  with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
219
  json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
220
- with open(os.path.join(output_dir, "finetune_params.json"), 'w') as finetune_params_json_file:
221
- finetune_params = {
222
- 'micro_batch_size': micro_batch_size,
223
- 'gradient_accumulation_steps': gradient_accumulation_steps,
224
- 'num_train_epochs': num_train_epochs,
225
- 'learning_rate': learning_rate,
226
- 'cutoff_len': cutoff_len,
227
- 'lora_r': lora_r,
228
- 'lora_alpha': lora_alpha,
229
- 'lora_dropout': lora_dropout,
230
- 'lora_target_modules': lora_target_modules,
231
- 'train_on_inputs': train_on_inputs,
232
- 'group_by_length': group_by_length,
233
- 'save_steps': save_steps,
234
- 'save_total_limit': save_total_limit,
235
- 'logging_steps': logging_steps,
236
- }
237
- json.dump(finetune_params, finetune_params_json_file, indent=2)
238
 
239
  model.config.use_cache = False
240
 
@@ -261,4 +321,7 @@ def train(
261
  with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
262
  json.dump(train_output, train_output_json_file, indent=2)
263
 
 
 
 
264
  return train_output
 
1
  import os
2
  import sys
3
+ import importlib
4
  from typing import Any, List
5
 
6
  import json
 
33
  num_train_epochs: int = 3,
34
  learning_rate: float = 3e-4,
35
  cutoff_len: int = 256,
36
+ val_set_size: int = 2000,
37
  # lora hyperparams
38
  lora_r: int = 8,
39
  lora_alpha: int = 16,
 
46
  train_on_inputs: bool = True, # if False, masks out inputs in loss
47
  group_by_length: bool = False, # faster, but produces an odd training loss curve
48
  # either training checkpoint or final adapter
49
+ resume_from_checkpoint = None,
50
  save_steps: int = 200,
51
  save_total_limit: int = 3,
52
  logging_steps: int = 10,
53
  # logging
54
+ callbacks: List[Any] = [],
55
+ # wandb params
56
+ wandb_api_key = None,
57
+ wandb_project: str = "",
58
+ wandb_group = None,
59
+ wandb_run_name: str = "",
60
+ wandb_tags: List[str] = [],
61
+ wandb_watch: str = "false", # options: false | gradients | all
62
+ wandb_log_model: str = "true", # options: false | true
63
  ):
64
+ # for logging
65
+ finetune_args = {
66
+ 'micro_batch_size': micro_batch_size,
67
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
68
+ 'num_train_epochs': num_train_epochs,
69
+ 'learning_rate': learning_rate,
70
+ 'cutoff_len': cutoff_len,
71
+ 'val_set_size': val_set_size,
72
+ 'lora_r': lora_r,
73
+ 'lora_alpha': lora_alpha,
74
+ 'lora_dropout': lora_dropout,
75
+ 'lora_target_modules': lora_target_modules,
76
+ 'train_on_inputs': train_on_inputs,
77
+ 'group_by_length': group_by_length,
78
+ 'save_steps': save_steps,
79
+ 'save_total_limit': save_total_limit,
80
+ 'logging_steps': logging_steps,
81
+ }
82
+ if val_set_size and val_set_size > 0:
83
+ finetune_args['val_set_size'] = val_set_size
84
+ if resume_from_checkpoint:
85
+ finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
86
+
87
+ wandb = None
88
+ if wandb_api_key:
89
+ os.environ["WANDB_API_KEY"] = wandb_api_key
90
+
91
+ # wandb: WARNING Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to https://wandb.me/wandb-init.
92
+ # if wandb_project:
93
+ # os.environ["WANDB_PROJECT"] = wandb_project
94
+ # if wandb_run_name:
95
+ # os.environ["WANDB_RUN_NAME"] = wandb_run_name
96
+
97
+ if wandb_watch:
98
+ os.environ["WANDB_WATCH"] = wandb_watch
99
+ if wandb_log_model:
100
+ os.environ["WANDB_LOG_MODEL"] = wandb_log_model
101
+ use_wandb = (wandb_project and len(wandb_project) > 0) or (
102
+ "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
103
+ )
104
+ if use_wandb:
105
+ os.environ['WANDB_MODE'] = "online"
106
+ wandb = importlib.import_module("wandb")
107
+ wandb.init(
108
+ project=wandb_project,
109
+ resume="auto",
110
+ group=wandb_group,
111
+ name=wandb_run_name,
112
+ tags=wandb_tags,
113
+ reinit=True,
114
+ magic=True,
115
+ config={'finetune_args': finetune_args},
116
+ # id=None # used for resuming
117
+ )
118
+ else:
119
+ os.environ['WANDB_MODE'] = "disabled"
120
+
121
  if os.path.exists(output_dir):
122
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
123
  raise ValueError(
 
204
 
205
  # If train_dataset_data is a list, convert it to datasets.Dataset
206
  if isinstance(train_dataset_data, list):
207
+ with open(os.path.join(output_dir, "train_data_samples.json"), 'w') as file:
208
+ json.dump(list(train_dataset_data[:100]), file, indent=2)
209
  train_dataset_data = Dataset.from_list(train_dataset_data)
210
 
211
  if resume_from_checkpoint:
 
226
  adapters_weights = torch.load(checkpoint_name)
227
  model = set_peft_model_state_dict(model, adapters_weights)
228
  else:
229
+ raise ValueError(f"Checkpoint {checkpoint_name} not found")
230
 
231
  # Be more transparent about the % of trainable params.
232
  model.print_trainable_parameters()
 
265
  optim="adamw_torch",
266
  evaluation_strategy="steps" if val_set_size > 0 else "no",
267
  save_strategy="steps",
268
+ eval_steps=save_steps if val_set_size > 0 else None,
269
  save_steps=save_steps,
270
  output_dir=output_dir,
271
  save_total_limit=save_total_limit,
272
  load_best_model_at_end=True if val_set_size > 0 else False,
273
  ddp_find_unused_parameters=False if ddp else None,
274
  group_by_length=group_by_length,
275
+ report_to="wandb" if use_wandb else None,
276
+ run_name=wandb_run_name if use_wandb else None,
277
  ),
278
  data_collator=transformers.DataCollatorForSeq2Seq(
279
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
 
285
  os.makedirs(output_dir)
286
  with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
287
  json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
288
+ with open(os.path.join(output_dir, "finetune_args.json"), 'w') as finetune_args_json_file:
289
+ json.dump(finetune_args, finetune_args_json_file, indent=2)
290
+
291
+ # Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
292
+ # if train_data:
293
+ # with open(os.path.join(output_dir, "train_dataset_samples.json"), 'w') as file:
294
+ # json.dump(list(train_data[:100]), file, indent=2)
295
+ # if val_data:
296
+ # with open(os.path.join(output_dir, "eval_dataset_samples.json"), 'w') as file:
297
+ # json.dump(list(val_data[:100]), file, indent=2)
 
 
 
 
 
 
 
 
298
 
299
  model.config.use_cache = False
300
 
 
321
  with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
322
  json.dump(train_output, train_output_json_file, indent=2)
323
 
324
+ if use_wandb and wandb:
325
+ wandb.finish()
326
+
327
  return train_output
llama_lora/lib/get_device.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_device():
5
+ device ="cpu"
6
+ if torch.cuda.is_available():
7
+ device = "cuda"
8
+
9
+ try:
10
+ if torch.backends.mps.is_available():
11
+ device = "mps"
12
+ except: # noqa: E722
13
+ pass
14
+
15
+ return device
llama_lora/lib/inference.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+
4
+ from .get_device import get_device
5
+ from .streaming_generation_utils import Iteratorize, Stream
6
+
7
+ def generate(
8
+ # model
9
+ model,
10
+ tokenizer,
11
+ # input
12
+ prompt,
13
+ generation_config,
14
+ max_new_tokens,
15
+ stopping_criteria=[],
16
+ # output options
17
+ stream_output=False
18
+ ):
19
+ device = get_device()
20
+
21
+ inputs = tokenizer(prompt, return_tensors="pt")
22
+ input_ids = inputs["input_ids"].to(device)
23
+ generate_params = {
24
+ "input_ids": input_ids,
25
+ "generation_config": generation_config,
26
+ "return_dict_in_generate": True,
27
+ "output_scores": True,
28
+ "max_new_tokens": max_new_tokens,
29
+ "stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
30
+ }
31
+
32
+ skip_special_tokens = True
33
+
34
+ if '/dolly' in tokenizer.name_or_path:
35
+ # dolly has additional_special_tokens as ['### End', '### Instruction:', '### Response:'], skipping them will break the prompter's reply extraction.
36
+ skip_special_tokens = False
37
+ # Ensure generation stops once it generates "### End"
38
+ end_key_token_id = tokenizer.encode("### End")
39
+ end_key_token_id = end_key_token_id[0] # 50277
40
+ if isinstance(generate_params['generation_config'].eos_token_id, str):
41
+ generate_params['generation_config'].eos_token_id = [generate_params['generation_config'].eos_token_id]
42
+ elif not generate_params['generation_config'].eos_token_id:
43
+ generate_params['generation_config'].eos_token_id = []
44
+ generate_params['generation_config'].eos_token_id.append(end_key_token_id)
45
+
46
+ if stream_output:
47
+ # Stream the reply 1 token at a time.
48
+ # This is based on the trick of using 'stopping_criteria' to create an iterator,
49
+ # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
50
+ generation_output = None
51
+
52
+ def generate_with_callback(callback=None, **kwargs):
53
+ nonlocal generation_output
54
+ kwargs["stopping_criteria"].insert(
55
+ 0,
56
+ Stream(callback_func=callback)
57
+ )
58
+ with torch.no_grad():
59
+ generation_output = model.generate(**kwargs)
60
+
61
+ def generate_with_streaming(**kwargs):
62
+ return Iteratorize(
63
+ generate_with_callback, kwargs, callback=None
64
+ )
65
+
66
+ with generate_with_streaming(**generate_params) as generator:
67
+ for output in generator:
68
+ decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
+ yield decoded_output, output
70
+ if output[-1] in [tokenizer.eos_token_id]:
71
+ break
72
+
73
+ if generation_output:
74
+ output = generation_output.sequences[0]
75
+ decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
+ yield decoded_output, output
77
+
78
+ return # early return for stream_output
79
+
80
+ # Without streaming
81
+ with torch.no_grad():
82
+ generation_output = model.generate(**generate_params)
83
+ output = generation_output.sequences[0]
84
+ decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
+ yield decoded_output, output
86
+ return
llama_lora/{utils/callbacks.py → lib/streaming_generation_utils.py} RENAMED
File without changes
llama_lora/models.py CHANGED
@@ -2,25 +2,14 @@ import os
2
  import sys
3
  import gc
4
  import json
 
5
 
6
  import torch
7
- from transformers import LlamaForCausalLM, LlamaTokenizer
8
  from peft import PeftModel
9
 
10
  from .globals import Global
11
-
12
-
13
- def get_device():
14
- if torch.cuda.is_available():
15
- return "cuda"
16
- else:
17
- return "cpu"
18
-
19
- try:
20
- if torch.backends.mps.is_available():
21
- return "mps"
22
- except: # noqa: E722
23
- pass
24
 
25
 
26
  def get_new_base_model(base_model_name):
@@ -41,7 +30,7 @@ def get_new_base_model(base_model_name):
41
  device = get_device()
42
 
43
  if device == "cuda":
44
- model = LlamaForCausalLM.from_pretrained(
45
  base_model_name,
46
  load_in_8bit=Global.load_8bit,
47
  torch_dtype=torch.float16,
@@ -50,19 +39,22 @@ def get_new_base_model(base_model_name):
50
  device_map={'': 0},
51
  )
52
  elif device == "mps":
53
- model = LlamaForCausalLM.from_pretrained(
54
  base_model_name,
55
  device_map={"": device},
56
  torch_dtype=torch.float16,
57
  )
58
  else:
59
- model = LlamaForCausalLM.from_pretrained(
60
  base_model_name, device_map={"": device}, low_cpu_mem_usage=True
61
  )
62
 
63
- model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
64
- model.config.bos_token_id = 1
65
- model.config.eos_token_id = 2
 
 
 
66
 
67
  return model
68
 
@@ -75,7 +67,14 @@ def get_tokenizer(base_model_name):
75
  if loaded_tokenizer:
76
  return loaded_tokenizer
77
 
78
- tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
 
 
 
 
 
 
 
79
  Global.loaded_tokenizers.set(base_model_name, tokenizer)
80
 
81
  return tokenizer
@@ -148,9 +147,10 @@ def get_model(
148
  device_map={"": device},
149
  )
150
 
151
- model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
152
- model.config.bos_token_id = 1
153
- model.config.eos_token_id = 2
 
154
 
155
  if not Global.load_8bit:
156
  model.half() # seems to fix bugs for some users.
 
2
  import sys
3
  import gc
4
  import json
5
+ import re
6
 
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
9
  from peft import PeftModel
10
 
11
  from .globals import Global
12
+ from .lib.get_device import get_device
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def get_new_base_model(base_model_name):
 
30
  device = get_device()
31
 
32
  if device == "cuda":
33
+ model = AutoModelForCausalLM.from_pretrained(
34
  base_model_name,
35
  load_in_8bit=Global.load_8bit,
36
  torch_dtype=torch.float16,
 
39
  device_map={'': 0},
40
  )
41
  elif device == "mps":
42
+ model = AutoModelForCausalLM.from_pretrained(
43
  base_model_name,
44
  device_map={"": device},
45
  torch_dtype=torch.float16,
46
  )
47
  else:
48
+ model = AutoModelForCausalLM.from_pretrained(
49
  base_model_name, device_map={"": device}, low_cpu_mem_usage=True
50
  )
51
 
52
+ tokenizer = get_tokenizer(base_model_name)
53
+
54
+ if re.match("[^/]+/llama", base_model_name):
55
+ model.config.pad_token_id = tokenizer.pad_token_id = 0
56
+ model.config.bos_token_id = tokenizer.bos_token_id = 1
57
+ model.config.eos_token_id = tokenizer.eos_token_id = 2
58
 
59
  return model
60
 
 
67
  if loaded_tokenizer:
68
  return loaded_tokenizer
69
 
70
+ try:
71
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
72
+ except Exception as e:
73
+ if 'LLaMATokenizer' in str(e):
74
+ tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
75
+ else:
76
+ raise e
77
+
78
  Global.loaded_tokenizers.set(base_model_name, tokenizer)
79
 
80
  return tokenizer
 
147
  device_map={"": device},
148
  )
149
 
150
+ if re.match("[^/]+/llama", base_model_name):
151
+ model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
152
+ model.config.bos_token_id = 1
153
+ model.config.eos_token_id = 2
154
 
155
  if not Global.load_8bit:
156
  model.half() # seems to fix bugs for some users.
llama_lora/ui/finetune_ui.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import json
3
  import time
 
 
4
  from datetime import datetime
5
  import gradio as gr
6
  import math
@@ -15,7 +17,8 @@ from ..models import (
15
  from ..utils.data import (
16
  get_available_template_names,
17
  get_available_dataset_names,
18
- get_dataset_content
 
19
  )
20
  from ..utils.prompter import Prompter
21
 
@@ -47,13 +50,16 @@ def reload_selections(current_template, current_dataset):
47
  current_dataset = current_dataset or next(
48
  iter(available_dataset_names), None)
49
 
 
 
50
  return (
51
  gr.Dropdown.update(
52
  choices=available_template_names_with_none,
53
  value=current_template),
54
  gr.Dropdown.update(
55
  choices=available_dataset_names,
56
- value=current_dataset)
 
57
  )
58
 
59
 
@@ -79,56 +85,47 @@ def load_sample_dataset_to_text_input(format):
79
  return gr.Code.update(value=sample_plain_text_value)
80
 
81
 
82
- def process_json_dataset(data, only_first_n_items=None):
83
- if not isinstance(data, list):
84
- raise ValueError("The dataset is not an array of objects.")
85
-
86
- if only_first_n_items is not None:
87
- data = data[:only_first_n_items]
88
-
89
- first_item = get_val_from_arr(data, 0, None)
90
-
91
- if first_item is None:
92
- raise ValueError("The dataset is empty.")
93
- if not isinstance(first_item, dict):
94
- raise ValueError("The dataset is not an array of objects.")
95
-
96
- # Convert OpenAI fine-tuning dataset to LLaMA LoRA style
97
- if "completion" in first_item and "output" not in first_item:
98
- data = [
99
- {"output" if k == "completion" else k: v for k, v in d.items()}
100
- for d in data]
101
- first_item = get_val_from_arr(data, 0, None)
102
-
103
- # Flatten Stanford Alpaca style instances
104
- if "instances" in first_item and isinstance(first_item["instances"], list):
105
- data = [
106
- {"output" if k == "completion" else k: v for k, v in d.items()}
107
- for d in data]
108
- flattened_data = []
109
- for item in data:
110
- for instance in item["instances"]:
111
- d = {k: v for k, v in item.items() if k != "instances"}
112
- d.update(instance)
113
- flattened_data.append(d)
114
- data = flattened_data
115
- first_item = get_val_from_arr(data, 0, None)
116
-
117
- if "output" not in first_item:
118
- raise ValueError(
119
- "The data does not contains an \"output\" or \"completion\".")
120
-
121
- # Put all variables under the "variables" key if it does not exists
122
- if "variables" not in first_item:
123
- data = [
124
- {
125
- "variables":
126
- {k: v for k, v in d.items() if k != "output"},
127
- "output":
128
- d["output"]
129
- }
130
- for d in data
131
- ]
132
  return data
133
 
134
 
@@ -141,72 +138,92 @@ def refresh_preview(
141
  dataset_plain_text_input_variables_separator,
142
  dataset_plain_text_input_and_output_separator,
143
  dataset_plain_text_data_separator,
144
- preview_show_actual_prompt,
145
  ):
146
  try:
147
- max_preview_count = 100
148
  prompter = Prompter(template)
149
  variable_names = prompter.get_variable_names()
150
 
151
- if load_dataset_from == "Text Input":
152
- if dataset_text_format == "JSON":
153
- data = json.loads(dataset_text)
154
- data = process_json_dataset(data)
155
-
156
- elif dataset_text_format == "JSON Lines":
157
- lines = dataset_text.split('\n')
158
- data = []
159
- for i, line in enumerate(lines):
160
- line_number = i + 1
161
- try:
162
- data.append(json.loads(line))
163
- except Exception as e:
164
- raise ValueError(
165
- f"Error parsing JSON on line {line_number}: {e}")
166
-
167
- data = process_json_dataset(data)
168
-
169
- else: # Plain Text
170
- data = parse_plain_text_input(
171
- dataset_text,
172
- (
173
- dataset_plain_text_input_variables_separator or
174
- default_dataset_plain_text_input_variables_separator
175
- ).replace("\\n", "\n"),
176
- (
177
- dataset_plain_text_input_and_output_separator or
178
- default_dataset_plain_text_input_and_output_separator
179
- ).replace("\\n", "\n"),
180
- (
181
- dataset_plain_text_data_separator or
182
- default_dataset_plain_text_data_separator
183
- ).replace("\\n", "\n"),
184
- variable_names
185
- )
186
 
187
- else: # Load dataset from data directory
188
- data = get_dataset_content(dataset_from_data_dir)
189
- data = process_json_dataset(data)
190
 
191
  data_count = len(data)
192
- headers = variable_names
 
193
  preview_data = [
194
- [item['variables'].get(name, "") for name in variable_names]
195
- for item in data[:max_preview_count]
196
  ]
197
 
198
- if preview_show_actual_prompt:
199
- headers = headers + ["Prompt (actual input)"]
200
- rendered = [prompter.generate_prompt(
201
- item['variables']) for item in data[:max_preview_count]]
202
- preview_data = result = [d + [i]
203
- for d, i in zip(preview_data, rendered)]
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- headers = headers + ["Completion (output)"]
206
- preview_data = result = [pd + [d['output']]
207
- for pd, d in zip(preview_data, data[:max_preview_count])]
 
 
208
 
209
- preview_info_message = f"The dataset has a total of {data_count} item(s)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if data_count > max_preview_count:
211
  preview_info_message += f" Previewing the first {max_preview_count}."
212
 
@@ -215,11 +232,22 @@ def refresh_preview(
215
  info_message = "This dataset contains " + info_message
216
  update_message = gr.Markdown.update(info_message, visible=True)
217
 
218
- return gr.Dataframe.update(value={'data': preview_data, 'headers': headers}), gr.Markdown.update(preview_info_message), update_message, update_message
219
  except Exception as e:
220
  update_message = gr.Markdown.update(
221
  f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
222
- return gr.Dataframe.update(value={'data': [], 'headers': []}), gr.Markdown.update("Set the dataset in the \"Prepare\" tab, then preview it here."), update_message, update_message
 
 
 
 
 
 
 
 
 
 
 
223
 
224
 
225
  def parse_plain_text_input(
@@ -258,7 +286,7 @@ def do_train(
258
  dataset_plain_text_data_separator,
259
  # Training Options
260
  max_seq_length,
261
- evaluate_data_percentage,
262
  micro_batch_size,
263
  gradient_accumulation_steps,
264
  epochs,
@@ -268,14 +296,27 @@ def do_train(
268
  lora_alpha,
269
  lora_dropout,
270
  lora_target_modules,
271
- model_name,
272
  save_steps,
273
  save_total_limit,
274
  logging_steps,
 
 
 
275
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
276
  ):
277
  try:
278
- base_model_name = Global.default_base_model_name
 
 
 
 
 
 
 
 
 
 
 
279
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
280
  if os.path.exists(output_dir):
281
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
@@ -288,56 +329,22 @@ def do_train(
288
  unload_models() # Need RAM for training
289
 
290
  prompter = Prompter(template)
291
- variable_names = prompter.get_variable_names()
292
-
293
- if load_dataset_from == "Text Input":
294
- if dataset_text_format == "JSON":
295
- data = json.loads(dataset_text)
296
- data = process_json_dataset(data)
297
-
298
- elif dataset_text_format == "JSON Lines":
299
- lines = dataset_text.split('\n')
300
- data = []
301
- for i, line in enumerate(lines):
302
- line_number = i + 1
303
- try:
304
- data.append(json.loads(line))
305
- except Exception as e:
306
- raise ValueError(
307
- f"Error parsing JSON on line {line_number}: {e}")
308
-
309
- data = process_json_dataset(data)
310
-
311
- else: # Plain Text
312
- data = parse_plain_text_input(
313
- dataset_text,
314
- (
315
- dataset_plain_text_input_variables_separator or
316
- default_dataset_plain_text_input_variables_separator
317
- ).replace("\\n", "\n"),
318
- (
319
- dataset_plain_text_input_and_output_separator or
320
- default_dataset_plain_text_input_and_output_separator
321
- ).replace("\\n", "\n"),
322
- (
323
- dataset_plain_text_data_separator or
324
- default_dataset_plain_text_data_separator
325
- ).replace("\\n", "\n"),
326
- variable_names
327
- )
328
-
329
- else: # Load dataset from data directory
330
- data = get_dataset_content(dataset_from_data_dir)
331
- data = process_json_dataset(data)
332
 
333
- data_count = len(data)
334
- evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
335
 
336
- train_data = [
337
- {
338
- 'prompt': prompter.generate_prompt(d['variables']),
339
- 'completion': d['output']}
340
- for d in data]
341
 
342
  def get_progress_text(epoch, epochs, last_loss):
343
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
@@ -380,6 +387,8 @@ Train options: {json.dumps({
380
  'lora_dropout': lora_dropout,
381
  'lora_target_modules': lora_target_modules,
382
  'model_name': model_name,
 
 
383
  }, indent=2)}
384
 
385
  Train data (first 10):
@@ -390,7 +399,7 @@ Train data (first 10):
390
  return message
391
 
392
  if not should_training_progress_track_tqdm:
393
- progress(0, desc="Preparing model for training...")
394
 
395
  log_history = []
396
 
@@ -449,26 +458,37 @@ Train data (first 10):
449
  'dataset_rows': len(train_data),
450
  'timestamp': time.time(),
451
 
452
- 'max_seq_length': max_seq_length,
453
- 'train_on_inputs': train_on_inputs,
 
454
 
455
- 'micro_batch_size': micro_batch_size,
456
- 'gradient_accumulation_steps': gradient_accumulation_steps,
457
- 'epochs': epochs,
458
- 'learning_rate': learning_rate,
459
 
460
- 'evaluate_data_percentage': evaluate_data_percentage,
461
 
462
- 'lora_r': lora_r,
463
- 'lora_alpha': lora_alpha,
464
- 'lora_dropout': lora_dropout,
465
- 'lora_target_modules': lora_target_modules,
466
  }
 
 
 
 
467
  json.dump(info, info_json_file, indent=2)
468
 
469
  if not should_training_progress_track_tqdm:
470
  progress(0, desc="Train starting...")
471
 
 
 
 
 
 
 
472
  train_output = Global.train_fn(
473
  base_model, # base_model
474
  tokenizer, # tokenizer
@@ -487,11 +507,16 @@ Train data (first 10):
487
  lora_target_modules, # lora_target_modules
488
  train_on_inputs, # train_on_inputs
489
  False, # group_by_length
490
- None, # resume_from_checkpoint
491
  save_steps, # save_steps
492
  save_total_limit, # save_total_limit
493
  logging_steps, # logging_steps
494
- training_callbacks # callbacks
 
 
 
 
 
495
  )
496
 
497
  logs_str = "\n".join([json.dumps(log)
@@ -515,6 +540,146 @@ def do_abort_training():
515
  Global.should_stop_training = True
516
 
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  def finetune_ui():
519
  things_that_might_timeout = []
520
 
@@ -606,9 +771,13 @@ def finetune_ui():
606
  "Set the dataset in the \"Prepare\" tab, then preview it here.",
607
  elem_id="finetune_dataset_preview_info_message"
608
  )
609
- finetune_dataset_preview_show_actual_prompt = gr.Checkbox(
610
- label="Show actual prompt",
611
- elem_id="finetune_dataset_preview_show_actual_prompt"
 
 
 
 
612
  )
613
  finetune_dataset_preview = gr.Dataframe(
614
  wrap=True, elem_id="finetune_dataset_preview")
@@ -633,25 +802,7 @@ def finetune_ui():
633
  dataset_plain_text_data_separator,
634
  ]
635
  dataset_preview_inputs = dataset_inputs + \
636
- [finetune_dataset_preview_show_actual_prompt]
637
- for i in dataset_preview_inputs:
638
- things_that_might_timeout.append(
639
- i.change(
640
- fn=refresh_preview,
641
- inputs=dataset_preview_inputs,
642
- outputs=[finetune_dataset_preview,
643
- finetune_dataset_preview_info_message,
644
- dataset_from_text_message,
645
- dataset_from_data_dir_message
646
- ]
647
- ))
648
-
649
- things_that_might_timeout.append(reload_selections_button.click(
650
- reload_selections,
651
- inputs=[template, dataset_from_data_dir],
652
- outputs=[template, dataset_from_data_dir],
653
- )
654
- )
655
 
656
  with gr.Row():
657
  max_seq_length = gr.Slider(
@@ -704,12 +855,43 @@ def finetune_ui():
704
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
705
  )
706
 
707
- evaluate_data_percentage = gr.Slider(
708
- minimum=0, maximum=0.5, step=0.001, value=0,
709
- label="Evaluation Data Percentage",
710
- info="The percentage of data to be used for evaluation. This percentage of data will not be used for training and will be used to assess the performance of the model during the process."
711
  )
712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  with gr.Column():
714
  lora_r = gr.Slider(
715
  minimum=1, maximum=16, step=1, value=8,
@@ -729,12 +911,31 @@ def finetune_ui():
729
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
730
  )
731
 
 
 
732
  lora_target_modules = gr.CheckboxGroup(
733
  label="LoRA Target Modules",
734
- choices=["q_proj", "k_proj", "v_proj", "o_proj"],
735
  value=["q_proj", "v_proj"],
736
- info="Modules to replace with LoRA."
 
737
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
 
739
  with gr.Row():
740
  logging_steps = gr.Number(
@@ -759,6 +960,7 @@ def finetune_ui():
759
  with gr.Column():
760
  model_name = gr.Textbox(
761
  lines=1, label="LoRA Model Name", value=random_name,
 
762
  info="The name of the new LoRA model.",
763
  elem_id="finetune_model_name",
764
  )
@@ -778,6 +980,59 @@ def finetune_ui():
778
  elem_id="finetune_confirm_stop_btn"
779
  )
780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  train_output = gr.Text(
782
  "Training results will be shown here.",
783
  label="Train Output",
@@ -785,22 +1040,10 @@ def finetune_ui():
785
 
786
  train_progress = train_btn.click(
787
  fn=do_train,
788
- inputs=(dataset_inputs + [
789
- max_seq_length,
790
- evaluate_data_percentage,
791
- micro_batch_size,
792
- gradient_accumulation_steps,
793
- epochs,
794
- learning_rate,
795
- train_on_inputs,
796
- lora_r,
797
- lora_alpha,
798
- lora_dropout,
799
- lora_target_modules,
800
  model_name,
801
- save_steps,
802
- save_total_limit,
803
- logging_steps,
804
  ]),
805
  outputs=train_output
806
  )
 
1
  import os
2
  import json
3
  import time
4
+ import traceback
5
+ import re
6
  from datetime import datetime
7
  import gradio as gr
8
  import math
 
17
  from ..utils.data import (
18
  get_available_template_names,
19
  get_available_dataset_names,
20
+ get_dataset_content,
21
+ get_available_lora_model_names
22
  )
23
  from ..utils.prompter import Prompter
24
 
 
50
  current_dataset = current_dataset or next(
51
  iter(available_dataset_names), None)
52
 
53
+ available_lora_models = ["-"] + get_available_lora_model_names()
54
+
55
  return (
56
  gr.Dropdown.update(
57
  choices=available_template_names_with_none,
58
  value=current_template),
59
  gr.Dropdown.update(
60
  choices=available_dataset_names,
61
+ value=current_dataset),
62
+ gr.Dropdown.update(choices=available_lora_models)
63
  )
64
 
65
 
 
85
  return gr.Code.update(value=sample_plain_text_value)
86
 
87
 
88
+ def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
89
+ dataset_plain_text_input_variables_separator,
90
+ dataset_plain_text_input_and_output_separator,
91
+ dataset_plain_text_data_separator,
92
+ dataset_from_data_dir, prompter):
93
+ if load_dataset_from == "Text Input":
94
+ if dataset_text_format == "JSON":
95
+ data = json.loads(dataset_text)
96
+
97
+ elif dataset_text_format == "JSON Lines":
98
+ lines = dataset_text.split('\n')
99
+ data = []
100
+ for i, line in enumerate(lines):
101
+ line_number = i + 1
102
+ try:
103
+ data.append(json.loads(line))
104
+ except Exception as e:
105
+ raise ValueError(
106
+ f"Error parsing JSON on line {line_number}: {e}")
107
+
108
+ else: # Plain Text
109
+ data = parse_plain_text_input(
110
+ dataset_text,
111
+ (
112
+ dataset_plain_text_input_variables_separator or
113
+ default_dataset_plain_text_input_variables_separator
114
+ ).replace("\\n", "\n"),
115
+ (
116
+ dataset_plain_text_input_and_output_separator or
117
+ default_dataset_plain_text_input_and_output_separator
118
+ ).replace("\\n", "\n"),
119
+ (
120
+ dataset_plain_text_data_separator or
121
+ default_dataset_plain_text_data_separator
122
+ ).replace("\\n", "\n"),
123
+ prompter.get_variable_names()
124
+ )
125
+
126
+ else: # Load dataset from data directory
127
+ data = get_dataset_content(dataset_from_data_dir)
128
+
 
 
 
 
 
 
 
 
 
129
  return data
130
 
131
 
 
138
  dataset_plain_text_input_variables_separator,
139
  dataset_plain_text_input_and_output_separator,
140
  dataset_plain_text_data_separator,
141
+ max_preview_count,
142
  ):
143
  try:
 
144
  prompter = Prompter(template)
145
  variable_names = prompter.get_variable_names()
146
 
147
+ data = get_data_from_input(
148
+ load_dataset_from=load_dataset_from,
149
+ dataset_text=dataset_text,
150
+ dataset_text_format=dataset_text_format,
151
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
152
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
153
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
154
+ dataset_from_data_dir=dataset_from_data_dir,
155
+ prompter=prompter
156
+ )
157
+
158
+ train_data = prompter.get_train_data_from_dataset(
159
+ data, max_preview_count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ train_data = train_data[:max_preview_count]
 
 
162
 
163
  data_count = len(data)
164
+
165
+ headers = ['Prompt', 'Completion']
166
  preview_data = [
167
+ [item.get("prompt", ""), item.get("completion", "")]
168
+ for item in train_data
169
  ]
170
 
171
+ if not prompter.template_module:
172
+ variable_names = prompter.get_variable_names()
173
+ headers += [f"Variable: {variable_name}" for variable_name in variable_names]
174
+ variables = [
175
+ [item.get(f"_var_{name}", "") for name in variable_names]
176
+ for item in train_data
177
+ ]
178
+ preview_data = [d + v for d, v in zip(preview_data, variables)]
179
+
180
+ preview_info_message = f"The dataset has about {data_count} item(s)."
181
+ if data_count > max_preview_count:
182
+ preview_info_message += f" Previewing the first {max_preview_count}."
183
+
184
+ info_message = f"about {data_count} item(s)."
185
+ if load_dataset_from == "Data Dir":
186
+ info_message = "This dataset contains about " + info_message
187
+ update_message = gr.Markdown.update(info_message, visible=True)
188
 
189
+ return gr.Dataframe.update(value={'data': preview_data, 'headers': headers}), gr.Markdown.update(preview_info_message), update_message, update_message
190
+ except Exception as e:
191
+ update_message = gr.Markdown.update(
192
+ f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
193
+ return gr.Dataframe.update(value={'data': [], 'headers': []}), gr.Markdown.update("Set the dataset in the \"Prepare\" tab, then preview it here."), update_message, update_message
194
 
195
+
196
+ def refresh_dataset_items_count(
197
+ template,
198
+ load_dataset_from,
199
+ dataset_from_data_dir,
200
+ dataset_text,
201
+ dataset_text_format,
202
+ dataset_plain_text_input_variables_separator,
203
+ dataset_plain_text_input_and_output_separator,
204
+ dataset_plain_text_data_separator,
205
+ max_preview_count,
206
+ ):
207
+ try:
208
+ prompter = Prompter(template)
209
+ variable_names = prompter.get_variable_names()
210
+
211
+ data = get_data_from_input(
212
+ load_dataset_from=load_dataset_from,
213
+ dataset_text=dataset_text,
214
+ dataset_text_format=dataset_text_format,
215
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
216
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
217
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
218
+ dataset_from_data_dir=dataset_from_data_dir,
219
+ prompter=prompter
220
+ )
221
+
222
+ train_data = prompter.get_train_data_from_dataset(
223
+ data)
224
+ data_count = len(train_data)
225
+
226
+ preview_info_message = f"The dataset contains {data_count} item(s)."
227
  if data_count > max_preview_count:
228
  preview_info_message += f" Previewing the first {max_preview_count}."
229
 
 
232
  info_message = "This dataset contains " + info_message
233
  update_message = gr.Markdown.update(info_message, visible=True)
234
 
235
+ return gr.Markdown.update(preview_info_message), update_message, update_message, gr.Slider.update(maximum=math.floor(data_count / 2))
236
  except Exception as e:
237
  update_message = gr.Markdown.update(
238
  f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
239
+
240
+ trace = traceback.format_exc()
241
+ traces = [s.strip() for s in re.split("\n * File ", trace)]
242
+ templates_path = os.path.join(Global.data_dir, "templates")
243
+ traces_to_show = [s for s in traces if os.path.join(
244
+ Global.data_dir, "templates") in s]
245
+ traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
246
+ if len(traces_to_show) > 0:
247
+ update_message = gr.Markdown.update(
248
+ f"<span class=\"finetune_dataset_error_message\">Error: {e} ({','.join(traces_to_show)}).</span>", visible=True)
249
+
250
+ return gr.Markdown.update("Set the dataset in the \"Prepare\" tab, then preview it here."), update_message, update_message, gr.Slider.update(maximum=1)
251
 
252
 
253
  def parse_plain_text_input(
 
286
  dataset_plain_text_data_separator,
287
  # Training Options
288
  max_seq_length,
289
+ evaluate_data_count,
290
  micro_batch_size,
291
  gradient_accumulation_steps,
292
  epochs,
 
296
  lora_alpha,
297
  lora_dropout,
298
  lora_target_modules,
 
299
  save_steps,
300
  save_total_limit,
301
  logging_steps,
302
+ model_name,
303
+ continue_from_model,
304
+ continue_from_checkpoint,
305
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
306
  ):
307
  try:
308
+ base_model_name = Global.base_model_name
309
+
310
+ resume_from_checkpoint = None
311
+ if continue_from_model == "-" or continue_from_model == "None":
312
+ continue_from_model = None
313
+ if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
314
+ continue_from_checkpoint = None
315
+ if continue_from_model:
316
+ resume_from_checkpoint = os.path.join(Global.data_dir, "lora_models", continue_from_model)
317
+ if continue_from_checkpoint:
318
+ resume_from_checkpoint = os.path.join(resume_from_checkpoint, continue_from_checkpoint)
319
+
320
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
321
  if os.path.exists(output_dir):
322
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
 
329
  unload_models() # Need RAM for training
330
 
331
  prompter = Prompter(template)
332
+ # variable_names = prompter.get_variable_names()
333
+
334
+ data = get_data_from_input(
335
+ load_dataset_from=load_dataset_from,
336
+ dataset_text=dataset_text,
337
+ dataset_text_format=dataset_text_format,
338
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
339
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
340
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
341
+ dataset_from_data_dir=dataset_from_data_dir,
342
+ prompter=prompter
343
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
+ train_data = prompter.get_train_data_from_dataset(data)
 
346
 
347
+ data_count = len(train_data)
 
 
 
 
348
 
349
  def get_progress_text(epoch, epochs, last_loss):
350
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
 
387
  'lora_dropout': lora_dropout,
388
  'lora_target_modules': lora_target_modules,
389
  'model_name': model_name,
390
+ 'continue_from_model': continue_from_model,
391
+ 'continue_from_checkpoint': continue_from_checkpoint,
392
  }, indent=2)}
393
 
394
  Train data (first 10):
 
399
  return message
400
 
401
  if not should_training_progress_track_tqdm:
402
+ progress(0, desc=f"Preparing model {base_model_name} for training...")
403
 
404
  log_history = []
405
 
 
458
  'dataset_rows': len(train_data),
459
  'timestamp': time.time(),
460
 
461
+ # These will be saved in another JSON file by the train function
462
+ # 'max_seq_length': max_seq_length,
463
+ # 'train_on_inputs': train_on_inputs,
464
 
465
+ # 'micro_batch_size': micro_batch_size,
466
+ # 'gradient_accumulation_steps': gradient_accumulation_steps,
467
+ # 'epochs': epochs,
468
+ # 'learning_rate': learning_rate,
469
 
470
+ # 'evaluate_data_count': evaluate_data_count,
471
 
472
+ # 'lora_r': lora_r,
473
+ # 'lora_alpha': lora_alpha,
474
+ # 'lora_dropout': lora_dropout,
475
+ # 'lora_target_modules': lora_target_modules,
476
  }
477
+ if continue_from_model:
478
+ info['continued_from_model'] = continue_from_model
479
+ if continue_from_checkpoint:
480
+ info['continued_from_checkpoint'] = continue_from_checkpoint
481
  json.dump(info, info_json_file, indent=2)
482
 
483
  if not should_training_progress_track_tqdm:
484
  progress(0, desc="Train starting...")
485
 
486
+ wandb_group = template
487
+ wandb_tags = [f"template:{template}"]
488
+ if load_dataset_from == "Data Dir" and dataset_from_data_dir:
489
+ wandb_group += f"/{dataset_from_data_dir}"
490
+ wandb_tags.append(f"dataset:{dataset_from_data_dir}")
491
+
492
  train_output = Global.train_fn(
493
  base_model, # base_model
494
  tokenizer, # tokenizer
 
507
  lora_target_modules, # lora_target_modules
508
  train_on_inputs, # train_on_inputs
509
  False, # group_by_length
510
+ resume_from_checkpoint, # resume_from_checkpoint
511
  save_steps, # save_steps
512
  save_total_limit, # save_total_limit
513
  logging_steps, # logging_steps
514
+ training_callbacks, # callbacks
515
+ Global.wandb_api_key, # wandb_api_key
516
+ Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
517
+ wandb_group, # wandb_group
518
+ model_name, # wandb_run_name
519
+ wandb_tags # wandb_tags
520
  )
521
 
522
  logs_str = "\n".join([json.dumps(log)
 
540
  Global.should_stop_training = True
541
 
542
 
543
+ def handle_continue_from_model_change(model_name):
544
+ try:
545
+ lora_models_directory_path = os.path.join(
546
+ Global.data_dir, "lora_models")
547
+ lora_model_directory_path = os.path.join(
548
+ lora_models_directory_path, model_name)
549
+ all_files = os.listdir(lora_model_directory_path)
550
+ checkpoints = [
551
+ file for file in all_files if file.startswith("checkpoint-")]
552
+ checkpoints = ["-"] + checkpoints
553
+ can_load_params = "finetune_params.json" in all_files or "finetune_args.json" in all_files
554
+ return gr.Dropdown.update(choices=checkpoints, value="-"), gr.Button.update(visible=can_load_params), gr.Markdown.update(value="", visible=False)
555
+ except Exception:
556
+ pass
557
+ return gr.Dropdown.update(choices=["-"], value="-"), gr.Button.update(visible=False), gr.Markdown.update(value="", visible=False)
558
+
559
+
560
+ def handle_load_params_from_model(
561
+ model_name,
562
+ max_seq_length,
563
+ evaluate_data_count,
564
+ micro_batch_size,
565
+ gradient_accumulation_steps,
566
+ epochs,
567
+ learning_rate,
568
+ train_on_inputs,
569
+ lora_r,
570
+ lora_alpha,
571
+ lora_dropout,
572
+ lora_target_modules,
573
+ save_steps,
574
+ save_total_limit,
575
+ logging_steps,
576
+ lora_target_module_choices,
577
+ ):
578
+ error_message = ""
579
+ notice_message = ""
580
+ unknown_keys = []
581
+ try:
582
+ lora_models_directory_path = os.path.join(
583
+ Global.data_dir, "lora_models")
584
+ lora_model_directory_path = os.path.join(
585
+ lora_models_directory_path, model_name)
586
+
587
+ data = {}
588
+ possible_files = ["finetune_params.json", "finetune_args.json"]
589
+ for file in possible_files:
590
+ try:
591
+ with open(os.path.join(lora_model_directory_path, file), "r") as f:
592
+ data = json.load(f)
593
+ except FileNotFoundError:
594
+ pass
595
+
596
+ for key, value in data.items():
597
+ if key == "max_seq_length":
598
+ max_seq_length = value
599
+ if key == "cutoff_len":
600
+ cutoff_len = value
601
+ elif key == "evaluate_data_count":
602
+ evaluate_data_count = value
603
+ elif key == "val_set_size":
604
+ evaluate_data_count = value
605
+ elif key == "micro_batch_size":
606
+ micro_batch_size = value
607
+ elif key == "gradient_accumulation_steps":
608
+ gradient_accumulation_steps = value
609
+ elif key == "epochs":
610
+ epochs = value
611
+ elif key == "num_train_epochs":
612
+ epochs = value
613
+ elif key == "learning_rate":
614
+ learning_rate = value
615
+ elif key == "train_on_inputs":
616
+ train_on_inputs = value
617
+ elif key == "lora_r":
618
+ lora_r = value
619
+ elif key == "lora_alpha":
620
+ lora_alpha = value
621
+ elif key == "lora_dropout":
622
+ lora_dropout = value
623
+ elif key == "lora_target_modules":
624
+ lora_target_modules = value
625
+ for element in value:
626
+ if element not in lora_target_module_choices:
627
+ lora_target_module_choices.append(element)
628
+ elif key == "save_steps":
629
+ save_steps = value
630
+ elif key == "save_total_limit":
631
+ save_total_limit = value
632
+ elif key == "logging_steps":
633
+ logging_steps = value
634
+ elif key == "group_by_length":
635
+ pass
636
+ elif key == "resume_from_checkpoint":
637
+ pass
638
+ else:
639
+ unknown_keys.append(key)
640
+ except Exception as e:
641
+ error_message = str(e)
642
+
643
+ if len(unknown_keys) > 0:
644
+ notice_message = f"Note: cannot restore unknown arg: {', '.join([f'`{x}`' for x in unknown_keys])}"
645
+
646
+ message = ". ".join([x for x in [error_message, notice_message] if x])
647
+
648
+ has_message = False
649
+ if message:
650
+ message += "."
651
+ has_message = True
652
+
653
+ return (
654
+ gr.Markdown.update(value=message, visible=has_message),
655
+ max_seq_length,
656
+ evaluate_data_count,
657
+ micro_batch_size,
658
+ gradient_accumulation_steps,
659
+ epochs,
660
+ learning_rate,
661
+ train_on_inputs,
662
+ lora_r,
663
+ lora_alpha,
664
+ lora_dropout,
665
+ gr.CheckboxGroup.update(value=lora_target_modules, choices=lora_target_module_choices),
666
+ save_steps,
667
+ save_total_limit,
668
+ logging_steps,
669
+ lora_target_module_choices,
670
+ )
671
+
672
+
673
+ default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
674
+
675
+
676
+ def handle_lora_target_modules_add(choices, new_module, selected_modules):
677
+ choices.append(new_module)
678
+ selected_modules.append(new_module)
679
+
680
+ return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
681
+
682
+
683
  def finetune_ui():
684
  things_that_might_timeout = []
685
 
 
771
  "Set the dataset in the \"Prepare\" tab, then preview it here.",
772
  elem_id="finetune_dataset_preview_info_message"
773
  )
774
+ finetune_dataset_preview_count = gr.Number(
775
+ label="Preview items count",
776
+ value=10,
777
+ # minimum=1,
778
+ # maximum=100,
779
+ precision=0,
780
+ elem_id="finetune_dataset_preview_count"
781
  )
782
  finetune_dataset_preview = gr.Dataframe(
783
  wrap=True, elem_id="finetune_dataset_preview")
 
802
  dataset_plain_text_data_separator,
803
  ]
804
  dataset_preview_inputs = dataset_inputs + \
805
+ [finetune_dataset_preview_count]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
 
807
  with gr.Row():
808
  max_seq_length = gr.Slider(
 
855
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
856
  )
857
 
858
+ evaluate_data_count = gr.Slider(
859
+ minimum=0, maximum=1, step=1, value=0,
860
+ label="Evaluation Data Count",
861
+ info="The number of data to be used for evaluation. This amount of data will not be used for training and will be used to assess the performance of the model during the process."
862
  )
863
 
864
+ with gr.Box(elem_id="finetune_continue_from_model_box"):
865
+ with gr.Row():
866
+ continue_from_model = gr.Dropdown(
867
+ value="-",
868
+ label="Continue from Model",
869
+ choices=["-"],
870
+ elem_id="finetune_continue_from_model"
871
+ )
872
+ continue_from_checkpoint = gr.Dropdown(
873
+ value="-", label="Checkpoint", choices=["-"])
874
+ with gr.Column():
875
+ load_params_from_model_btn = gr.Button(
876
+ "Load training parameters from selected model", visible=False)
877
+ load_params_from_model_btn.style(
878
+ full_width=False,
879
+ size="sm")
880
+ load_params_from_model_message = gr.Markdown(
881
+ "", visible=False)
882
+
883
+ things_that_might_timeout.append(
884
+ continue_from_model.change(
885
+ fn=handle_continue_from_model_change,
886
+ inputs=[continue_from_model],
887
+ outputs=[
888
+ continue_from_checkpoint,
889
+ load_params_from_model_btn,
890
+ load_params_from_model_message
891
+ ]
892
+ )
893
+ )
894
+
895
  with gr.Column():
896
  lora_r = gr.Slider(
897
  minimum=1, maximum=16, step=1, value=8,
 
911
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
912
  )
913
 
914
+ lora_target_module_choices = gr.State(value=default_lora_target_module_choices)
915
+
916
  lora_target_modules = gr.CheckboxGroup(
917
  label="LoRA Target Modules",
918
+ choices=default_lora_target_module_choices,
919
  value=["q_proj", "v_proj"],
920
+ info="Modules to replace with LoRA.",
921
+ elem_id="finetune_lora_target_modules"
922
  )
923
+ with gr.Box(elem_id="finetune_lora_target_modules_add_box"):
924
+ with gr.Row():
925
+ lora_target_modules_add = gr.Textbox(
926
+ lines=1, max_lines=1, show_label=False,
927
+ elem_id="finetune_lora_target_modules_add"
928
+ )
929
+ lora_target_modules_add_btn = gr.Button(
930
+ "Add",
931
+ elem_id="finetune_lora_target_modules_add_btn"
932
+ )
933
+ lora_target_modules_add_btn.style(full_width=False, size="sm")
934
+ things_that_might_timeout.append(lora_target_modules_add_btn.click(
935
+ handle_lora_target_modules_add,
936
+ inputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
937
+ outputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
938
+ ))
939
 
940
  with gr.Row():
941
  logging_steps = gr.Number(
 
960
  with gr.Column():
961
  model_name = gr.Textbox(
962
  lines=1, label="LoRA Model Name", value=random_name,
963
+ max_lines=1,
964
  info="The name of the new LoRA model.",
965
  elem_id="finetune_model_name",
966
  )
 
980
  elem_id="finetune_confirm_stop_btn"
981
  )
982
 
983
+ things_that_might_timeout.append(reload_selections_button.click(
984
+ reload_selections,
985
+ inputs=[template, dataset_from_data_dir],
986
+ outputs=[template, dataset_from_data_dir, continue_from_model],
987
+ ))
988
+
989
+ for i in dataset_preview_inputs:
990
+ things_that_might_timeout.append(
991
+ i.change(
992
+ fn=refresh_preview,
993
+ inputs=dataset_preview_inputs,
994
+ outputs=[
995
+ finetune_dataset_preview,
996
+ finetune_dataset_preview_info_message,
997
+ dataset_from_text_message,
998
+ dataset_from_data_dir_message
999
+ ]
1000
+ ).then(
1001
+ fn=refresh_dataset_items_count,
1002
+ inputs=dataset_preview_inputs,
1003
+ outputs=[
1004
+ finetune_dataset_preview_info_message,
1005
+ dataset_from_text_message,
1006
+ dataset_from_data_dir_message,
1007
+ evaluate_data_count,
1008
+ ]
1009
+ ))
1010
+
1011
+ finetune_args = [
1012
+ max_seq_length,
1013
+ evaluate_data_count,
1014
+ micro_batch_size,
1015
+ gradient_accumulation_steps,
1016
+ epochs,
1017
+ learning_rate,
1018
+ train_on_inputs,
1019
+ lora_r,
1020
+ lora_alpha,
1021
+ lora_dropout,
1022
+ lora_target_modules,
1023
+ save_steps,
1024
+ save_total_limit,
1025
+ logging_steps,
1026
+ ]
1027
+
1028
+ things_that_might_timeout.append(
1029
+ load_params_from_model_btn.click(
1030
+ fn=handle_load_params_from_model,
1031
+ inputs=[continue_from_model] + finetune_args + [lora_target_module_choices],
1032
+ outputs=[load_params_from_model_message] + finetune_args + [lora_target_module_choices]
1033
+ )
1034
+ )
1035
+
1036
  train_output = gr.Text(
1037
  "Training results will be shown here.",
1038
  label="Train Output",
 
1040
 
1041
  train_progress = train_btn.click(
1042
  fn=do_train,
1043
+ inputs=(dataset_inputs + finetune_args + [
 
 
 
 
 
 
 
 
 
 
 
1044
  model_name,
1045
+ continue_from_model,
1046
+ continue_from_checkpoint,
 
1047
  ]),
1048
  outputs=train_output
1049
  )
llama_lora/ui/inference_ui.py CHANGED
@@ -8,12 +8,12 @@ from transformers import GenerationConfig
8
 
9
  from ..globals import Global
10
  from ..models import get_model, get_tokenizer, get_device
 
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
  get_info_of_available_lora_model)
15
  from ..utils.prompter import Prompter
16
- from ..utils.callbacks import Iteratorize, Stream
17
 
18
  device = get_device()
19
 
@@ -22,7 +22,7 @@ inference_output_lines = 12
22
 
23
 
24
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
25
- base_model_name = Global.default_base_model_name
26
 
27
  try:
28
  get_tokenizer(base_model_name)
@@ -48,7 +48,7 @@ def do_inference(
48
  show_raw=False,
49
  progress=gr.Progress(track_tqdm=True),
50
  ):
51
- base_model_name = Global.default_base_model_name
52
 
53
  try:
54
  if Global.generation_force_stopped_at is not None:
@@ -103,8 +103,6 @@ def do_inference(
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
- inputs = tokenizer(prompt, return_tensors="pt")
107
- input_ids = inputs["input_ids"].to(device)
108
  generation_config = GenerationConfig(
109
  temperature=temperature,
110
  top_p=top_p,
@@ -113,103 +111,55 @@ def do_inference(
113
  num_beams=num_beams,
114
  )
115
 
116
- generate_params = {
117
- "input_ids": input_ids,
118
- "generation_config": generation_config,
119
- "return_dict_in_generate": True,
120
- "output_scores": True,
121
- "max_new_tokens": max_new_tokens,
122
- }
123
-
124
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
125
  if Global.should_stop_generating:
126
  return True
127
  return False
128
 
129
  Global.should_stop_generating = False
130
- generate_params.setdefault(
131
- "stopping_criteria", transformers.StoppingCriteriaList()
132
- )
133
- generate_params["stopping_criteria"].append(
134
- ui_generation_stopping_criteria
135
- )
136
-
137
- if stream_output:
138
- # Stream the reply 1 token at a time.
139
- # This is based on the trick of using 'stopping_criteria' to create an iterator,
140
- # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
141
-
142
- def generate_with_callback(callback=None, **kwargs):
143
- kwargs.setdefault(
144
- "stopping_criteria", transformers.StoppingCriteriaList()
145
- )
146
- kwargs["stopping_criteria"].append(
147
- Stream(callback_func=callback)
148
- )
149
- with torch.no_grad():
150
- model.generate(**kwargs)
151
-
152
- def generate_with_streaming(**kwargs):
153
- return Iteratorize(
154
- generate_with_callback, kwargs, callback=None
155
- )
156
 
157
- with generate_with_streaming(**generate_params) as generator:
158
- for output in generator:
159
- # new_tokens = len(output) - len(input_ids[0])
160
- decoded_output = tokenizer.decode(output)
161
-
162
- if output[-1] in [tokenizer.eos_token_id]:
163
- break
164
-
165
- raw_output = None
166
- if show_raw:
167
- raw_output = str(output)
168
- response = prompter.get_response(decoded_output)
169
 
170
- if Global.should_stop_generating:
171
- return
 
 
 
172
 
173
- yield (
174
- gr.Textbox.update(
175
- value=response, lines=inference_output_lines),
176
- raw_output)
177
-
178
- if Global.should_stop_generating:
179
- # If the user stops the generation, and then clicks the
180
- # generation button again, they may mysteriously landed
181
- # here, in the previous, should-be-stopped generation
182
- # function call, with the new generation function not be
183
- # called at all. To workaround this, we yield a message
184
- # and setting lines=1, and if the front-end JS detects
185
- # that lines has been set to 1 (rows="1" in HTML),
186
- # it will automatically click the generate button again
187
- # (gr.Textbox.update() does not support updating
188
- # elem_classes or elem_id).
189
- # [WORKAROUND-UI01]
190
- yield (
191
- gr.Textbox.update(
192
- value="Please retry", lines=1),
193
- None)
194
- return # early return for stream_output
195
-
196
- # Without streaming
197
- with torch.no_grad():
198
- generation_output = model.generate(**generate_params)
199
- s = generation_output.sequences[0]
200
- output = tokenizer.decode(s)
201
- raw_output = None
202
- if show_raw:
203
- raw_output = str(s)
204
-
205
- response = prompter.get_response(output)
206
- if Global.should_stop_generating:
207
- return
208
 
209
- yield (
210
- gr.Textbox.update(value=response, lines=inference_output_lines),
211
- raw_output)
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  except Exception as e:
214
  raise gr.Error(e)
215
 
@@ -229,7 +179,7 @@ def reload_selections(current_lora_model, current_prompt_template):
229
  current_prompt_template = current_prompt_template or next(
230
  iter(available_template_names_with_none), None)
231
 
232
- default_lora_models = ["tloen/alpaca-lora-7b"]
233
  available_lora_models = default_lora_models + get_available_lora_model_names()
234
  available_lora_models = available_lora_models + ["None"]
235
 
@@ -255,8 +205,12 @@ def handle_prompt_template_change(prompt_template, lora_model):
255
  "", visible=False)
256
  lora_mode_info = get_info_of_available_lora_model(lora_model)
257
  if lora_mode_info and isinstance(lora_mode_info, dict):
 
258
  model_prompt_template = lora_mode_info.get("prompt_template")
259
- if model_prompt_template and model_prompt_template != prompt_template:
 
 
 
260
  model_prompt_template_message_update = gr.Markdown.update(
261
  f"This model was trained with prompt template `{model_prompt_template}`.", visible=True)
262
 
@@ -303,7 +257,7 @@ def inference_ui():
303
  lora_model = gr.Dropdown(
304
  label="LoRA Model",
305
  elem_id="inference_lora_model",
306
- value="tloen/alpaca-lora-7b",
307
  allow_custom_value=True,
308
  )
309
  prompt_template = gr.Dropdown(
@@ -433,6 +387,8 @@ def inference_ui():
433
  interactive=False,
434
  elem_id="inference_raw_output")
435
 
 
 
436
  show_raw_change_event = show_raw.change(
437
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
438
  inputs=[show_raw],
@@ -454,6 +410,14 @@ def inference_ui():
454
  variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
455
  things_that_might_timeout.append(prompt_template_change_event)
456
 
 
 
 
 
 
 
 
 
457
  lora_model_change_event = lora_model.change(
458
  fn=handle_lora_model_change,
459
  inputs=[lora_model, prompt_template],
@@ -510,7 +474,7 @@ def inference_ui():
510
 
511
  // Workaround default value not shown.
512
  document.querySelector('#inference_lora_model input').value =
513
- 'tloen/alpaca-lora-7b';
514
  }, 100);
515
 
516
  // Add tooltips
@@ -654,6 +618,30 @@ def inference_ui():
654
  }, 500);
655
  }, 0);
656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  // Debounced updating the prompt preview.
658
  setTimeout(function () {
659
  function debounce(func, wait) {
 
8
 
9
  from ..globals import Global
10
  from ..models import get_model, get_tokenizer, get_device
11
+ from ..lib.inference import generate
12
  from ..utils.data import (
13
  get_available_template_names,
14
  get_available_lora_model_names,
15
  get_info_of_available_lora_model)
16
  from ..utils.prompter import Prompter
 
17
 
18
  device = get_device()
19
 
 
22
 
23
 
24
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
25
+ base_model_name = Global.base_model_name
26
 
27
  try:
28
  get_tokenizer(base_model_name)
 
48
  show_raw=False,
49
  progress=gr.Progress(track_tqdm=True),
50
  ):
51
+ base_model_name = Global.base_model_name
52
 
53
  try:
54
  if Global.generation_force_stopped_at is not None:
 
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
 
 
106
  generation_config = GenerationConfig(
107
  temperature=temperature,
108
  top_p=top_p,
 
111
  num_beams=num_beams,
112
  )
113
 
 
 
 
 
 
 
 
 
114
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
115
  if Global.should_stop_generating:
116
  return True
117
  return False
118
 
119
  Global.should_stop_generating = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ generation_args = {
122
+ 'model': model,
123
+ 'tokenizer': tokenizer,
124
+ 'prompt': prompt,
125
+ 'generation_config': generation_config,
126
+ 'max_new_tokens': max_new_tokens,
127
+ 'stopping_criteria': [ui_generation_stopping_criteria],
128
+ 'stream_output': stream_output
129
+ }
 
 
 
130
 
131
+ for (decoded_output, output) in generate(**generation_args):
132
+ raw_output_str = None
133
+ if show_raw:
134
+ raw_output_str = str(output)
135
+ response = prompter.get_response(decoded_output)
136
 
137
+ if Global.should_stop_generating:
138
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ yield (
141
+ gr.Textbox.update(
142
+ value=response, lines=inference_output_lines),
143
+ raw_output_str)
144
 
145
+ if Global.should_stop_generating:
146
+ # If the user stops the generation, and then clicks the
147
+ # generation button again, they may mysteriously landed
148
+ # here, in the previous, should-be-stopped generation
149
+ # function call, with the new generation function not be
150
+ # called at all. To workaround this, we yield a message
151
+ # and setting lines=1, and if the front-end JS detects
152
+ # that lines has been set to 1 (rows="1" in HTML),
153
+ # it will automatically click the generate button again
154
+ # (gr.Textbox.update() does not support updating
155
+ # elem_classes or elem_id).
156
+ # [WORKAROUND-UI01]
157
+ yield (
158
+ gr.Textbox.update(
159
+ value="Please retry", lines=1),
160
+ None)
161
+
162
+ return
163
  except Exception as e:
164
  raise gr.Error(e)
165
 
 
179
  current_prompt_template = current_prompt_template or next(
180
  iter(available_template_names_with_none), None)
181
 
182
+ default_lora_models = []
183
  available_lora_models = default_lora_models + get_available_lora_model_names()
184
  available_lora_models = available_lora_models + ["None"]
185
 
 
205
  "", visible=False)
206
  lora_mode_info = get_info_of_available_lora_model(lora_model)
207
  if lora_mode_info and isinstance(lora_mode_info, dict):
208
+ model_base_model = lora_mode_info.get("base_model")
209
  model_prompt_template = lora_mode_info.get("prompt_template")
210
+ if model_base_model and model_base_model != Global.base_model_name:
211
+ model_prompt_template_message_update = gr.Markdown.update(
212
+ f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.", visible=True)
213
+ elif model_prompt_template and model_prompt_template != prompt_template:
214
  model_prompt_template_message_update = gr.Markdown.update(
215
  f"This model was trained with prompt template `{model_prompt_template}`.", visible=True)
216
 
 
257
  lora_model = gr.Dropdown(
258
  label="LoRA Model",
259
  elem_id="inference_lora_model",
260
+ value="None",
261
  allow_custom_value=True,
262
  )
263
  prompt_template = gr.Dropdown(
 
387
  interactive=False,
388
  elem_id="inference_raw_output")
389
 
390
+ reload_selected_models_btn = gr.Button("", elem_id="inference_reload_selected_models_btn")
391
+
392
  show_raw_change_event = show_raw.change(
393
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
394
  inputs=[show_raw],
 
410
  variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
411
  things_that_might_timeout.append(prompt_template_change_event)
412
 
413
+ reload_selected_models_btn_event = reload_selected_models_btn.click(
414
+ fn=handle_prompt_template_change,
415
+ inputs=[prompt_template, lora_model],
416
+ outputs=[
417
+ model_prompt_template_message,
418
+ variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
419
+ things_that_might_timeout.append(reload_selected_models_btn_event)
420
+
421
  lora_model_change_event = lora_model.change(
422
  fn=handle_lora_model_change,
423
  inputs=[lora_model, prompt_template],
 
474
 
475
  // Workaround default value not shown.
476
  document.querySelector('#inference_lora_model input').value =
477
+ 'None';
478
  }, 100);
479
 
480
  // Add tooltips
 
618
  }, 500);
619
  }, 0);
620
 
621
+ // Reload model selection on possible base model change.
622
+ setTimeout(function () {
623
+ const elem = document.getElementById('main_page_tabs_container');
624
+ if (!elem) return;
625
+
626
+ let prevClassList = [];
627
+
628
+ new MutationObserver(function (mutationsList, observer) {
629
+ const currentPrevClassList = prevClassList;
630
+ const currentClassList = Array.from(elem.classList);
631
+ prevClassList = Array.from(elem.classList);
632
+
633
+ if (!currentPrevClassList.includes('hide')) return;
634
+ if (currentClassList.includes('hide')) return;
635
+
636
+ const inference_reload_selected_models_btn_elem = document.getElementById('inference_reload_selected_models_btn');
637
+
638
+ if (inference_reload_selected_models_btn_elem) inference_reload_selected_models_btn_elem.click();
639
+ }).observe(elem, {
640
+ attributes: true,
641
+ attributeFilter: ['class'],
642
+ });
643
+ }, 0);
644
+
645
  // Debounced updating the prompt preview.
646
  setTimeout(function () {
647
  function debounce(func, wait) {
llama_lora/ui/main_page.py CHANGED
@@ -17,25 +17,50 @@ def main_page():
17
  css=main_page_custom_css(),
18
  ) as main_page_blocks:
19
  with gr.Column(elem_id="main_page_content"):
20
- gr.Markdown(f"""
21
- <h1 class="app_title_text">{title}</h1> <wbr />
22
- <h2 class="app_subtitle_text">{Global.ui_subtitle}</h2>
23
- """)
24
- with gr.Tab("Inference"):
25
- inference_ui()
26
- with gr.Tab("Fine-tuning"):
27
- finetune_ui()
28
- with gr.Tab("Tokenizer"):
29
- tokenizer_ui()
30
- info = []
31
- if Global.version:
32
- info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
33
- info.append(f"Base model: `{Global.default_base_model_name}`")
34
- if Global.ui_show_sys_info:
35
- info.append(f"Data dir: `{Global.data_dir}`")
36
- gr.Markdown(f"""
37
- <small>{"&nbsp;&nbsp;·&nbsp;&nbsp;".join(info)}</small>
38
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  main_page_blocks.load(_js=f"""
40
  function () {{
41
  {popperjs_core_code()}
@@ -61,6 +86,17 @@ def main_page():
61
  });
62
  handle_gradio_container_element_class_change();
63
  }, 500);
 
 
 
 
 
 
 
 
 
 
 
64
  }
65
  """)
66
 
@@ -127,12 +163,81 @@ def main_page_custom_css():
127
  display: none;
128
  }
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  #main_page_content > .tabs > .tab-nav * {
131
  font-size: 1rem;
132
  font-weight: 700;
133
  /* text-transform: uppercase; */
134
  }
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  #inference_lora_model_group {
137
  border-radius: var(--block-radius);
138
  background: var(--block-background-fill);
@@ -147,7 +252,8 @@ def main_page_custom_css():
147
  position: absolute;
148
  bottom: 8px;
149
  left: 20px;
150
- z-index: 1;
 
151
  font-size: 12px;
152
  opacity: 0.7;
153
  }
@@ -413,6 +519,24 @@ def main_page_custom_css():
413
  margin: -32px -16px;
414
  }
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  .finetune_dataset_error_message {
417
  color: var(--error-text-color) !important;
418
  }
@@ -428,10 +552,43 @@ def main_page_custom_css():
428
  white-space: pre-wrap;
429
  }
430
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  #finetune_max_seq_length {
432
  flex: 2;
433
  }
434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  #finetune_save_total_limit,
436
  #finetune_save_steps,
437
  #finetune_logging_steps {
@@ -503,3 +660,28 @@ def main_page_custom_css():
503
  .tippy-box[data-animation=scale-subtle][data-placement^=top]{transform-origin:bottom}.tippy-box[data-animation=scale-subtle][data-placement^=bottom]{transform-origin:top}.tippy-box[data-animation=scale-subtle][data-placement^=left]{transform-origin:right}.tippy-box[data-animation=scale-subtle][data-placement^=right]{transform-origin:left}.tippy-box[data-animation=scale-subtle][data-state=hidden]{transform:scale(.8);opacity:0}
504
  """
505
  return css
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  css=main_page_custom_css(),
18
  ) as main_page_blocks:
19
  with gr.Column(elem_id="main_page_content"):
20
+ with gr.Row():
21
+ gr.Markdown(
22
+ f"""
23
+ <h1 class="app_title_text">{title}</h1> <wbr />
24
+ <h2 class="app_subtitle_text">{Global.ui_subtitle}</h2>
25
+ """,
26
+ elem_id="page_title",
27
+ )
28
+ global_base_model_select = gr.Dropdown(
29
+ label="Base Model",
30
+ elem_id="global_base_model_select",
31
+ choices=Global.base_model_choices,
32
+ value=lambda: Global.base_model_name,
33
+ allow_custom_value=True,
34
+ )
35
+ # global_base_model_select_loading_status = gr.Markdown("", elem_id="global_base_model_select_loading_status")
36
+
37
+ with gr.Column(elem_id="main_page_tabs_container") as main_page_tabs_container:
38
+ with gr.Tab("Inference"):
39
+ inference_ui()
40
+ with gr.Tab("Fine-tuning"):
41
+ finetune_ui()
42
+ with gr.Tab("Tokenizer"):
43
+ tokenizer_ui()
44
+ please_select_a_base_model_message = gr.Markdown("Please select a base model.", visible=False)
45
+ current_base_model_hint = gr.Markdown(lambda: Global.base_model_name, elem_id="current_base_model_hint")
46
+ foot_info = gr.Markdown(get_foot_info)
47
+
48
+ global_base_model_select.change(
49
+ fn=pre_handle_change_base_model,
50
+ inputs=[],
51
+ outputs=[main_page_tabs_container]
52
+ ).then(
53
+ fn=handle_change_base_model,
54
+ inputs=[global_base_model_select],
55
+ outputs=[
56
+ main_page_tabs_container,
57
+ please_select_a_base_model_message,
58
+ current_base_model_hint,
59
+ # global_base_model_select_loading_status,
60
+ foot_info
61
+ ]
62
+ )
63
+
64
  main_page_blocks.load(_js=f"""
65
  function () {{
66
  {popperjs_core_code()}
 
86
  });
87
  handle_gradio_container_element_class_change();
88
  }, 500);
89
+ """ + """
90
+ setTimeout(function () {
91
+ // Workaround default value not shown.
92
+ const current_base_model_hint_elem = document.querySelector('#current_base_model_hint > p');
93
+ if (!current_base_model_hint_elem) return;
94
+
95
+ const base_model_name = current_base_model_hint_elem.innerText;
96
+ document.querySelector('#global_base_model_select input').value = base_model_name;
97
+ document.querySelector('#global_base_model_select').classList.add('show');
98
+ }, 3200);
99
+ """ + """
100
  }
101
  """)
102
 
 
163
  display: none;
164
  }
165
 
166
+ #page_title {
167
+ flex-grow: 3;
168
+ }
169
+ #global_base_model_select {
170
+ position: relative;
171
+ align-self: center;
172
+ min-width: 250px;
173
+ padding: 2px 2px;
174
+ border: 0;
175
+ box-shadow: none;
176
+ opacity: 0;
177
+ pointer-events: none;
178
+ }
179
+ #global_base_model_select.show {
180
+ opacity: 1;
181
+ pointer-events: auto;
182
+ }
183
+ #global_base_model_select label .wrap-inner {
184
+ padding: 2px 8px;
185
+ }
186
+ #global_base_model_select label span {
187
+ margin-bottom: 2px;
188
+ font-size: 80%;
189
+ position: absolute;
190
+ top: -14px;
191
+ left: 8px;
192
+ opacity: 0;
193
+ }
194
+ #global_base_model_select:hover label span {
195
+ opacity: 1;
196
+ }
197
+
198
+ #global_base_model_select_loading_status {
199
+ position: absolute;
200
+ pointer-events: none;
201
+ top: 0;
202
+ left: 0;
203
+ right: 0;
204
+ bottom: 0;
205
+ }
206
+ #global_base_model_select_loading_status > .wrap:not(.hide) {
207
+ z-index: 9999;
208
+ position: absolute;
209
+ top: 112px !important;
210
+ bottom: 0 !important;
211
+ max-height: none;
212
+ background: var(--background-fill-primary);
213
+ opacity: 0.8;
214
+ }
215
+ #global_base_model_select ul {
216
+ z-index: 9999;
217
+ background: var(--block-background-fill);
218
+ }
219
+
220
+ #current_base_model_hint {
221
+ display: none;
222
+ }
223
+
224
  #main_page_content > .tabs > .tab-nav * {
225
  font-size: 1rem;
226
  font-weight: 700;
227
  /* text-transform: uppercase; */
228
  }
229
 
230
+ #inference_reload_selected_models_btn {
231
+ position: absolute;
232
+ top: 0;
233
+ left: 0;
234
+ width: 0;
235
+ height: 0;
236
+ padding: 0;
237
+ opacity: 0;
238
+ pointer-events: none;
239
+ }
240
+
241
  #inference_lora_model_group {
242
  border-radius: var(--block-radius);
243
  background: var(--block-background-fill);
 
252
  position: absolute;
253
  bottom: 8px;
254
  left: 20px;
255
+ z-index: 61;
256
+ width: 999px;
257
  font-size: 12px;
258
  opacity: 0.7;
259
  }
 
519
  margin: -32px -16px;
520
  }
521
 
522
+ #finetune_continue_from_model_box {
523
+ /* padding: 0; */
524
+ }
525
+ #finetune_continue_from_model_box .block {
526
+ border: 0;
527
+ box-shadow: none;
528
+ padding: 0;
529
+ }
530
+ #finetune_continue_from_model_box > * {
531
+ /* gap: 0; */
532
+ }
533
+ #finetune_continue_from_model_box button {
534
+ margin-top: 16px;
535
+ }
536
+ #finetune_continue_from_model {
537
+ flex-grow: 2;
538
+ }
539
+
540
  .finetune_dataset_error_message {
541
  color: var(--error-text-color) !important;
542
  }
 
552
  white-space: pre-wrap;
553
  }
554
 
555
+ /*
556
+ #finetune_dataset_preview {
557
+ max-height: 100vh;
558
+ overflow: auto;
559
+ border: var(--block-border-width) solid var(--border-color-primary);
560
+ border-radius: var(--radius-lg);
561
+ }
562
+ #finetune_dataset_preview .table-wrap {
563
+ border: 0 !important;
564
+ }
565
+ */
566
+
567
  #finetune_max_seq_length {
568
  flex: 2;
569
  }
570
 
571
+ #finetune_lora_target_modules_add_box {
572
+ margin-top: -24px;
573
+ padding-top: 8px;
574
+ border-top-left-radius: 0;
575
+ border-top-right-radius: 0;
576
+ border-top: 0;
577
+ }
578
+ #finetune_lora_target_modules_add_box > * > .form {
579
+ border: 0;
580
+ box-shadow: none;
581
+ }
582
+ #finetune_lora_target_modules_add {
583
+ padding: 0;
584
+ }
585
+ #finetune_lora_target_modules_add input {
586
+ padding: 4px 8px;
587
+ }
588
+ #finetune_lora_target_modules_add_btn {
589
+ min-width: 60px;
590
+ }
591
+
592
  #finetune_save_total_limit,
593
  #finetune_save_steps,
594
  #finetune_logging_steps {
 
660
  .tippy-box[data-animation=scale-subtle][data-placement^=top]{transform-origin:bottom}.tippy-box[data-animation=scale-subtle][data-placement^=bottom]{transform-origin:top}.tippy-box[data-animation=scale-subtle][data-placement^=left]{transform-origin:right}.tippy-box[data-animation=scale-subtle][data-placement^=right]{transform-origin:left}.tippy-box[data-animation=scale-subtle][data-state=hidden]{transform:scale(.8);opacity:0}
661
  """
662
  return css
663
+
664
+
665
+ def pre_handle_change_base_model():
666
+ return gr.Column.update(visible=False)
667
+
668
+
669
+ def handle_change_base_model(selected_base_model_name):
670
+ Global.base_model_name = selected_base_model_name
671
+
672
+ if Global.base_model_name:
673
+ return gr.Column.update(visible=True), gr.Markdown.update(visible=False), Global.base_model_name, get_foot_info()
674
+
675
+ return gr.Column.update(visible=False), gr.Markdown.update(visible=True), Global.base_model_name, get_foot_info()
676
+
677
+
678
+ def get_foot_info():
679
+ info = []
680
+ if Global.version:
681
+ info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
682
+ info.append(f"Base model: `{Global.base_model_name}`")
683
+ if Global.ui_show_sys_info:
684
+ info.append(f"Data dir: `{Global.data_dir}`")
685
+ return f"""\
686
+ <small>{"&nbsp;&nbsp;·&nbsp;&nbsp;".join(info)}</small>
687
+ """
llama_lora/ui/tokenizer_ui.py CHANGED
@@ -7,7 +7,7 @@ from ..models import get_tokenizer
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
10
- base_model_name = Global.default_base_model_name
11
  try:
12
  encoded_tokens = json.loads(encoded_tokens_json)
13
  if Global.ui_dev_mode:
@@ -20,7 +20,7 @@ def handle_decode(encoded_tokens_json):
20
 
21
 
22
  def handle_encode(decoded_tokens):
23
- base_model_name = Global.default_base_model_name
24
  try:
25
  if Global.ui_dev_mode:
26
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
 
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
10
+ base_model_name = Global.base_model_name
11
  try:
12
  encoded_tokens = json.loads(encoded_tokens_json)
13
  if Global.ui_dev_mode:
 
20
 
21
 
22
  def handle_encode(decoded_tokens):
23
+ base_model_name = Global.base_model_name
24
  try:
25
  if Global.ui_dev_mode:
26
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
llama_lora/utils/data.py CHANGED
@@ -30,19 +30,22 @@ def copy_sample_data_if_not_exists(source, destination):
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
- return [os.path.splitext(filename)[0] for filename in all_files if fnmatch.fnmatch(filename, "*.json")]
 
34
 
35
 
36
  def get_available_dataset_names():
37
  datasets_directory_path = os.path.join(Global.data_dir, "datasets")
38
  all_files = os.listdir(datasets_directory_path)
39
- return [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
 
40
 
41
 
42
  def get_available_lora_model_names():
43
- datasets_directory_path = os.path.join(Global.data_dir, "lora_models")
44
- all_items = os.listdir(datasets_directory_path)
45
- return [item for item in all_items if os.path.isdir(os.path.join(datasets_directory_path, item))]
 
46
 
47
 
48
  def get_path_of_available_lora_model(name):
 
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
+ names = [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
34
+ return sorted(names)
35
 
36
 
37
  def get_available_dataset_names():
38
  datasets_directory_path = os.path.join(Global.data_dir, "datasets")
39
  all_files = os.listdir(datasets_directory_path)
40
+ names = [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
41
+ return sorted(names)
42
 
43
 
44
  def get_available_lora_model_names():
45
+ lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
46
+ all_items = os.listdir(lora_models_directory_path)
47
+ names = [item for item in all_items if os.path.isdir(os.path.join(lora_models_directory_path, item))]
48
+ return sorted(names)
49
 
50
 
51
  def get_path_of_available_lora_model(name):
llama_lora/utils/prompter.py CHANGED
@@ -5,13 +5,15 @@ From https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py
5
 
6
  import json
7
  import os.path as osp
 
 
8
  from typing import Union, List
9
 
10
  from ..globals import Global
11
 
12
 
13
  class Prompter(object):
14
- __slots__ = ("template_name", "template", "_verbose")
15
 
16
  def __init__(self, template_name: str = "", verbose: bool = False):
17
  self._verbose = verbose
@@ -21,12 +23,41 @@ class Prompter(object):
21
  self.template_name = "None"
22
  return
23
  self.template_name = template_name
 
24
 
25
- file_name = osp.join(Global.data_dir, "templates",
26
- f"{template_name}.json")
27
- if not osp.exists(file_name):
28
- raise ValueError(f"Can't read {file_name}")
29
- with open(file_name) as fp:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.template = json.load(fp)
31
  if self._verbose:
32
  print(
@@ -47,23 +78,31 @@ class Prompter(object):
47
  res = variables.get("prompt", "")
48
  elif "variables" in self.template:
49
  variable_names = self.template.get("variables")
50
- if type(variables) == dict:
51
- variables = [variables.get(name, None)
52
- for name in variable_names]
53
- if "default" not in self.template:
54
- raise ValueError(
55
- f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
56
- default_prompt_name = self.template.get("default")
57
- if default_prompt_name not in self.template:
58
- raise ValueError(
59
- f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
60
- prompt_name = get_prompt_name(variables, variable_names)
61
- prompt_template = self.template.get(default_prompt_name)
62
- if prompt_name in self.template:
63
- prompt_template = self.template.get(prompt_name)
64
 
65
- res = prompt_template.format(
66
- **variables_to_dict(variables, variable_names))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  else:
69
  if type(variables) == dict:
@@ -92,18 +131,50 @@ class Prompter(object):
92
  def get_response(self, output: str) -> str:
93
  if self.template_name == "None":
94
  return output
 
 
 
 
 
95
  return self.template["response_split"].join(
96
- output.split(self.template["response_split"])[1:]
97
  ).strip()
98
 
99
  def get_variable_names(self) -> List[str]:
100
  if self.template_name == "None":
101
  return ["prompt"]
102
  elif "variables" in self.template:
103
- return self.template.get("variables")
104
  else:
105
  return ["instruction", "input"]
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def get_val(arr, index, default=None):
109
  return arr[index] if -len(arr) <= index < len(arr) else default
@@ -116,4 +187,62 @@ def get_prompt_name(variables, variable_names):
116
 
117
 
118
  def variables_to_dict(variables, variable_names):
119
- return {key: (variables[i] if i < len(variables) and variables[i] is not None else '') for i, key in enumerate(variable_names)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import json
7
  import os.path as osp
8
+ import importlib
9
+ import itertools
10
  from typing import Union, List
11
 
12
  from ..globals import Global
13
 
14
 
15
  class Prompter(object):
16
+ __slots__ = ("template_name", "template", "template_module", "_verbose")
17
 
18
  def __init__(self, template_name: str = "", verbose: bool = False):
19
  self._verbose = verbose
 
23
  self.template_name = "None"
24
  return
25
  self.template_name = template_name
26
+ self.template_module = None
27
 
28
+ base_filename, ext = osp.splitext(template_name)
29
+ if ext == "":
30
+ filename = base_filename + ".json"
31
+ else:
32
+ filename = base_filename + ext
33
+
34
+ file_path = osp.join(Global.data_dir, "templates", filename)
35
+
36
+ if not osp.exists(file_path):
37
+ raise ValueError(f"Can't read {file_path}")
38
+
39
+ if ext == ".py":
40
+ template_module_spec = importlib.util.spec_from_file_location(
41
+ "template_module", file_path)
42
+ template_module = importlib.util.module_from_spec(
43
+ template_module_spec)
44
+ template_module_spec.loader.exec_module(template_module)
45
+ self.template_module = template_module
46
+
47
+ if not hasattr(template_module, "variables"):
48
+ raise ValueError(
49
+ "The template module does not have a \"variables\" attribute.")
50
+
51
+ self.template = {
52
+ 'variables': template_module.variables
53
+ }
54
+
55
+ if hasattr(template_module, "response_split"):
56
+ self.template["response_split"] = template_module.response_split
57
+
58
+ return
59
+
60
+ with open(file_path) as fp:
61
  self.template = json.load(fp)
62
  if self._verbose:
63
  print(
 
78
  res = variables.get("prompt", "")
79
  elif "variables" in self.template:
80
  variable_names = self.template.get("variables")
81
+ if self.template_module:
82
+ if type(variables) == list:
83
+ variables = {k: v for k, v in zip(
84
+ variable_names, variables)}
 
 
 
 
 
 
 
 
 
 
85
 
86
+ res = self.template_module.get_prompt(variables)
87
+ else:
88
+ if type(variables) == dict:
89
+ variables = [variables.get(name, None)
90
+ for name in variable_names]
91
+
92
+ if "default" not in self.template:
93
+ raise ValueError(
94
+ f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
95
+ default_prompt_name = self.template.get("default")
96
+ if default_prompt_name not in self.template:
97
+ raise ValueError(
98
+ f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
99
+ prompt_name = get_prompt_name(variables, variable_names)
100
+ prompt_template = self.template.get(default_prompt_name)
101
+ if prompt_name in self.template:
102
+ prompt_template = self.template.get(prompt_name)
103
+
104
+ res = prompt_template.format(
105
+ **variables_to_dict(variables, variable_names))
106
 
107
  else:
108
  if type(variables) == dict:
 
131
  def get_response(self, output: str) -> str:
132
  if self.template_name == "None":
133
  return output
134
+
135
+ splitted_output = output.split(self.template["response_split"])
136
+ # if len(splitted_output) <= 1:
137
+ # return output.strip()
138
+
139
  return self.template["response_split"].join(
140
+ splitted_output[1:]
141
  ).strip()
142
 
143
  def get_variable_names(self) -> List[str]:
144
  if self.template_name == "None":
145
  return ["prompt"]
146
  elif "variables" in self.template:
147
+ return self.template['variables']
148
  else:
149
  return ["instruction", "input"]
150
 
151
+ def get_train_data_from_dataset(self, data, only_first_n_items=None):
152
+ if self.template_module:
153
+ if hasattr(self.template_module,
154
+ "get_train_data_list_from_dataset"):
155
+ data = self.template_module.get_train_data_list_from_dataset(
156
+ data)
157
+ if only_first_n_items:
158
+ data = data[:only_first_n_items]
159
+ return list(itertools.chain(*list(
160
+ map(self.template_module.get_train_data, data)
161
+ )))
162
+
163
+ if only_first_n_items:
164
+ data = data[:only_first_n_items]
165
+
166
+ data = process_json_dataset(data)
167
+
168
+ train_data = [
169
+ {
170
+ 'prompt': self.generate_prompt(d['variables']),
171
+ 'completion': d['output'],
172
+ **{"_var_" + k: v for k, v in d['variables'].items()}
173
+ }
174
+ for d in data]
175
+
176
+ return train_data
177
+
178
 
179
  def get_val(arr, index, default=None):
180
  return arr[index] if -len(arr) <= index < len(arr) else default
 
187
 
188
 
189
  def variables_to_dict(variables, variable_names):
190
+ return {
191
+ key: (variables[i] if i < len(variables)
192
+ and variables[i] is not None else '')
193
+ for i, key in enumerate(variable_names)
194
+ }
195
+
196
+
197
+ def process_json_dataset(data):
198
+ if not isinstance(data, list):
199
+ raise ValueError("The dataset is not an array of objects.")
200
+
201
+ first_item = get_val_from_arr(data, 0, None)
202
+
203
+ if first_item is None:
204
+ raise ValueError("The dataset is empty.")
205
+ if not isinstance(first_item, dict):
206
+ raise ValueError("The dataset is not an array of objects.")
207
+
208
+ # Convert OpenAI fine-tuning dataset to LLaMA LoRA style
209
+ if "completion" in first_item and "output" not in first_item:
210
+ data = [
211
+ {"output" if k == "completion" else k: v for k, v in d.items()}
212
+ for d in data]
213
+ first_item = get_val_from_arr(data, 0, None)
214
+
215
+ # Flatten Stanford Alpaca style instances
216
+ if "instances" in first_item and isinstance(first_item["instances"], list):
217
+ data = [
218
+ {"output" if k == "completion" else k: v for k, v in d.items()}
219
+ for d in data]
220
+ flattened_data = []
221
+ for item in data:
222
+ for instance in item["instances"]:
223
+ d = {k: v for k, v in item.items() if k != "instances"}
224
+ d.update(instance)
225
+ flattened_data.append(d)
226
+ data = flattened_data
227
+ first_item = get_val_from_arr(data, 0, None)
228
+
229
+ if "output" not in first_item:
230
+ raise ValueError(
231
+ "The data does not contains an \"output\" or \"completion\".")
232
+
233
+ # Put all variables under the "variables" key if it does not exists
234
+ if "variables" not in first_item:
235
+ data = [
236
+ {
237
+ "variables":
238
+ {k: v for k, v in d.items() if k != "output"},
239
+ "output":
240
+ d["output"]
241
+ }
242
+ for d in data
243
+ ]
244
+ return data
245
+
246
+
247
+ def get_val_from_arr(arr, index, default=None):
248
+ return arr[index] if -len(arr) <= index < len(arr) else default
lora_models/alpaca-lora-7b/finetune_params.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_train_epochs": 10,
3
+ "learning_rate": 0.0003,
4
+ "cutoff_len": 512,
5
+ "lora_r": 16,
6
+ "lora_alpha": 16,
7
+ "lora_dropout": 0.05,
8
+ "lora_target_modules": [
9
+ "q_proj",
10
+ "v_proj",
11
+ "k_proj",
12
+ "o_proj"
13
+ ],
14
+ "train_on_inputs": true,
15
+ "group_by_length": false,
16
+ "save_steps": 2000,
17
+ "save_total_limit": 10,
18
+ "logging_steps": 10
19
+ }
lora_models/alpaca-lora-7b/info.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "hf_model_name": "tloen/alpaca-lora-7b",
3
+ "load_from_hf": true,
4
+ "base_model": "decapoda-research/llama-7b-hf",
5
+ "prompt_template": "alpaca"
6
+ }
lora_models/unhelpful-ai-v01/finetune_params.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_train_epochs": 16,
3
+ "learning_rate": 0.0003,
4
+ "cutoff_len": 512,
5
+ "lora_r": 12,
6
+ "lora_alpha": 32,
7
+ "lora_dropout": 0.05,
8
+ "lora_target_modules": [
9
+ "q_proj",
10
+ "v_proj",
11
+ "k_proj",
12
+ "o_proj"
13
+ ],
14
+ "train_on_inputs": false,
15
+ "group_by_length": false,
16
+ "save_steps": 500,
17
+ "save_total_limit": 5,
18
+ "logging_steps": 10
19
+ }