Spaces:
Runtime error
Runtime error
Merge branch 'main' into hf-ui-demo
Browse files- .gitignore +1 -0
- LLaMA_LoRA.ipynb +12 -3
- README.md +11 -3
- app.py +41 -2
- download_base_model.py +32 -0
- llama_lora/globals.py +7 -0
- llama_lora/lib/finetune.py +88 -25
- llama_lora/lib/get_device.py +15 -0
- llama_lora/lib/inference.py +86 -0
- llama_lora/{utils/callbacks.py → lib/streaming_generation_utils.py} +0 -0
- llama_lora/models.py +24 -24
- llama_lora/ui/finetune_ui.py +458 -215
- llama_lora/ui/inference_ui.py +83 -95
- llama_lora/ui/main_page.py +202 -20
- llama_lora/ui/tokenizer_ui.py +2 -2
- llama_lora/utils/data.py +8 -5
- llama_lora/utils/prompter.py +154 -25
- lora_models/alpaca-lora-7b/finetune_params.json +19 -0
- lora_models/alpaca-lora-7b/info.json +6 -0
- lora_models/unhelpful-ai-v01/finetune_params.json +19 -0
.gitignore
CHANGED
@@ -3,4 +3,5 @@ __pycache__/
|
|
3 |
/venv
|
4 |
.vscode
|
5 |
|
|
|
6 |
/data
|
|
|
3 |
/venv
|
4 |
.vscode
|
5 |
|
6 |
+
/wandb
|
7 |
/data
|
LLaMA_LoRA.ipynb
CHANGED
@@ -60,12 +60,20 @@
|
|
60 |
"# @title A small workaround { display-mode: \"form\" }\n",
|
61 |
"# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
|
62 |
"# @markdown The error will disappear on the next run.\n",
|
63 |
-
"!pip install Pillow==9.3.0\n",
|
|
|
64 |
"import PIL\n",
|
65 |
"major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
|
66 |
"version_float = major + minor / 10**len(str(minor))\n",
|
67 |
-
"print(version_float)\n",
|
68 |
"if version_float < 9.003:\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
" raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
|
70 |
],
|
71 |
"metadata": {
|
@@ -281,7 +289,8 @@
|
|
281 |
"\n",
|
282 |
"# Set Configs\n",
|
283 |
"from llama_lora.llama_lora.globals import Global\n",
|
284 |
-
"Global.default_base_model_name = base_model\n",
|
|
|
285 |
"data_dir_realpath = !realpath ./data\n",
|
286 |
"Global.data_dir = data_dir_realpath[0]\n",
|
287 |
"Global.load_8bit = True\n",
|
|
|
60 |
"# @title A small workaround { display-mode: \"form\" }\n",
|
61 |
"# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
|
62 |
"# @markdown The error will disappear on the next run.\n",
|
63 |
+
"!pip install Pillow==9.3.0 numpy==1.23.5\n",
|
64 |
+
"\n",
|
65 |
"import PIL\n",
|
66 |
"major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
|
67 |
"version_float = major + minor / 10**len(str(minor))\n",
|
68 |
+
"print('PIL', version_float)\n",
|
69 |
"if version_float < 9.003:\n",
|
70 |
+
" raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")\n",
|
71 |
+
"\n",
|
72 |
+
"import numpy\n",
|
73 |
+
"major, minor = map(float, numpy.__version__.split(\".\")[:2])\n",
|
74 |
+
"version_float = major + minor / 10**len(str(minor))\n",
|
75 |
+
"print('numpy', version_float)\n",
|
76 |
+
"if version_float < 1.0023:\n",
|
77 |
" raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
|
78 |
],
|
79 |
"metadata": {
|
|
|
289 |
"\n",
|
290 |
"# Set Configs\n",
|
291 |
"from llama_lora.llama_lora.globals import Global\n",
|
292 |
+
"Global.default_base_model_name = Global.base_model_name = base_model\n",
|
293 |
+
"Global.base_model_choices = [base_model]\n",
|
294 |
"data_dir_realpath = !realpath ./data\n",
|
295 |
"Global.data_dir = data_dir_realpath[0]\n",
|
296 |
"Global.load_8bit = True\n",
|
README.md
CHANGED
@@ -34,8 +34,8 @@ Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) e
|
|
34 |
|
35 |
* **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
|
36 |
* Loads and stores data in Google Drive.
|
37 |
-
* Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/
|
38 |
-
* Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/
|
39 |
* Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
|
40 |
* Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
|
41 |
* Use prompt templates to keep your dataset DRY.
|
@@ -51,6 +51,8 @@ There are various ways to run this app:
|
|
51 |
|
52 |
### Run On Google Colab
|
53 |
|
|
|
|
|
54 |
Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
|
55 |
|
56 |
You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
|
@@ -81,13 +83,14 @@ file_mounts:
|
|
81 |
setup: |
|
82 |
git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
|
83 |
cd llama_lora_tuner && pip install -r requirements.lock.txt
|
|
|
84 |
cd ..
|
85 |
echo 'Dependencies installed.'
|
86 |
|
87 |
# Start the app.
|
88 |
run: |
|
89 |
echo 'Starting...'
|
90 |
-
python llama_lora_tuner/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
|
91 |
```
|
92 |
|
93 |
Then launch a cluster to run the task:
|
@@ -135,6 +138,11 @@ For more options, see `python app.py --help`.
|
|
135 |
</details>
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
|
|
138 |
## Acknowledgements
|
139 |
|
140 |
* https://github.com/tloen/alpaca-lora
|
|
|
34 |
|
35 |
* **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
|
36 |
* Loads and stores data in Google Drive.
|
37 |
+
* Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/IoEMgouZ5xU"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231023326-f28c84e2-df74-4179-b0ac-c25c4e8ca001.gif" /></a>
|
38 |
+
* Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/IoEMgouZ5xU?t=60"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231026640-b5cf5c79-9fe9-430b-8d4e-7346eb9567ad.gif" /></a>
|
39 |
* Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
|
40 |
* Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
|
41 |
* Use prompt templates to keep your dataset DRY.
|
|
|
51 |
|
52 |
### Run On Google Colab
|
53 |
|
54 |
+
*See [video](https://youtu.be/lByYOMdy9h4) for step-by-step instructions.*
|
55 |
+
|
56 |
Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
|
57 |
|
58 |
You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
|
|
|
83 |
setup: |
|
84 |
git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
|
85 |
cd llama_lora_tuner && pip install -r requirements.lock.txt
|
86 |
+
pip install wandb
|
87 |
cd ..
|
88 |
echo 'Dependencies installed.'
|
89 |
|
90 |
# Start the app.
|
91 |
run: |
|
92 |
echo 'Starting...'
|
93 |
+
python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key "$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model='decapoda-research/llama-7b-hf' --share
|
94 |
```
|
95 |
|
96 |
Then launch a cluster to run the task:
|
|
|
138 |
</details>
|
139 |
|
140 |
|
141 |
+
## Usage
|
142 |
+
|
143 |
+
See [video on YouTube](https://youtu.be/IoEMgouZ5xU).
|
144 |
+
|
145 |
+
|
146 |
## Acknowledgements
|
147 |
|
148 |
* https://github.com/tloen/alpaca-lora
|
app.py
CHANGED
@@ -5,21 +5,41 @@ import fire
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from llama_lora.globals import Global
|
|
|
8 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
9 |
from llama_lora.utils.data import init_data_dir
|
10 |
|
11 |
|
|
|
12 |
def main(
|
13 |
-
load_8bit: bool = False,
|
14 |
base_model: str = "",
|
15 |
data_dir: str = "",
|
|
|
16 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
17 |
server_name: str = "127.0.0.1",
|
18 |
share: bool = False,
|
19 |
skip_loading_base_model: bool = False,
|
|
|
20 |
ui_show_sys_info: bool = True,
|
21 |
ui_dev_mode: bool = False,
|
|
|
|
|
22 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
|
24 |
data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
|
25 |
assert (
|
@@ -30,16 +50,35 @@ def main(
|
|
30 |
data_dir
|
31 |
), "Please specify a --data_dir, e.g. --data_dir='./data'"
|
32 |
|
33 |
-
Global.default_base_model_name = base_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
Global.data_dir = os.path.abspath(data_dir)
|
35 |
Global.load_8bit = load_8bit
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
Global.ui_dev_mode = ui_dev_mode
|
38 |
Global.ui_show_sys_info = ui_show_sys_info
|
39 |
|
40 |
os.makedirs(data_dir, exist_ok=True)
|
41 |
init_data_dir()
|
42 |
|
|
|
|
|
|
|
43 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
44 |
main_page()
|
45 |
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from llama_lora.globals import Global
|
8 |
+
from llama_lora.models import prepare_base_model
|
9 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
10 |
from llama_lora.utils.data import init_data_dir
|
11 |
|
12 |
|
13 |
+
|
14 |
def main(
|
|
|
15 |
base_model: str = "",
|
16 |
data_dir: str = "",
|
17 |
+
base_model_choices: str = "",
|
18 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
19 |
server_name: str = "127.0.0.1",
|
20 |
share: bool = False,
|
21 |
skip_loading_base_model: bool = False,
|
22 |
+
load_8bit: bool = False,
|
23 |
ui_show_sys_info: bool = True,
|
24 |
ui_dev_mode: bool = False,
|
25 |
+
wandb_api_key: str = "",
|
26 |
+
wandb_project: str = "",
|
27 |
):
|
28 |
+
'''
|
29 |
+
Start the LLaMA-LoRA Tuner UI.
|
30 |
+
|
31 |
+
:param base_model: (required) The name of the default base model to use.
|
32 |
+
:param data_dir: (required) The path to the directory to store data.
|
33 |
+
|
34 |
+
:param base_model_choices: Base model selections to display on the UI, seperated by ",". For example: 'decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'.
|
35 |
+
|
36 |
+
:param server_name: Allows to listen on all interfaces by providing '0.0.0.0'.
|
37 |
+
:param share: Create a public Gradio URL.
|
38 |
+
|
39 |
+
:param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
|
40 |
+
:param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
|
41 |
+
'''
|
42 |
+
|
43 |
base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
|
44 |
data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
|
45 |
assert (
|
|
|
50 |
data_dir
|
51 |
), "Please specify a --data_dir, e.g. --data_dir='./data'"
|
52 |
|
53 |
+
Global.default_base_model_name = Global.base_model_name = base_model
|
54 |
+
|
55 |
+
if base_model_choices:
|
56 |
+
base_model_choices = base_model_choices.split(',')
|
57 |
+
base_model_choices = [name.strip() for name in base_model_choices]
|
58 |
+
Global.base_model_choices = base_model_choices
|
59 |
+
|
60 |
+
if base_model not in Global.base_model_choices:
|
61 |
+
Global.base_model_choices = [base_model] + Global.base_model_choices
|
62 |
+
|
63 |
Global.data_dir = os.path.abspath(data_dir)
|
64 |
Global.load_8bit = load_8bit
|
65 |
|
66 |
+
if len(wandb_api_key) > 0:
|
67 |
+
Global.enable_wandb = True
|
68 |
+
Global.wandb_api_key = wandb_api_key
|
69 |
+
if len(wandb_project) > 0:
|
70 |
+
Global.enable_wandb = True
|
71 |
+
Global.wandb_project = wandb_project
|
72 |
+
|
73 |
Global.ui_dev_mode = ui_dev_mode
|
74 |
Global.ui_show_sys_info = ui_show_sys_info
|
75 |
|
76 |
os.makedirs(data_dir, exist_ok=True)
|
77 |
init_data_dir()
|
78 |
|
79 |
+
if (not skip_loading_base_model) and (not ui_dev_mode):
|
80 |
+
prepare_base_model(base_model)
|
81 |
+
|
82 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
83 |
main_page()
|
84 |
|
download_base_model.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
|
3 |
+
from llama_lora.models import get_new_base_model, clear_cache
|
4 |
+
|
5 |
+
|
6 |
+
def main(
|
7 |
+
base_model_names: str = "",
|
8 |
+
):
|
9 |
+
'''
|
10 |
+
Download and cache base models form Hugging Face.
|
11 |
+
|
12 |
+
:param base_model_names: Names of the base model you want to download, seperated by ",". For example: 'decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'.
|
13 |
+
'''
|
14 |
+
|
15 |
+
assert (
|
16 |
+
base_model_names
|
17 |
+
), "Please specify --base_model_names, e.g. --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'"
|
18 |
+
|
19 |
+
base_model_names = base_model_names.split(',')
|
20 |
+
base_model_names = [name.strip() for name in base_model_names]
|
21 |
+
|
22 |
+
print(f"Base models: {', '.join(base_model_names)}.")
|
23 |
+
|
24 |
+
for name in base_model_names:
|
25 |
+
print(f"Preparing {name}...")
|
26 |
+
get_new_base_model(name)
|
27 |
+
clear_cache()
|
28 |
+
|
29 |
+
print("Done.")
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
fire.Fire(main)
|
llama_lora/globals.py
CHANGED
@@ -17,6 +17,8 @@ class Global:
|
|
17 |
load_8bit: bool = False
|
18 |
|
19 |
default_base_model_name: str = ""
|
|
|
|
|
20 |
|
21 |
# Functions
|
22 |
train_fn: Any = train
|
@@ -40,6 +42,11 @@ class Global:
|
|
40 |
gpu_total_cores = None # GPU total cores
|
41 |
gpu_total_memory = None
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
# UI related
|
44 |
ui_title: str = "LLaMA-LoRA Tuner"
|
45 |
ui_emoji: str = "🦙🎛️"
|
|
|
17 |
load_8bit: bool = False
|
18 |
|
19 |
default_base_model_name: str = ""
|
20 |
+
base_model_name: str = ""
|
21 |
+
base_model_choices: List[str] = []
|
22 |
|
23 |
# Functions
|
24 |
train_fn: Any = train
|
|
|
42 |
gpu_total_cores = None # GPU total cores
|
43 |
gpu_total_memory = None
|
44 |
|
45 |
+
# WandB
|
46 |
+
enable_wandb = False
|
47 |
+
wandb_api_key = None
|
48 |
+
default_wandb_project = "llama-lora-tuner"
|
49 |
+
|
50 |
# UI related
|
51 |
ui_title: str = "LLaMA-LoRA Tuner"
|
52 |
ui_emoji: str = "🦙🎛️"
|
llama_lora/lib/finetune.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
from typing import Any, List
|
4 |
|
5 |
import json
|
@@ -32,7 +33,7 @@ def train(
|
|
32 |
num_train_epochs: int = 3,
|
33 |
learning_rate: float = 3e-4,
|
34 |
cutoff_len: int = 256,
|
35 |
-
val_set_size: int = 2000,
|
36 |
# lora hyperparams
|
37 |
lora_r: int = 8,
|
38 |
lora_alpha: int = 16,
|
@@ -45,13 +46,78 @@ def train(
|
|
45 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
46 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
47 |
# either training checkpoint or final adapter
|
48 |
-
resume_from_checkpoint
|
49 |
save_steps: int = 200,
|
50 |
save_total_limit: int = 3,
|
51 |
logging_steps: int = 10,
|
52 |
# logging
|
53 |
-
callbacks: List[Any] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
if os.path.exists(output_dir):
|
56 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
57 |
raise ValueError(
|
@@ -138,6 +204,8 @@ def train(
|
|
138 |
|
139 |
# If train_dataset_data is a list, convert it to datasets.Dataset
|
140 |
if isinstance(train_dataset_data, list):
|
|
|
|
|
141 |
train_dataset_data = Dataset.from_list(train_dataset_data)
|
142 |
|
143 |
if resume_from_checkpoint:
|
@@ -158,7 +226,7 @@ def train(
|
|
158 |
adapters_weights = torch.load(checkpoint_name)
|
159 |
model = set_peft_model_state_dict(model, adapters_weights)
|
160 |
else:
|
161 |
-
|
162 |
|
163 |
# Be more transparent about the % of trainable params.
|
164 |
model.print_trainable_parameters()
|
@@ -197,15 +265,15 @@ def train(
|
|
197 |
optim="adamw_torch",
|
198 |
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
199 |
save_strategy="steps",
|
200 |
-
eval_steps=
|
201 |
save_steps=save_steps,
|
202 |
output_dir=output_dir,
|
203 |
save_total_limit=save_total_limit,
|
204 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
205 |
ddp_find_unused_parameters=False if ddp else None,
|
206 |
group_by_length=group_by_length,
|
207 |
-
|
208 |
-
|
209 |
),
|
210 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
211 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
@@ -217,24 +285,16 @@ def train(
|
|
217 |
os.makedirs(output_dir)
|
218 |
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
219 |
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
220 |
-
with open(os.path.join(output_dir, "
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
'lora_target_modules': lora_target_modules,
|
231 |
-
'train_on_inputs': train_on_inputs,
|
232 |
-
'group_by_length': group_by_length,
|
233 |
-
'save_steps': save_steps,
|
234 |
-
'save_total_limit': save_total_limit,
|
235 |
-
'logging_steps': logging_steps,
|
236 |
-
}
|
237 |
-
json.dump(finetune_params, finetune_params_json_file, indent=2)
|
238 |
|
239 |
model.config.use_cache = False
|
240 |
|
@@ -261,4 +321,7 @@ def train(
|
|
261 |
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
262 |
json.dump(train_output, train_output_json_file, indent=2)
|
263 |
|
|
|
|
|
|
|
264 |
return train_output
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import importlib
|
4 |
from typing import Any, List
|
5 |
|
6 |
import json
|
|
|
33 |
num_train_epochs: int = 3,
|
34 |
learning_rate: float = 3e-4,
|
35 |
cutoff_len: int = 256,
|
36 |
+
val_set_size: int = 2000,
|
37 |
# lora hyperparams
|
38 |
lora_r: int = 8,
|
39 |
lora_alpha: int = 16,
|
|
|
46 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
47 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
48 |
# either training checkpoint or final adapter
|
49 |
+
resume_from_checkpoint = None,
|
50 |
save_steps: int = 200,
|
51 |
save_total_limit: int = 3,
|
52 |
logging_steps: int = 10,
|
53 |
# logging
|
54 |
+
callbacks: List[Any] = [],
|
55 |
+
# wandb params
|
56 |
+
wandb_api_key = None,
|
57 |
+
wandb_project: str = "",
|
58 |
+
wandb_group = None,
|
59 |
+
wandb_run_name: str = "",
|
60 |
+
wandb_tags: List[str] = [],
|
61 |
+
wandb_watch: str = "false", # options: false | gradients | all
|
62 |
+
wandb_log_model: str = "true", # options: false | true
|
63 |
):
|
64 |
+
# for logging
|
65 |
+
finetune_args = {
|
66 |
+
'micro_batch_size': micro_batch_size,
|
67 |
+
'gradient_accumulation_steps': gradient_accumulation_steps,
|
68 |
+
'num_train_epochs': num_train_epochs,
|
69 |
+
'learning_rate': learning_rate,
|
70 |
+
'cutoff_len': cutoff_len,
|
71 |
+
'val_set_size': val_set_size,
|
72 |
+
'lora_r': lora_r,
|
73 |
+
'lora_alpha': lora_alpha,
|
74 |
+
'lora_dropout': lora_dropout,
|
75 |
+
'lora_target_modules': lora_target_modules,
|
76 |
+
'train_on_inputs': train_on_inputs,
|
77 |
+
'group_by_length': group_by_length,
|
78 |
+
'save_steps': save_steps,
|
79 |
+
'save_total_limit': save_total_limit,
|
80 |
+
'logging_steps': logging_steps,
|
81 |
+
}
|
82 |
+
if val_set_size and val_set_size > 0:
|
83 |
+
finetune_args['val_set_size'] = val_set_size
|
84 |
+
if resume_from_checkpoint:
|
85 |
+
finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
|
86 |
+
|
87 |
+
wandb = None
|
88 |
+
if wandb_api_key:
|
89 |
+
os.environ["WANDB_API_KEY"] = wandb_api_key
|
90 |
+
|
91 |
+
# wandb: WARNING Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to https://wandb.me/wandb-init.
|
92 |
+
# if wandb_project:
|
93 |
+
# os.environ["WANDB_PROJECT"] = wandb_project
|
94 |
+
# if wandb_run_name:
|
95 |
+
# os.environ["WANDB_RUN_NAME"] = wandb_run_name
|
96 |
+
|
97 |
+
if wandb_watch:
|
98 |
+
os.environ["WANDB_WATCH"] = wandb_watch
|
99 |
+
if wandb_log_model:
|
100 |
+
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
101 |
+
use_wandb = (wandb_project and len(wandb_project) > 0) or (
|
102 |
+
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
103 |
+
)
|
104 |
+
if use_wandb:
|
105 |
+
os.environ['WANDB_MODE'] = "online"
|
106 |
+
wandb = importlib.import_module("wandb")
|
107 |
+
wandb.init(
|
108 |
+
project=wandb_project,
|
109 |
+
resume="auto",
|
110 |
+
group=wandb_group,
|
111 |
+
name=wandb_run_name,
|
112 |
+
tags=wandb_tags,
|
113 |
+
reinit=True,
|
114 |
+
magic=True,
|
115 |
+
config={'finetune_args': finetune_args},
|
116 |
+
# id=None # used for resuming
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
os.environ['WANDB_MODE'] = "disabled"
|
120 |
+
|
121 |
if os.path.exists(output_dir):
|
122 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
123 |
raise ValueError(
|
|
|
204 |
|
205 |
# If train_dataset_data is a list, convert it to datasets.Dataset
|
206 |
if isinstance(train_dataset_data, list):
|
207 |
+
with open(os.path.join(output_dir, "train_data_samples.json"), 'w') as file:
|
208 |
+
json.dump(list(train_dataset_data[:100]), file, indent=2)
|
209 |
train_dataset_data = Dataset.from_list(train_dataset_data)
|
210 |
|
211 |
if resume_from_checkpoint:
|
|
|
226 |
adapters_weights = torch.load(checkpoint_name)
|
227 |
model = set_peft_model_state_dict(model, adapters_weights)
|
228 |
else:
|
229 |
+
raise ValueError(f"Checkpoint {checkpoint_name} not found")
|
230 |
|
231 |
# Be more transparent about the % of trainable params.
|
232 |
model.print_trainable_parameters()
|
|
|
265 |
optim="adamw_torch",
|
266 |
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
267 |
save_strategy="steps",
|
268 |
+
eval_steps=save_steps if val_set_size > 0 else None,
|
269 |
save_steps=save_steps,
|
270 |
output_dir=output_dir,
|
271 |
save_total_limit=save_total_limit,
|
272 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
273 |
ddp_find_unused_parameters=False if ddp else None,
|
274 |
group_by_length=group_by_length,
|
275 |
+
report_to="wandb" if use_wandb else None,
|
276 |
+
run_name=wandb_run_name if use_wandb else None,
|
277 |
),
|
278 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
279 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
|
|
285 |
os.makedirs(output_dir)
|
286 |
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
287 |
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
288 |
+
with open(os.path.join(output_dir, "finetune_args.json"), 'w') as finetune_args_json_file:
|
289 |
+
json.dump(finetune_args, finetune_args_json_file, indent=2)
|
290 |
+
|
291 |
+
# Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
|
292 |
+
# if train_data:
|
293 |
+
# with open(os.path.join(output_dir, "train_dataset_samples.json"), 'w') as file:
|
294 |
+
# json.dump(list(train_data[:100]), file, indent=2)
|
295 |
+
# if val_data:
|
296 |
+
# with open(os.path.join(output_dir, "eval_dataset_samples.json"), 'w') as file:
|
297 |
+
# json.dump(list(val_data[:100]), file, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
model.config.use_cache = False
|
300 |
|
|
|
321 |
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
322 |
json.dump(train_output, train_output_json_file, indent=2)
|
323 |
|
324 |
+
if use_wandb and wandb:
|
325 |
+
wandb.finish()
|
326 |
+
|
327 |
return train_output
|
llama_lora/lib/get_device.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def get_device():
|
5 |
+
device ="cpu"
|
6 |
+
if torch.cuda.is_available():
|
7 |
+
device = "cuda"
|
8 |
+
|
9 |
+
try:
|
10 |
+
if torch.backends.mps.is_available():
|
11 |
+
device = "mps"
|
12 |
+
except: # noqa: E722
|
13 |
+
pass
|
14 |
+
|
15 |
+
return device
|
llama_lora/lib/inference.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
|
4 |
+
from .get_device import get_device
|
5 |
+
from .streaming_generation_utils import Iteratorize, Stream
|
6 |
+
|
7 |
+
def generate(
|
8 |
+
# model
|
9 |
+
model,
|
10 |
+
tokenizer,
|
11 |
+
# input
|
12 |
+
prompt,
|
13 |
+
generation_config,
|
14 |
+
max_new_tokens,
|
15 |
+
stopping_criteria=[],
|
16 |
+
# output options
|
17 |
+
stream_output=False
|
18 |
+
):
|
19 |
+
device = get_device()
|
20 |
+
|
21 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
22 |
+
input_ids = inputs["input_ids"].to(device)
|
23 |
+
generate_params = {
|
24 |
+
"input_ids": input_ids,
|
25 |
+
"generation_config": generation_config,
|
26 |
+
"return_dict_in_generate": True,
|
27 |
+
"output_scores": True,
|
28 |
+
"max_new_tokens": max_new_tokens,
|
29 |
+
"stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
|
30 |
+
}
|
31 |
+
|
32 |
+
skip_special_tokens = True
|
33 |
+
|
34 |
+
if '/dolly' in tokenizer.name_or_path:
|
35 |
+
# dolly has additional_special_tokens as ['### End', '### Instruction:', '### Response:'], skipping them will break the prompter's reply extraction.
|
36 |
+
skip_special_tokens = False
|
37 |
+
# Ensure generation stops once it generates "### End"
|
38 |
+
end_key_token_id = tokenizer.encode("### End")
|
39 |
+
end_key_token_id = end_key_token_id[0] # 50277
|
40 |
+
if isinstance(generate_params['generation_config'].eos_token_id, str):
|
41 |
+
generate_params['generation_config'].eos_token_id = [generate_params['generation_config'].eos_token_id]
|
42 |
+
elif not generate_params['generation_config'].eos_token_id:
|
43 |
+
generate_params['generation_config'].eos_token_id = []
|
44 |
+
generate_params['generation_config'].eos_token_id.append(end_key_token_id)
|
45 |
+
|
46 |
+
if stream_output:
|
47 |
+
# Stream the reply 1 token at a time.
|
48 |
+
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
49 |
+
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
50 |
+
generation_output = None
|
51 |
+
|
52 |
+
def generate_with_callback(callback=None, **kwargs):
|
53 |
+
nonlocal generation_output
|
54 |
+
kwargs["stopping_criteria"].insert(
|
55 |
+
0,
|
56 |
+
Stream(callback_func=callback)
|
57 |
+
)
|
58 |
+
with torch.no_grad():
|
59 |
+
generation_output = model.generate(**kwargs)
|
60 |
+
|
61 |
+
def generate_with_streaming(**kwargs):
|
62 |
+
return Iteratorize(
|
63 |
+
generate_with_callback, kwargs, callback=None
|
64 |
+
)
|
65 |
+
|
66 |
+
with generate_with_streaming(**generate_params) as generator:
|
67 |
+
for output in generator:
|
68 |
+
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
69 |
+
yield decoded_output, output
|
70 |
+
if output[-1] in [tokenizer.eos_token_id]:
|
71 |
+
break
|
72 |
+
|
73 |
+
if generation_output:
|
74 |
+
output = generation_output.sequences[0]
|
75 |
+
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
76 |
+
yield decoded_output, output
|
77 |
+
|
78 |
+
return # early return for stream_output
|
79 |
+
|
80 |
+
# Without streaming
|
81 |
+
with torch.no_grad():
|
82 |
+
generation_output = model.generate(**generate_params)
|
83 |
+
output = generation_output.sequences[0]
|
84 |
+
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
85 |
+
yield decoded_output, output
|
86 |
+
return
|
llama_lora/{utils/callbacks.py → lib/streaming_generation_utils.py}
RENAMED
File without changes
|
llama_lora/models.py
CHANGED
@@ -2,25 +2,14 @@ import os
|
|
2 |
import sys
|
3 |
import gc
|
4 |
import json
|
|
|
5 |
|
6 |
import torch
|
7 |
-
from transformers import
|
8 |
from peft import PeftModel
|
9 |
|
10 |
from .globals import Global
|
11 |
-
|
12 |
-
|
13 |
-
def get_device():
|
14 |
-
if torch.cuda.is_available():
|
15 |
-
return "cuda"
|
16 |
-
else:
|
17 |
-
return "cpu"
|
18 |
-
|
19 |
-
try:
|
20 |
-
if torch.backends.mps.is_available():
|
21 |
-
return "mps"
|
22 |
-
except: # noqa: E722
|
23 |
-
pass
|
24 |
|
25 |
|
26 |
def get_new_base_model(base_model_name):
|
@@ -41,7 +30,7 @@ def get_new_base_model(base_model_name):
|
|
41 |
device = get_device()
|
42 |
|
43 |
if device == "cuda":
|
44 |
-
model =
|
45 |
base_model_name,
|
46 |
load_in_8bit=Global.load_8bit,
|
47 |
torch_dtype=torch.float16,
|
@@ -50,19 +39,22 @@ def get_new_base_model(base_model_name):
|
|
50 |
device_map={'': 0},
|
51 |
)
|
52 |
elif device == "mps":
|
53 |
-
model =
|
54 |
base_model_name,
|
55 |
device_map={"": device},
|
56 |
torch_dtype=torch.float16,
|
57 |
)
|
58 |
else:
|
59 |
-
model =
|
60 |
base_model_name, device_map={"": device}, low_cpu_mem_usage=True
|
61 |
)
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
return model
|
68 |
|
@@ -75,7 +67,14 @@ def get_tokenizer(base_model_name):
|
|
75 |
if loaded_tokenizer:
|
76 |
return loaded_tokenizer
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
Global.loaded_tokenizers.set(base_model_name, tokenizer)
|
80 |
|
81 |
return tokenizer
|
@@ -148,9 +147,10 @@ def get_model(
|
|
148 |
device_map={"": device},
|
149 |
)
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
154 |
|
155 |
if not Global.load_8bit:
|
156 |
model.half() # seems to fix bugs for some users.
|
|
|
2 |
import sys
|
3 |
import gc
|
4 |
import json
|
5 |
+
import re
|
6 |
|
7 |
import torch
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
|
9 |
from peft import PeftModel
|
10 |
|
11 |
from .globals import Global
|
12 |
+
from .lib.get_device import get_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def get_new_base_model(base_model_name):
|
|
|
30 |
device = get_device()
|
31 |
|
32 |
if device == "cuda":
|
33 |
+
model = AutoModelForCausalLM.from_pretrained(
|
34 |
base_model_name,
|
35 |
load_in_8bit=Global.load_8bit,
|
36 |
torch_dtype=torch.float16,
|
|
|
39 |
device_map={'': 0},
|
40 |
)
|
41 |
elif device == "mps":
|
42 |
+
model = AutoModelForCausalLM.from_pretrained(
|
43 |
base_model_name,
|
44 |
device_map={"": device},
|
45 |
torch_dtype=torch.float16,
|
46 |
)
|
47 |
else:
|
48 |
+
model = AutoModelForCausalLM.from_pretrained(
|
49 |
base_model_name, device_map={"": device}, low_cpu_mem_usage=True
|
50 |
)
|
51 |
|
52 |
+
tokenizer = get_tokenizer(base_model_name)
|
53 |
+
|
54 |
+
if re.match("[^/]+/llama", base_model_name):
|
55 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0
|
56 |
+
model.config.bos_token_id = tokenizer.bos_token_id = 1
|
57 |
+
model.config.eos_token_id = tokenizer.eos_token_id = 2
|
58 |
|
59 |
return model
|
60 |
|
|
|
67 |
if loaded_tokenizer:
|
68 |
return loaded_tokenizer
|
69 |
|
70 |
+
try:
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
72 |
+
except Exception as e:
|
73 |
+
if 'LLaMATokenizer' in str(e):
|
74 |
+
tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
|
75 |
+
else:
|
76 |
+
raise e
|
77 |
+
|
78 |
Global.loaded_tokenizers.set(base_model_name, tokenizer)
|
79 |
|
80 |
return tokenizer
|
|
|
147 |
device_map={"": device},
|
148 |
)
|
149 |
|
150 |
+
if re.match("[^/]+/llama", base_model_name):
|
151 |
+
model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
|
152 |
+
model.config.bos_token_id = 1
|
153 |
+
model.config.eos_token_id = 2
|
154 |
|
155 |
if not Global.load_8bit:
|
156 |
model.half() # seems to fix bugs for some users.
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import time
|
|
|
|
|
4 |
from datetime import datetime
|
5 |
import gradio as gr
|
6 |
import math
|
@@ -15,7 +17,8 @@ from ..models import (
|
|
15 |
from ..utils.data import (
|
16 |
get_available_template_names,
|
17 |
get_available_dataset_names,
|
18 |
-
get_dataset_content
|
|
|
19 |
)
|
20 |
from ..utils.prompter import Prompter
|
21 |
|
@@ -47,13 +50,16 @@ def reload_selections(current_template, current_dataset):
|
|
47 |
current_dataset = current_dataset or next(
|
48 |
iter(available_dataset_names), None)
|
49 |
|
|
|
|
|
50 |
return (
|
51 |
gr.Dropdown.update(
|
52 |
choices=available_template_names_with_none,
|
53 |
value=current_template),
|
54 |
gr.Dropdown.update(
|
55 |
choices=available_dataset_names,
|
56 |
-
value=current_dataset)
|
|
|
57 |
)
|
58 |
|
59 |
|
@@ -79,56 +85,47 @@ def load_sample_dataset_to_text_input(format):
|
|
79 |
return gr.Code.update(value=sample_plain_text_value)
|
80 |
|
81 |
|
82 |
-
def
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
data = [
|
124 |
-
{
|
125 |
-
"variables":
|
126 |
-
{k: v for k, v in d.items() if k != "output"},
|
127 |
-
"output":
|
128 |
-
d["output"]
|
129 |
-
}
|
130 |
-
for d in data
|
131 |
-
]
|
132 |
return data
|
133 |
|
134 |
|
@@ -141,72 +138,92 @@ def refresh_preview(
|
|
141 |
dataset_plain_text_input_variables_separator,
|
142 |
dataset_plain_text_input_and_output_separator,
|
143 |
dataset_plain_text_data_separator,
|
144 |
-
|
145 |
):
|
146 |
try:
|
147 |
-
max_preview_count = 100
|
148 |
prompter = Prompter(template)
|
149 |
variable_names = prompter.get_variable_names()
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
raise ValueError(
|
165 |
-
f"Error parsing JSON on line {line_number}: {e}")
|
166 |
-
|
167 |
-
data = process_json_dataset(data)
|
168 |
-
|
169 |
-
else: # Plain Text
|
170 |
-
data = parse_plain_text_input(
|
171 |
-
dataset_text,
|
172 |
-
(
|
173 |
-
dataset_plain_text_input_variables_separator or
|
174 |
-
default_dataset_plain_text_input_variables_separator
|
175 |
-
).replace("\\n", "\n"),
|
176 |
-
(
|
177 |
-
dataset_plain_text_input_and_output_separator or
|
178 |
-
default_dataset_plain_text_input_and_output_separator
|
179 |
-
).replace("\\n", "\n"),
|
180 |
-
(
|
181 |
-
dataset_plain_text_data_separator or
|
182 |
-
default_dataset_plain_text_data_separator
|
183 |
-
).replace("\\n", "\n"),
|
184 |
-
variable_names
|
185 |
-
)
|
186 |
|
187 |
-
|
188 |
-
data = get_dataset_content(dataset_from_data_dir)
|
189 |
-
data = process_json_dataset(data)
|
190 |
|
191 |
data_count = len(data)
|
192 |
-
|
|
|
193 |
preview_data = [
|
194 |
-
[item
|
195 |
-
for item in
|
196 |
]
|
197 |
|
198 |
-
if
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
208 |
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
if data_count > max_preview_count:
|
211 |
preview_info_message += f" Previewing the first {max_preview_count}."
|
212 |
|
@@ -215,11 +232,22 @@ def refresh_preview(
|
|
215 |
info_message = "This dataset contains " + info_message
|
216 |
update_message = gr.Markdown.update(info_message, visible=True)
|
217 |
|
218 |
-
return gr.
|
219 |
except Exception as e:
|
220 |
update_message = gr.Markdown.update(
|
221 |
f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
|
225 |
def parse_plain_text_input(
|
@@ -258,7 +286,7 @@ def do_train(
|
|
258 |
dataset_plain_text_data_separator,
|
259 |
# Training Options
|
260 |
max_seq_length,
|
261 |
-
|
262 |
micro_batch_size,
|
263 |
gradient_accumulation_steps,
|
264 |
epochs,
|
@@ -268,14 +296,27 @@ def do_train(
|
|
268 |
lora_alpha,
|
269 |
lora_dropout,
|
270 |
lora_target_modules,
|
271 |
-
model_name,
|
272 |
save_steps,
|
273 |
save_total_limit,
|
274 |
logging_steps,
|
|
|
|
|
|
|
275 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
276 |
):
|
277 |
try:
|
278 |
-
base_model_name = Global.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
280 |
if os.path.exists(output_dir):
|
281 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
@@ -288,56 +329,22 @@ def do_train(
|
|
288 |
unload_models() # Need RAM for training
|
289 |
|
290 |
prompter = Prompter(template)
|
291 |
-
variable_names = prompter.get_variable_names()
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
try:
|
304 |
-
data.append(json.loads(line))
|
305 |
-
except Exception as e:
|
306 |
-
raise ValueError(
|
307 |
-
f"Error parsing JSON on line {line_number}: {e}")
|
308 |
-
|
309 |
-
data = process_json_dataset(data)
|
310 |
-
|
311 |
-
else: # Plain Text
|
312 |
-
data = parse_plain_text_input(
|
313 |
-
dataset_text,
|
314 |
-
(
|
315 |
-
dataset_plain_text_input_variables_separator or
|
316 |
-
default_dataset_plain_text_input_variables_separator
|
317 |
-
).replace("\\n", "\n"),
|
318 |
-
(
|
319 |
-
dataset_plain_text_input_and_output_separator or
|
320 |
-
default_dataset_plain_text_input_and_output_separator
|
321 |
-
).replace("\\n", "\n"),
|
322 |
-
(
|
323 |
-
dataset_plain_text_data_separator or
|
324 |
-
default_dataset_plain_text_data_separator
|
325 |
-
).replace("\\n", "\n"),
|
326 |
-
variable_names
|
327 |
-
)
|
328 |
-
|
329 |
-
else: # Load dataset from data directory
|
330 |
-
data = get_dataset_content(dataset_from_data_dir)
|
331 |
-
data = process_json_dataset(data)
|
332 |
|
333 |
-
|
334 |
-
evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
|
335 |
|
336 |
-
|
337 |
-
{
|
338 |
-
'prompt': prompter.generate_prompt(d['variables']),
|
339 |
-
'completion': d['output']}
|
340 |
-
for d in data]
|
341 |
|
342 |
def get_progress_text(epoch, epochs, last_loss):
|
343 |
progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
|
@@ -380,6 +387,8 @@ Train options: {json.dumps({
|
|
380 |
'lora_dropout': lora_dropout,
|
381 |
'lora_target_modules': lora_target_modules,
|
382 |
'model_name': model_name,
|
|
|
|
|
383 |
}, indent=2)}
|
384 |
|
385 |
Train data (first 10):
|
@@ -390,7 +399,7 @@ Train data (first 10):
|
|
390 |
return message
|
391 |
|
392 |
if not should_training_progress_track_tqdm:
|
393 |
-
progress(0, desc="Preparing model for training...")
|
394 |
|
395 |
log_history = []
|
396 |
|
@@ -449,26 +458,37 @@ Train data (first 10):
|
|
449 |
'dataset_rows': len(train_data),
|
450 |
'timestamp': time.time(),
|
451 |
|
452 |
-
|
453 |
-
'
|
|
|
454 |
|
455 |
-
'micro_batch_size': micro_batch_size,
|
456 |
-
'gradient_accumulation_steps': gradient_accumulation_steps,
|
457 |
-
'epochs': epochs,
|
458 |
-
'learning_rate': learning_rate,
|
459 |
|
460 |
-
'
|
461 |
|
462 |
-
'lora_r': lora_r,
|
463 |
-
'lora_alpha': lora_alpha,
|
464 |
-
'lora_dropout': lora_dropout,
|
465 |
-
'lora_target_modules': lora_target_modules,
|
466 |
}
|
|
|
|
|
|
|
|
|
467 |
json.dump(info, info_json_file, indent=2)
|
468 |
|
469 |
if not should_training_progress_track_tqdm:
|
470 |
progress(0, desc="Train starting...")
|
471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
train_output = Global.train_fn(
|
473 |
base_model, # base_model
|
474 |
tokenizer, # tokenizer
|
@@ -487,11 +507,16 @@ Train data (first 10):
|
|
487 |
lora_target_modules, # lora_target_modules
|
488 |
train_on_inputs, # train_on_inputs
|
489 |
False, # group_by_length
|
490 |
-
|
491 |
save_steps, # save_steps
|
492 |
save_total_limit, # save_total_limit
|
493 |
logging_steps, # logging_steps
|
494 |
-
training_callbacks # callbacks
|
|
|
|
|
|
|
|
|
|
|
495 |
)
|
496 |
|
497 |
logs_str = "\n".join([json.dumps(log)
|
@@ -515,6 +540,146 @@ def do_abort_training():
|
|
515 |
Global.should_stop_training = True
|
516 |
|
517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
def finetune_ui():
|
519 |
things_that_might_timeout = []
|
520 |
|
@@ -606,9 +771,13 @@ def finetune_ui():
|
|
606 |
"Set the dataset in the \"Prepare\" tab, then preview it here.",
|
607 |
elem_id="finetune_dataset_preview_info_message"
|
608 |
)
|
609 |
-
|
610 |
-
label="
|
611 |
-
|
|
|
|
|
|
|
|
|
612 |
)
|
613 |
finetune_dataset_preview = gr.Dataframe(
|
614 |
wrap=True, elem_id="finetune_dataset_preview")
|
@@ -633,25 +802,7 @@ def finetune_ui():
|
|
633 |
dataset_plain_text_data_separator,
|
634 |
]
|
635 |
dataset_preview_inputs = dataset_inputs + \
|
636 |
-
[
|
637 |
-
for i in dataset_preview_inputs:
|
638 |
-
things_that_might_timeout.append(
|
639 |
-
i.change(
|
640 |
-
fn=refresh_preview,
|
641 |
-
inputs=dataset_preview_inputs,
|
642 |
-
outputs=[finetune_dataset_preview,
|
643 |
-
finetune_dataset_preview_info_message,
|
644 |
-
dataset_from_text_message,
|
645 |
-
dataset_from_data_dir_message
|
646 |
-
]
|
647 |
-
))
|
648 |
-
|
649 |
-
things_that_might_timeout.append(reload_selections_button.click(
|
650 |
-
reload_selections,
|
651 |
-
inputs=[template, dataset_from_data_dir],
|
652 |
-
outputs=[template, dataset_from_data_dir],
|
653 |
-
)
|
654 |
-
)
|
655 |
|
656 |
with gr.Row():
|
657 |
max_seq_length = gr.Slider(
|
@@ -704,12 +855,43 @@ def finetune_ui():
|
|
704 |
info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
|
705 |
)
|
706 |
|
707 |
-
|
708 |
-
minimum=0, maximum=
|
709 |
-
label="Evaluation Data
|
710 |
-
info="The
|
711 |
)
|
712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
713 |
with gr.Column():
|
714 |
lora_r = gr.Slider(
|
715 |
minimum=1, maximum=16, step=1, value=8,
|
@@ -729,12 +911,31 @@ def finetune_ui():
|
|
729 |
info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
|
730 |
)
|
731 |
|
|
|
|
|
732 |
lora_target_modules = gr.CheckboxGroup(
|
733 |
label="LoRA Target Modules",
|
734 |
-
choices=
|
735 |
value=["q_proj", "v_proj"],
|
736 |
-
info="Modules to replace with LoRA."
|
|
|
737 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
738 |
|
739 |
with gr.Row():
|
740 |
logging_steps = gr.Number(
|
@@ -759,6 +960,7 @@ def finetune_ui():
|
|
759 |
with gr.Column():
|
760 |
model_name = gr.Textbox(
|
761 |
lines=1, label="LoRA Model Name", value=random_name,
|
|
|
762 |
info="The name of the new LoRA model.",
|
763 |
elem_id="finetune_model_name",
|
764 |
)
|
@@ -778,6 +980,59 @@ def finetune_ui():
|
|
778 |
elem_id="finetune_confirm_stop_btn"
|
779 |
)
|
780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
train_output = gr.Text(
|
782 |
"Training results will be shown here.",
|
783 |
label="Train Output",
|
@@ -785,22 +1040,10 @@ def finetune_ui():
|
|
785 |
|
786 |
train_progress = train_btn.click(
|
787 |
fn=do_train,
|
788 |
-
inputs=(dataset_inputs + [
|
789 |
-
max_seq_length,
|
790 |
-
evaluate_data_percentage,
|
791 |
-
micro_batch_size,
|
792 |
-
gradient_accumulation_steps,
|
793 |
-
epochs,
|
794 |
-
learning_rate,
|
795 |
-
train_on_inputs,
|
796 |
-
lora_r,
|
797 |
-
lora_alpha,
|
798 |
-
lora_dropout,
|
799 |
-
lora_target_modules,
|
800 |
model_name,
|
801 |
-
|
802 |
-
|
803 |
-
logging_steps,
|
804 |
]),
|
805 |
outputs=train_output
|
806 |
)
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import time
|
4 |
+
import traceback
|
5 |
+
import re
|
6 |
from datetime import datetime
|
7 |
import gradio as gr
|
8 |
import math
|
|
|
17 |
from ..utils.data import (
|
18 |
get_available_template_names,
|
19 |
get_available_dataset_names,
|
20 |
+
get_dataset_content,
|
21 |
+
get_available_lora_model_names
|
22 |
)
|
23 |
from ..utils.prompter import Prompter
|
24 |
|
|
|
50 |
current_dataset = current_dataset or next(
|
51 |
iter(available_dataset_names), None)
|
52 |
|
53 |
+
available_lora_models = ["-"] + get_available_lora_model_names()
|
54 |
+
|
55 |
return (
|
56 |
gr.Dropdown.update(
|
57 |
choices=available_template_names_with_none,
|
58 |
value=current_template),
|
59 |
gr.Dropdown.update(
|
60 |
choices=available_dataset_names,
|
61 |
+
value=current_dataset),
|
62 |
+
gr.Dropdown.update(choices=available_lora_models)
|
63 |
)
|
64 |
|
65 |
|
|
|
85 |
return gr.Code.update(value=sample_plain_text_value)
|
86 |
|
87 |
|
88 |
+
def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
|
89 |
+
dataset_plain_text_input_variables_separator,
|
90 |
+
dataset_plain_text_input_and_output_separator,
|
91 |
+
dataset_plain_text_data_separator,
|
92 |
+
dataset_from_data_dir, prompter):
|
93 |
+
if load_dataset_from == "Text Input":
|
94 |
+
if dataset_text_format == "JSON":
|
95 |
+
data = json.loads(dataset_text)
|
96 |
+
|
97 |
+
elif dataset_text_format == "JSON Lines":
|
98 |
+
lines = dataset_text.split('\n')
|
99 |
+
data = []
|
100 |
+
for i, line in enumerate(lines):
|
101 |
+
line_number = i + 1
|
102 |
+
try:
|
103 |
+
data.append(json.loads(line))
|
104 |
+
except Exception as e:
|
105 |
+
raise ValueError(
|
106 |
+
f"Error parsing JSON on line {line_number}: {e}")
|
107 |
+
|
108 |
+
else: # Plain Text
|
109 |
+
data = parse_plain_text_input(
|
110 |
+
dataset_text,
|
111 |
+
(
|
112 |
+
dataset_plain_text_input_variables_separator or
|
113 |
+
default_dataset_plain_text_input_variables_separator
|
114 |
+
).replace("\\n", "\n"),
|
115 |
+
(
|
116 |
+
dataset_plain_text_input_and_output_separator or
|
117 |
+
default_dataset_plain_text_input_and_output_separator
|
118 |
+
).replace("\\n", "\n"),
|
119 |
+
(
|
120 |
+
dataset_plain_text_data_separator or
|
121 |
+
default_dataset_plain_text_data_separator
|
122 |
+
).replace("\\n", "\n"),
|
123 |
+
prompter.get_variable_names()
|
124 |
+
)
|
125 |
+
|
126 |
+
else: # Load dataset from data directory
|
127 |
+
data = get_dataset_content(dataset_from_data_dir)
|
128 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
return data
|
130 |
|
131 |
|
|
|
138 |
dataset_plain_text_input_variables_separator,
|
139 |
dataset_plain_text_input_and_output_separator,
|
140 |
dataset_plain_text_data_separator,
|
141 |
+
max_preview_count,
|
142 |
):
|
143 |
try:
|
|
|
144 |
prompter = Prompter(template)
|
145 |
variable_names = prompter.get_variable_names()
|
146 |
|
147 |
+
data = get_data_from_input(
|
148 |
+
load_dataset_from=load_dataset_from,
|
149 |
+
dataset_text=dataset_text,
|
150 |
+
dataset_text_format=dataset_text_format,
|
151 |
+
dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
|
152 |
+
dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
|
153 |
+
dataset_plain_text_data_separator=dataset_plain_text_data_separator,
|
154 |
+
dataset_from_data_dir=dataset_from_data_dir,
|
155 |
+
prompter=prompter
|
156 |
+
)
|
157 |
+
|
158 |
+
train_data = prompter.get_train_data_from_dataset(
|
159 |
+
data, max_preview_count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
+
train_data = train_data[:max_preview_count]
|
|
|
|
|
162 |
|
163 |
data_count = len(data)
|
164 |
+
|
165 |
+
headers = ['Prompt', 'Completion']
|
166 |
preview_data = [
|
167 |
+
[item.get("prompt", ""), item.get("completion", "")]
|
168 |
+
for item in train_data
|
169 |
]
|
170 |
|
171 |
+
if not prompter.template_module:
|
172 |
+
variable_names = prompter.get_variable_names()
|
173 |
+
headers += [f"Variable: {variable_name}" for variable_name in variable_names]
|
174 |
+
variables = [
|
175 |
+
[item.get(f"_var_{name}", "") for name in variable_names]
|
176 |
+
for item in train_data
|
177 |
+
]
|
178 |
+
preview_data = [d + v for d, v in zip(preview_data, variables)]
|
179 |
+
|
180 |
+
preview_info_message = f"The dataset has about {data_count} item(s)."
|
181 |
+
if data_count > max_preview_count:
|
182 |
+
preview_info_message += f" Previewing the first {max_preview_count}."
|
183 |
+
|
184 |
+
info_message = f"about {data_count} item(s)."
|
185 |
+
if load_dataset_from == "Data Dir":
|
186 |
+
info_message = "This dataset contains about " + info_message
|
187 |
+
update_message = gr.Markdown.update(info_message, visible=True)
|
188 |
|
189 |
+
return gr.Dataframe.update(value={'data': preview_data, 'headers': headers}), gr.Markdown.update(preview_info_message), update_message, update_message
|
190 |
+
except Exception as e:
|
191 |
+
update_message = gr.Markdown.update(
|
192 |
+
f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
|
193 |
+
return gr.Dataframe.update(value={'data': [], 'headers': []}), gr.Markdown.update("Set the dataset in the \"Prepare\" tab, then preview it here."), update_message, update_message
|
194 |
|
195 |
+
|
196 |
+
def refresh_dataset_items_count(
|
197 |
+
template,
|
198 |
+
load_dataset_from,
|
199 |
+
dataset_from_data_dir,
|
200 |
+
dataset_text,
|
201 |
+
dataset_text_format,
|
202 |
+
dataset_plain_text_input_variables_separator,
|
203 |
+
dataset_plain_text_input_and_output_separator,
|
204 |
+
dataset_plain_text_data_separator,
|
205 |
+
max_preview_count,
|
206 |
+
):
|
207 |
+
try:
|
208 |
+
prompter = Prompter(template)
|
209 |
+
variable_names = prompter.get_variable_names()
|
210 |
+
|
211 |
+
data = get_data_from_input(
|
212 |
+
load_dataset_from=load_dataset_from,
|
213 |
+
dataset_text=dataset_text,
|
214 |
+
dataset_text_format=dataset_text_format,
|
215 |
+
dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
|
216 |
+
dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
|
217 |
+
dataset_plain_text_data_separator=dataset_plain_text_data_separator,
|
218 |
+
dataset_from_data_dir=dataset_from_data_dir,
|
219 |
+
prompter=prompter
|
220 |
+
)
|
221 |
+
|
222 |
+
train_data = prompter.get_train_data_from_dataset(
|
223 |
+
data)
|
224 |
+
data_count = len(train_data)
|
225 |
+
|
226 |
+
preview_info_message = f"The dataset contains {data_count} item(s)."
|
227 |
if data_count > max_preview_count:
|
228 |
preview_info_message += f" Previewing the first {max_preview_count}."
|
229 |
|
|
|
232 |
info_message = "This dataset contains " + info_message
|
233 |
update_message = gr.Markdown.update(info_message, visible=True)
|
234 |
|
235 |
+
return gr.Markdown.update(preview_info_message), update_message, update_message, gr.Slider.update(maximum=math.floor(data_count / 2))
|
236 |
except Exception as e:
|
237 |
update_message = gr.Markdown.update(
|
238 |
f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
|
239 |
+
|
240 |
+
trace = traceback.format_exc()
|
241 |
+
traces = [s.strip() for s in re.split("\n * File ", trace)]
|
242 |
+
templates_path = os.path.join(Global.data_dir, "templates")
|
243 |
+
traces_to_show = [s for s in traces if os.path.join(
|
244 |
+
Global.data_dir, "templates") in s]
|
245 |
+
traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
|
246 |
+
if len(traces_to_show) > 0:
|
247 |
+
update_message = gr.Markdown.update(
|
248 |
+
f"<span class=\"finetune_dataset_error_message\">Error: {e} ({','.join(traces_to_show)}).</span>", visible=True)
|
249 |
+
|
250 |
+
return gr.Markdown.update("Set the dataset in the \"Prepare\" tab, then preview it here."), update_message, update_message, gr.Slider.update(maximum=1)
|
251 |
|
252 |
|
253 |
def parse_plain_text_input(
|
|
|
286 |
dataset_plain_text_data_separator,
|
287 |
# Training Options
|
288 |
max_seq_length,
|
289 |
+
evaluate_data_count,
|
290 |
micro_batch_size,
|
291 |
gradient_accumulation_steps,
|
292 |
epochs,
|
|
|
296 |
lora_alpha,
|
297 |
lora_dropout,
|
298 |
lora_target_modules,
|
|
|
299 |
save_steps,
|
300 |
save_total_limit,
|
301 |
logging_steps,
|
302 |
+
model_name,
|
303 |
+
continue_from_model,
|
304 |
+
continue_from_checkpoint,
|
305 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
306 |
):
|
307 |
try:
|
308 |
+
base_model_name = Global.base_model_name
|
309 |
+
|
310 |
+
resume_from_checkpoint = None
|
311 |
+
if continue_from_model == "-" or continue_from_model == "None":
|
312 |
+
continue_from_model = None
|
313 |
+
if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
|
314 |
+
continue_from_checkpoint = None
|
315 |
+
if continue_from_model:
|
316 |
+
resume_from_checkpoint = os.path.join(Global.data_dir, "lora_models", continue_from_model)
|
317 |
+
if continue_from_checkpoint:
|
318 |
+
resume_from_checkpoint = os.path.join(resume_from_checkpoint, continue_from_checkpoint)
|
319 |
+
|
320 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
321 |
if os.path.exists(output_dir):
|
322 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
|
|
329 |
unload_models() # Need RAM for training
|
330 |
|
331 |
prompter = Prompter(template)
|
332 |
+
# variable_names = prompter.get_variable_names()
|
333 |
+
|
334 |
+
data = get_data_from_input(
|
335 |
+
load_dataset_from=load_dataset_from,
|
336 |
+
dataset_text=dataset_text,
|
337 |
+
dataset_text_format=dataset_text_format,
|
338 |
+
dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
|
339 |
+
dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
|
340 |
+
dataset_plain_text_data_separator=dataset_plain_text_data_separator,
|
341 |
+
dataset_from_data_dir=dataset_from_data_dir,
|
342 |
+
prompter=prompter
|
343 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
+
train_data = prompter.get_train_data_from_dataset(data)
|
|
|
346 |
|
347 |
+
data_count = len(train_data)
|
|
|
|
|
|
|
|
|
348 |
|
349 |
def get_progress_text(epoch, epochs, last_loss):
|
350 |
progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
|
|
|
387 |
'lora_dropout': lora_dropout,
|
388 |
'lora_target_modules': lora_target_modules,
|
389 |
'model_name': model_name,
|
390 |
+
'continue_from_model': continue_from_model,
|
391 |
+
'continue_from_checkpoint': continue_from_checkpoint,
|
392 |
}, indent=2)}
|
393 |
|
394 |
Train data (first 10):
|
|
|
399 |
return message
|
400 |
|
401 |
if not should_training_progress_track_tqdm:
|
402 |
+
progress(0, desc=f"Preparing model {base_model_name} for training...")
|
403 |
|
404 |
log_history = []
|
405 |
|
|
|
458 |
'dataset_rows': len(train_data),
|
459 |
'timestamp': time.time(),
|
460 |
|
461 |
+
# These will be saved in another JSON file by the train function
|
462 |
+
# 'max_seq_length': max_seq_length,
|
463 |
+
# 'train_on_inputs': train_on_inputs,
|
464 |
|
465 |
+
# 'micro_batch_size': micro_batch_size,
|
466 |
+
# 'gradient_accumulation_steps': gradient_accumulation_steps,
|
467 |
+
# 'epochs': epochs,
|
468 |
+
# 'learning_rate': learning_rate,
|
469 |
|
470 |
+
# 'evaluate_data_count': evaluate_data_count,
|
471 |
|
472 |
+
# 'lora_r': lora_r,
|
473 |
+
# 'lora_alpha': lora_alpha,
|
474 |
+
# 'lora_dropout': lora_dropout,
|
475 |
+
# 'lora_target_modules': lora_target_modules,
|
476 |
}
|
477 |
+
if continue_from_model:
|
478 |
+
info['continued_from_model'] = continue_from_model
|
479 |
+
if continue_from_checkpoint:
|
480 |
+
info['continued_from_checkpoint'] = continue_from_checkpoint
|
481 |
json.dump(info, info_json_file, indent=2)
|
482 |
|
483 |
if not should_training_progress_track_tqdm:
|
484 |
progress(0, desc="Train starting...")
|
485 |
|
486 |
+
wandb_group = template
|
487 |
+
wandb_tags = [f"template:{template}"]
|
488 |
+
if load_dataset_from == "Data Dir" and dataset_from_data_dir:
|
489 |
+
wandb_group += f"/{dataset_from_data_dir}"
|
490 |
+
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
491 |
+
|
492 |
train_output = Global.train_fn(
|
493 |
base_model, # base_model
|
494 |
tokenizer, # tokenizer
|
|
|
507 |
lora_target_modules, # lora_target_modules
|
508 |
train_on_inputs, # train_on_inputs
|
509 |
False, # group_by_length
|
510 |
+
resume_from_checkpoint, # resume_from_checkpoint
|
511 |
save_steps, # save_steps
|
512 |
save_total_limit, # save_total_limit
|
513 |
logging_steps, # logging_steps
|
514 |
+
training_callbacks, # callbacks
|
515 |
+
Global.wandb_api_key, # wandb_api_key
|
516 |
+
Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
|
517 |
+
wandb_group, # wandb_group
|
518 |
+
model_name, # wandb_run_name
|
519 |
+
wandb_tags # wandb_tags
|
520 |
)
|
521 |
|
522 |
logs_str = "\n".join([json.dumps(log)
|
|
|
540 |
Global.should_stop_training = True
|
541 |
|
542 |
|
543 |
+
def handle_continue_from_model_change(model_name):
|
544 |
+
try:
|
545 |
+
lora_models_directory_path = os.path.join(
|
546 |
+
Global.data_dir, "lora_models")
|
547 |
+
lora_model_directory_path = os.path.join(
|
548 |
+
lora_models_directory_path, model_name)
|
549 |
+
all_files = os.listdir(lora_model_directory_path)
|
550 |
+
checkpoints = [
|
551 |
+
file for file in all_files if file.startswith("checkpoint-")]
|
552 |
+
checkpoints = ["-"] + checkpoints
|
553 |
+
can_load_params = "finetune_params.json" in all_files or "finetune_args.json" in all_files
|
554 |
+
return gr.Dropdown.update(choices=checkpoints, value="-"), gr.Button.update(visible=can_load_params), gr.Markdown.update(value="", visible=False)
|
555 |
+
except Exception:
|
556 |
+
pass
|
557 |
+
return gr.Dropdown.update(choices=["-"], value="-"), gr.Button.update(visible=False), gr.Markdown.update(value="", visible=False)
|
558 |
+
|
559 |
+
|
560 |
+
def handle_load_params_from_model(
|
561 |
+
model_name,
|
562 |
+
max_seq_length,
|
563 |
+
evaluate_data_count,
|
564 |
+
micro_batch_size,
|
565 |
+
gradient_accumulation_steps,
|
566 |
+
epochs,
|
567 |
+
learning_rate,
|
568 |
+
train_on_inputs,
|
569 |
+
lora_r,
|
570 |
+
lora_alpha,
|
571 |
+
lora_dropout,
|
572 |
+
lora_target_modules,
|
573 |
+
save_steps,
|
574 |
+
save_total_limit,
|
575 |
+
logging_steps,
|
576 |
+
lora_target_module_choices,
|
577 |
+
):
|
578 |
+
error_message = ""
|
579 |
+
notice_message = ""
|
580 |
+
unknown_keys = []
|
581 |
+
try:
|
582 |
+
lora_models_directory_path = os.path.join(
|
583 |
+
Global.data_dir, "lora_models")
|
584 |
+
lora_model_directory_path = os.path.join(
|
585 |
+
lora_models_directory_path, model_name)
|
586 |
+
|
587 |
+
data = {}
|
588 |
+
possible_files = ["finetune_params.json", "finetune_args.json"]
|
589 |
+
for file in possible_files:
|
590 |
+
try:
|
591 |
+
with open(os.path.join(lora_model_directory_path, file), "r") as f:
|
592 |
+
data = json.load(f)
|
593 |
+
except FileNotFoundError:
|
594 |
+
pass
|
595 |
+
|
596 |
+
for key, value in data.items():
|
597 |
+
if key == "max_seq_length":
|
598 |
+
max_seq_length = value
|
599 |
+
if key == "cutoff_len":
|
600 |
+
cutoff_len = value
|
601 |
+
elif key == "evaluate_data_count":
|
602 |
+
evaluate_data_count = value
|
603 |
+
elif key == "val_set_size":
|
604 |
+
evaluate_data_count = value
|
605 |
+
elif key == "micro_batch_size":
|
606 |
+
micro_batch_size = value
|
607 |
+
elif key == "gradient_accumulation_steps":
|
608 |
+
gradient_accumulation_steps = value
|
609 |
+
elif key == "epochs":
|
610 |
+
epochs = value
|
611 |
+
elif key == "num_train_epochs":
|
612 |
+
epochs = value
|
613 |
+
elif key == "learning_rate":
|
614 |
+
learning_rate = value
|
615 |
+
elif key == "train_on_inputs":
|
616 |
+
train_on_inputs = value
|
617 |
+
elif key == "lora_r":
|
618 |
+
lora_r = value
|
619 |
+
elif key == "lora_alpha":
|
620 |
+
lora_alpha = value
|
621 |
+
elif key == "lora_dropout":
|
622 |
+
lora_dropout = value
|
623 |
+
elif key == "lora_target_modules":
|
624 |
+
lora_target_modules = value
|
625 |
+
for element in value:
|
626 |
+
if element not in lora_target_module_choices:
|
627 |
+
lora_target_module_choices.append(element)
|
628 |
+
elif key == "save_steps":
|
629 |
+
save_steps = value
|
630 |
+
elif key == "save_total_limit":
|
631 |
+
save_total_limit = value
|
632 |
+
elif key == "logging_steps":
|
633 |
+
logging_steps = value
|
634 |
+
elif key == "group_by_length":
|
635 |
+
pass
|
636 |
+
elif key == "resume_from_checkpoint":
|
637 |
+
pass
|
638 |
+
else:
|
639 |
+
unknown_keys.append(key)
|
640 |
+
except Exception as e:
|
641 |
+
error_message = str(e)
|
642 |
+
|
643 |
+
if len(unknown_keys) > 0:
|
644 |
+
notice_message = f"Note: cannot restore unknown arg: {', '.join([f'`{x}`' for x in unknown_keys])}"
|
645 |
+
|
646 |
+
message = ". ".join([x for x in [error_message, notice_message] if x])
|
647 |
+
|
648 |
+
has_message = False
|
649 |
+
if message:
|
650 |
+
message += "."
|
651 |
+
has_message = True
|
652 |
+
|
653 |
+
return (
|
654 |
+
gr.Markdown.update(value=message, visible=has_message),
|
655 |
+
max_seq_length,
|
656 |
+
evaluate_data_count,
|
657 |
+
micro_batch_size,
|
658 |
+
gradient_accumulation_steps,
|
659 |
+
epochs,
|
660 |
+
learning_rate,
|
661 |
+
train_on_inputs,
|
662 |
+
lora_r,
|
663 |
+
lora_alpha,
|
664 |
+
lora_dropout,
|
665 |
+
gr.CheckboxGroup.update(value=lora_target_modules, choices=lora_target_module_choices),
|
666 |
+
save_steps,
|
667 |
+
save_total_limit,
|
668 |
+
logging_steps,
|
669 |
+
lora_target_module_choices,
|
670 |
+
)
|
671 |
+
|
672 |
+
|
673 |
+
default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
674 |
+
|
675 |
+
|
676 |
+
def handle_lora_target_modules_add(choices, new_module, selected_modules):
|
677 |
+
choices.append(new_module)
|
678 |
+
selected_modules.append(new_module)
|
679 |
+
|
680 |
+
return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
|
681 |
+
|
682 |
+
|
683 |
def finetune_ui():
|
684 |
things_that_might_timeout = []
|
685 |
|
|
|
771 |
"Set the dataset in the \"Prepare\" tab, then preview it here.",
|
772 |
elem_id="finetune_dataset_preview_info_message"
|
773 |
)
|
774 |
+
finetune_dataset_preview_count = gr.Number(
|
775 |
+
label="Preview items count",
|
776 |
+
value=10,
|
777 |
+
# minimum=1,
|
778 |
+
# maximum=100,
|
779 |
+
precision=0,
|
780 |
+
elem_id="finetune_dataset_preview_count"
|
781 |
)
|
782 |
finetune_dataset_preview = gr.Dataframe(
|
783 |
wrap=True, elem_id="finetune_dataset_preview")
|
|
|
802 |
dataset_plain_text_data_separator,
|
803 |
]
|
804 |
dataset_preview_inputs = dataset_inputs + \
|
805 |
+
[finetune_dataset_preview_count]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
|
807 |
with gr.Row():
|
808 |
max_seq_length = gr.Slider(
|
|
|
855 |
info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
|
856 |
)
|
857 |
|
858 |
+
evaluate_data_count = gr.Slider(
|
859 |
+
minimum=0, maximum=1, step=1, value=0,
|
860 |
+
label="Evaluation Data Count",
|
861 |
+
info="The number of data to be used for evaluation. This amount of data will not be used for training and will be used to assess the performance of the model during the process."
|
862 |
)
|
863 |
|
864 |
+
with gr.Box(elem_id="finetune_continue_from_model_box"):
|
865 |
+
with gr.Row():
|
866 |
+
continue_from_model = gr.Dropdown(
|
867 |
+
value="-",
|
868 |
+
label="Continue from Model",
|
869 |
+
choices=["-"],
|
870 |
+
elem_id="finetune_continue_from_model"
|
871 |
+
)
|
872 |
+
continue_from_checkpoint = gr.Dropdown(
|
873 |
+
value="-", label="Checkpoint", choices=["-"])
|
874 |
+
with gr.Column():
|
875 |
+
load_params_from_model_btn = gr.Button(
|
876 |
+
"Load training parameters from selected model", visible=False)
|
877 |
+
load_params_from_model_btn.style(
|
878 |
+
full_width=False,
|
879 |
+
size="sm")
|
880 |
+
load_params_from_model_message = gr.Markdown(
|
881 |
+
"", visible=False)
|
882 |
+
|
883 |
+
things_that_might_timeout.append(
|
884 |
+
continue_from_model.change(
|
885 |
+
fn=handle_continue_from_model_change,
|
886 |
+
inputs=[continue_from_model],
|
887 |
+
outputs=[
|
888 |
+
continue_from_checkpoint,
|
889 |
+
load_params_from_model_btn,
|
890 |
+
load_params_from_model_message
|
891 |
+
]
|
892 |
+
)
|
893 |
+
)
|
894 |
+
|
895 |
with gr.Column():
|
896 |
lora_r = gr.Slider(
|
897 |
minimum=1, maximum=16, step=1, value=8,
|
|
|
911 |
info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
|
912 |
)
|
913 |
|
914 |
+
lora_target_module_choices = gr.State(value=default_lora_target_module_choices)
|
915 |
+
|
916 |
lora_target_modules = gr.CheckboxGroup(
|
917 |
label="LoRA Target Modules",
|
918 |
+
choices=default_lora_target_module_choices,
|
919 |
value=["q_proj", "v_proj"],
|
920 |
+
info="Modules to replace with LoRA.",
|
921 |
+
elem_id="finetune_lora_target_modules"
|
922 |
)
|
923 |
+
with gr.Box(elem_id="finetune_lora_target_modules_add_box"):
|
924 |
+
with gr.Row():
|
925 |
+
lora_target_modules_add = gr.Textbox(
|
926 |
+
lines=1, max_lines=1, show_label=False,
|
927 |
+
elem_id="finetune_lora_target_modules_add"
|
928 |
+
)
|
929 |
+
lora_target_modules_add_btn = gr.Button(
|
930 |
+
"Add",
|
931 |
+
elem_id="finetune_lora_target_modules_add_btn"
|
932 |
+
)
|
933 |
+
lora_target_modules_add_btn.style(full_width=False, size="sm")
|
934 |
+
things_that_might_timeout.append(lora_target_modules_add_btn.click(
|
935 |
+
handle_lora_target_modules_add,
|
936 |
+
inputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
|
937 |
+
outputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
|
938 |
+
))
|
939 |
|
940 |
with gr.Row():
|
941 |
logging_steps = gr.Number(
|
|
|
960 |
with gr.Column():
|
961 |
model_name = gr.Textbox(
|
962 |
lines=1, label="LoRA Model Name", value=random_name,
|
963 |
+
max_lines=1,
|
964 |
info="The name of the new LoRA model.",
|
965 |
elem_id="finetune_model_name",
|
966 |
)
|
|
|
980 |
elem_id="finetune_confirm_stop_btn"
|
981 |
)
|
982 |
|
983 |
+
things_that_might_timeout.append(reload_selections_button.click(
|
984 |
+
reload_selections,
|
985 |
+
inputs=[template, dataset_from_data_dir],
|
986 |
+
outputs=[template, dataset_from_data_dir, continue_from_model],
|
987 |
+
))
|
988 |
+
|
989 |
+
for i in dataset_preview_inputs:
|
990 |
+
things_that_might_timeout.append(
|
991 |
+
i.change(
|
992 |
+
fn=refresh_preview,
|
993 |
+
inputs=dataset_preview_inputs,
|
994 |
+
outputs=[
|
995 |
+
finetune_dataset_preview,
|
996 |
+
finetune_dataset_preview_info_message,
|
997 |
+
dataset_from_text_message,
|
998 |
+
dataset_from_data_dir_message
|
999 |
+
]
|
1000 |
+
).then(
|
1001 |
+
fn=refresh_dataset_items_count,
|
1002 |
+
inputs=dataset_preview_inputs,
|
1003 |
+
outputs=[
|
1004 |
+
finetune_dataset_preview_info_message,
|
1005 |
+
dataset_from_text_message,
|
1006 |
+
dataset_from_data_dir_message,
|
1007 |
+
evaluate_data_count,
|
1008 |
+
]
|
1009 |
+
))
|
1010 |
+
|
1011 |
+
finetune_args = [
|
1012 |
+
max_seq_length,
|
1013 |
+
evaluate_data_count,
|
1014 |
+
micro_batch_size,
|
1015 |
+
gradient_accumulation_steps,
|
1016 |
+
epochs,
|
1017 |
+
learning_rate,
|
1018 |
+
train_on_inputs,
|
1019 |
+
lora_r,
|
1020 |
+
lora_alpha,
|
1021 |
+
lora_dropout,
|
1022 |
+
lora_target_modules,
|
1023 |
+
save_steps,
|
1024 |
+
save_total_limit,
|
1025 |
+
logging_steps,
|
1026 |
+
]
|
1027 |
+
|
1028 |
+
things_that_might_timeout.append(
|
1029 |
+
load_params_from_model_btn.click(
|
1030 |
+
fn=handle_load_params_from_model,
|
1031 |
+
inputs=[continue_from_model] + finetune_args + [lora_target_module_choices],
|
1032 |
+
outputs=[load_params_from_model_message] + finetune_args + [lora_target_module_choices]
|
1033 |
+
)
|
1034 |
+
)
|
1035 |
+
|
1036 |
train_output = gr.Text(
|
1037 |
"Training results will be shown here.",
|
1038 |
label="Train Output",
|
|
|
1040 |
|
1041 |
train_progress = train_btn.click(
|
1042 |
fn=do_train,
|
1043 |
+
inputs=(dataset_inputs + finetune_args + [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1044 |
model_name,
|
1045 |
+
continue_from_model,
|
1046 |
+
continue_from_checkpoint,
|
|
|
1047 |
]),
|
1048 |
outputs=train_output
|
1049 |
)
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -8,12 +8,12 @@ from transformers import GenerationConfig
|
|
8 |
|
9 |
from ..globals import Global
|
10 |
from ..models import get_model, get_tokenizer, get_device
|
|
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
get_info_of_available_lora_model)
|
15 |
from ..utils.prompter import Prompter
|
16 |
-
from ..utils.callbacks import Iteratorize, Stream
|
17 |
|
18 |
device = get_device()
|
19 |
|
@@ -22,7 +22,7 @@ inference_output_lines = 12
|
|
22 |
|
23 |
|
24 |
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
25 |
-
base_model_name = Global.
|
26 |
|
27 |
try:
|
28 |
get_tokenizer(base_model_name)
|
@@ -48,7 +48,7 @@ def do_inference(
|
|
48 |
show_raw=False,
|
49 |
progress=gr.Progress(track_tqdm=True),
|
50 |
):
|
51 |
-
base_model_name = Global.
|
52 |
|
53 |
try:
|
54 |
if Global.generation_force_stopped_at is not None:
|
@@ -103,8 +103,6 @@ def do_inference(
|
|
103 |
tokenizer = get_tokenizer(base_model_name)
|
104 |
model = get_model(base_model_name, lora_model_name)
|
105 |
|
106 |
-
inputs = tokenizer(prompt, return_tensors="pt")
|
107 |
-
input_ids = inputs["input_ids"].to(device)
|
108 |
generation_config = GenerationConfig(
|
109 |
temperature=temperature,
|
110 |
top_p=top_p,
|
@@ -113,103 +111,55 @@ def do_inference(
|
|
113 |
num_beams=num_beams,
|
114 |
)
|
115 |
|
116 |
-
generate_params = {
|
117 |
-
"input_ids": input_ids,
|
118 |
-
"generation_config": generation_config,
|
119 |
-
"return_dict_in_generate": True,
|
120 |
-
"output_scores": True,
|
121 |
-
"max_new_tokens": max_new_tokens,
|
122 |
-
}
|
123 |
-
|
124 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
125 |
if Global.should_stop_generating:
|
126 |
return True
|
127 |
return False
|
128 |
|
129 |
Global.should_stop_generating = False
|
130 |
-
generate_params.setdefault(
|
131 |
-
"stopping_criteria", transformers.StoppingCriteriaList()
|
132 |
-
)
|
133 |
-
generate_params["stopping_criteria"].append(
|
134 |
-
ui_generation_stopping_criteria
|
135 |
-
)
|
136 |
-
|
137 |
-
if stream_output:
|
138 |
-
# Stream the reply 1 token at a time.
|
139 |
-
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
140 |
-
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
141 |
-
|
142 |
-
def generate_with_callback(callback=None, **kwargs):
|
143 |
-
kwargs.setdefault(
|
144 |
-
"stopping_criteria", transformers.StoppingCriteriaList()
|
145 |
-
)
|
146 |
-
kwargs["stopping_criteria"].append(
|
147 |
-
Stream(callback_func=callback)
|
148 |
-
)
|
149 |
-
with torch.no_grad():
|
150 |
-
model.generate(**kwargs)
|
151 |
-
|
152 |
-
def generate_with_streaming(**kwargs):
|
153 |
-
return Iteratorize(
|
154 |
-
generate_with_callback, kwargs, callback=None
|
155 |
-
)
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
if show_raw:
|
167 |
-
raw_output = str(output)
|
168 |
-
response = prompter.get_response(decoded_output)
|
169 |
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
value=response, lines=inference_output_lines),
|
176 |
-
raw_output)
|
177 |
-
|
178 |
-
if Global.should_stop_generating:
|
179 |
-
# If the user stops the generation, and then clicks the
|
180 |
-
# generation button again, they may mysteriously landed
|
181 |
-
# here, in the previous, should-be-stopped generation
|
182 |
-
# function call, with the new generation function not be
|
183 |
-
# called at all. To workaround this, we yield a message
|
184 |
-
# and setting lines=1, and if the front-end JS detects
|
185 |
-
# that lines has been set to 1 (rows="1" in HTML),
|
186 |
-
# it will automatically click the generate button again
|
187 |
-
# (gr.Textbox.update() does not support updating
|
188 |
-
# elem_classes or elem_id).
|
189 |
-
# [WORKAROUND-UI01]
|
190 |
-
yield (
|
191 |
-
gr.Textbox.update(
|
192 |
-
value="Please retry", lines=1),
|
193 |
-
None)
|
194 |
-
return # early return for stream_output
|
195 |
-
|
196 |
-
# Without streaming
|
197 |
-
with torch.no_grad():
|
198 |
-
generation_output = model.generate(**generate_params)
|
199 |
-
s = generation_output.sequences[0]
|
200 |
-
output = tokenizer.decode(s)
|
201 |
-
raw_output = None
|
202 |
-
if show_raw:
|
203 |
-
raw_output = str(s)
|
204 |
-
|
205 |
-
response = prompter.get_response(output)
|
206 |
-
if Global.should_stop_generating:
|
207 |
-
return
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
except Exception as e:
|
214 |
raise gr.Error(e)
|
215 |
|
@@ -229,7 +179,7 @@ def reload_selections(current_lora_model, current_prompt_template):
|
|
229 |
current_prompt_template = current_prompt_template or next(
|
230 |
iter(available_template_names_with_none), None)
|
231 |
|
232 |
-
default_lora_models = [
|
233 |
available_lora_models = default_lora_models + get_available_lora_model_names()
|
234 |
available_lora_models = available_lora_models + ["None"]
|
235 |
|
@@ -255,8 +205,12 @@ def handle_prompt_template_change(prompt_template, lora_model):
|
|
255 |
"", visible=False)
|
256 |
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
257 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
|
|
258 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
259 |
-
if
|
|
|
|
|
|
|
260 |
model_prompt_template_message_update = gr.Markdown.update(
|
261 |
f"This model was trained with prompt template `{model_prompt_template}`.", visible=True)
|
262 |
|
@@ -303,7 +257,7 @@ def inference_ui():
|
|
303 |
lora_model = gr.Dropdown(
|
304 |
label="LoRA Model",
|
305 |
elem_id="inference_lora_model",
|
306 |
-
value="
|
307 |
allow_custom_value=True,
|
308 |
)
|
309 |
prompt_template = gr.Dropdown(
|
@@ -433,6 +387,8 @@ def inference_ui():
|
|
433 |
interactive=False,
|
434 |
elem_id="inference_raw_output")
|
435 |
|
|
|
|
|
436 |
show_raw_change_event = show_raw.change(
|
437 |
fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
|
438 |
inputs=[show_raw],
|
@@ -454,6 +410,14 @@ def inference_ui():
|
|
454 |
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
455 |
things_that_might_timeout.append(prompt_template_change_event)
|
456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
lora_model_change_event = lora_model.change(
|
458 |
fn=handle_lora_model_change,
|
459 |
inputs=[lora_model, prompt_template],
|
@@ -510,7 +474,7 @@ def inference_ui():
|
|
510 |
|
511 |
// Workaround default value not shown.
|
512 |
document.querySelector('#inference_lora_model input').value =
|
513 |
-
'
|
514 |
}, 100);
|
515 |
|
516 |
// Add tooltips
|
@@ -654,6 +618,30 @@ def inference_ui():
|
|
654 |
}, 500);
|
655 |
}, 0);
|
656 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
657 |
// Debounced updating the prompt preview.
|
658 |
setTimeout(function () {
|
659 |
function debounce(func, wait) {
|
|
|
8 |
|
9 |
from ..globals import Global
|
10 |
from ..models import get_model, get_tokenizer, get_device
|
11 |
+
from ..lib.inference import generate
|
12 |
from ..utils.data import (
|
13 |
get_available_template_names,
|
14 |
get_available_lora_model_names,
|
15 |
get_info_of_available_lora_model)
|
16 |
from ..utils.prompter import Prompter
|
|
|
17 |
|
18 |
device = get_device()
|
19 |
|
|
|
22 |
|
23 |
|
24 |
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
25 |
+
base_model_name = Global.base_model_name
|
26 |
|
27 |
try:
|
28 |
get_tokenizer(base_model_name)
|
|
|
48 |
show_raw=False,
|
49 |
progress=gr.Progress(track_tqdm=True),
|
50 |
):
|
51 |
+
base_model_name = Global.base_model_name
|
52 |
|
53 |
try:
|
54 |
if Global.generation_force_stopped_at is not None:
|
|
|
103 |
tokenizer = get_tokenizer(base_model_name)
|
104 |
model = get_model(base_model_name, lora_model_name)
|
105 |
|
|
|
|
|
106 |
generation_config = GenerationConfig(
|
107 |
temperature=temperature,
|
108 |
top_p=top_p,
|
|
|
111 |
num_beams=num_beams,
|
112 |
)
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
115 |
if Global.should_stop_generating:
|
116 |
return True
|
117 |
return False
|
118 |
|
119 |
Global.should_stop_generating = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
generation_args = {
|
122 |
+
'model': model,
|
123 |
+
'tokenizer': tokenizer,
|
124 |
+
'prompt': prompt,
|
125 |
+
'generation_config': generation_config,
|
126 |
+
'max_new_tokens': max_new_tokens,
|
127 |
+
'stopping_criteria': [ui_generation_stopping_criteria],
|
128 |
+
'stream_output': stream_output
|
129 |
+
}
|
|
|
|
|
|
|
130 |
|
131 |
+
for (decoded_output, output) in generate(**generation_args):
|
132 |
+
raw_output_str = None
|
133 |
+
if show_raw:
|
134 |
+
raw_output_str = str(output)
|
135 |
+
response = prompter.get_response(decoded_output)
|
136 |
|
137 |
+
if Global.should_stop_generating:
|
138 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
yield (
|
141 |
+
gr.Textbox.update(
|
142 |
+
value=response, lines=inference_output_lines),
|
143 |
+
raw_output_str)
|
144 |
|
145 |
+
if Global.should_stop_generating:
|
146 |
+
# If the user stops the generation, and then clicks the
|
147 |
+
# generation button again, they may mysteriously landed
|
148 |
+
# here, in the previous, should-be-stopped generation
|
149 |
+
# function call, with the new generation function not be
|
150 |
+
# called at all. To workaround this, we yield a message
|
151 |
+
# and setting lines=1, and if the front-end JS detects
|
152 |
+
# that lines has been set to 1 (rows="1" in HTML),
|
153 |
+
# it will automatically click the generate button again
|
154 |
+
# (gr.Textbox.update() does not support updating
|
155 |
+
# elem_classes or elem_id).
|
156 |
+
# [WORKAROUND-UI01]
|
157 |
+
yield (
|
158 |
+
gr.Textbox.update(
|
159 |
+
value="Please retry", lines=1),
|
160 |
+
None)
|
161 |
+
|
162 |
+
return
|
163 |
except Exception as e:
|
164 |
raise gr.Error(e)
|
165 |
|
|
|
179 |
current_prompt_template = current_prompt_template or next(
|
180 |
iter(available_template_names_with_none), None)
|
181 |
|
182 |
+
default_lora_models = []
|
183 |
available_lora_models = default_lora_models + get_available_lora_model_names()
|
184 |
available_lora_models = available_lora_models + ["None"]
|
185 |
|
|
|
205 |
"", visible=False)
|
206 |
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
207 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
208 |
+
model_base_model = lora_mode_info.get("base_model")
|
209 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
210 |
+
if model_base_model and model_base_model != Global.base_model_name:
|
211 |
+
model_prompt_template_message_update = gr.Markdown.update(
|
212 |
+
f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.", visible=True)
|
213 |
+
elif model_prompt_template and model_prompt_template != prompt_template:
|
214 |
model_prompt_template_message_update = gr.Markdown.update(
|
215 |
f"This model was trained with prompt template `{model_prompt_template}`.", visible=True)
|
216 |
|
|
|
257 |
lora_model = gr.Dropdown(
|
258 |
label="LoRA Model",
|
259 |
elem_id="inference_lora_model",
|
260 |
+
value="None",
|
261 |
allow_custom_value=True,
|
262 |
)
|
263 |
prompt_template = gr.Dropdown(
|
|
|
387 |
interactive=False,
|
388 |
elem_id="inference_raw_output")
|
389 |
|
390 |
+
reload_selected_models_btn = gr.Button("", elem_id="inference_reload_selected_models_btn")
|
391 |
+
|
392 |
show_raw_change_event = show_raw.change(
|
393 |
fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
|
394 |
inputs=[show_raw],
|
|
|
410 |
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
411 |
things_that_might_timeout.append(prompt_template_change_event)
|
412 |
|
413 |
+
reload_selected_models_btn_event = reload_selected_models_btn.click(
|
414 |
+
fn=handle_prompt_template_change,
|
415 |
+
inputs=[prompt_template, lora_model],
|
416 |
+
outputs=[
|
417 |
+
model_prompt_template_message,
|
418 |
+
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
419 |
+
things_that_might_timeout.append(reload_selected_models_btn_event)
|
420 |
+
|
421 |
lora_model_change_event = lora_model.change(
|
422 |
fn=handle_lora_model_change,
|
423 |
inputs=[lora_model, prompt_template],
|
|
|
474 |
|
475 |
// Workaround default value not shown.
|
476 |
document.querySelector('#inference_lora_model input').value =
|
477 |
+
'None';
|
478 |
}, 100);
|
479 |
|
480 |
// Add tooltips
|
|
|
618 |
}, 500);
|
619 |
}, 0);
|
620 |
|
621 |
+
// Reload model selection on possible base model change.
|
622 |
+
setTimeout(function () {
|
623 |
+
const elem = document.getElementById('main_page_tabs_container');
|
624 |
+
if (!elem) return;
|
625 |
+
|
626 |
+
let prevClassList = [];
|
627 |
+
|
628 |
+
new MutationObserver(function (mutationsList, observer) {
|
629 |
+
const currentPrevClassList = prevClassList;
|
630 |
+
const currentClassList = Array.from(elem.classList);
|
631 |
+
prevClassList = Array.from(elem.classList);
|
632 |
+
|
633 |
+
if (!currentPrevClassList.includes('hide')) return;
|
634 |
+
if (currentClassList.includes('hide')) return;
|
635 |
+
|
636 |
+
const inference_reload_selected_models_btn_elem = document.getElementById('inference_reload_selected_models_btn');
|
637 |
+
|
638 |
+
if (inference_reload_selected_models_btn_elem) inference_reload_selected_models_btn_elem.click();
|
639 |
+
}).observe(elem, {
|
640 |
+
attributes: true,
|
641 |
+
attributeFilter: ['class'],
|
642 |
+
});
|
643 |
+
}, 0);
|
644 |
+
|
645 |
// Debounced updating the prompt preview.
|
646 |
setTimeout(function () {
|
647 |
function debounce(func, wait) {
|
llama_lora/ui/main_page.py
CHANGED
@@ -17,25 +17,50 @@ def main_page():
|
|
17 |
css=main_page_custom_css(),
|
18 |
) as main_page_blocks:
|
19 |
with gr.Column(elem_id="main_page_content"):
|
20 |
-
gr.
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
main_page_blocks.load(_js=f"""
|
40 |
function () {{
|
41 |
{popperjs_core_code()}
|
@@ -61,6 +86,17 @@ def main_page():
|
|
61 |
});
|
62 |
handle_gradio_container_element_class_change();
|
63 |
}, 500);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
}
|
65 |
""")
|
66 |
|
@@ -127,12 +163,81 @@ def main_page_custom_css():
|
|
127 |
display: none;
|
128 |
}
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
#main_page_content > .tabs > .tab-nav * {
|
131 |
font-size: 1rem;
|
132 |
font-weight: 700;
|
133 |
/* text-transform: uppercase; */
|
134 |
}
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
#inference_lora_model_group {
|
137 |
border-radius: var(--block-radius);
|
138 |
background: var(--block-background-fill);
|
@@ -147,7 +252,8 @@ def main_page_custom_css():
|
|
147 |
position: absolute;
|
148 |
bottom: 8px;
|
149 |
left: 20px;
|
150 |
-
z-index:
|
|
|
151 |
font-size: 12px;
|
152 |
opacity: 0.7;
|
153 |
}
|
@@ -413,6 +519,24 @@ def main_page_custom_css():
|
|
413 |
margin: -32px -16px;
|
414 |
}
|
415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
.finetune_dataset_error_message {
|
417 |
color: var(--error-text-color) !important;
|
418 |
}
|
@@ -428,10 +552,43 @@ def main_page_custom_css():
|
|
428 |
white-space: pre-wrap;
|
429 |
}
|
430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
#finetune_max_seq_length {
|
432 |
flex: 2;
|
433 |
}
|
434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
#finetune_save_total_limit,
|
436 |
#finetune_save_steps,
|
437 |
#finetune_logging_steps {
|
@@ -503,3 +660,28 @@ def main_page_custom_css():
|
|
503 |
.tippy-box[data-animation=scale-subtle][data-placement^=top]{transform-origin:bottom}.tippy-box[data-animation=scale-subtle][data-placement^=bottom]{transform-origin:top}.tippy-box[data-animation=scale-subtle][data-placement^=left]{transform-origin:right}.tippy-box[data-animation=scale-subtle][data-placement^=right]{transform-origin:left}.tippy-box[data-animation=scale-subtle][data-state=hidden]{transform:scale(.8);opacity:0}
|
504 |
"""
|
505 |
return css
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
css=main_page_custom_css(),
|
18 |
) as main_page_blocks:
|
19 |
with gr.Column(elem_id="main_page_content"):
|
20 |
+
with gr.Row():
|
21 |
+
gr.Markdown(
|
22 |
+
f"""
|
23 |
+
<h1 class="app_title_text">{title}</h1> <wbr />
|
24 |
+
<h2 class="app_subtitle_text">{Global.ui_subtitle}</h2>
|
25 |
+
""",
|
26 |
+
elem_id="page_title",
|
27 |
+
)
|
28 |
+
global_base_model_select = gr.Dropdown(
|
29 |
+
label="Base Model",
|
30 |
+
elem_id="global_base_model_select",
|
31 |
+
choices=Global.base_model_choices,
|
32 |
+
value=lambda: Global.base_model_name,
|
33 |
+
allow_custom_value=True,
|
34 |
+
)
|
35 |
+
# global_base_model_select_loading_status = gr.Markdown("", elem_id="global_base_model_select_loading_status")
|
36 |
+
|
37 |
+
with gr.Column(elem_id="main_page_tabs_container") as main_page_tabs_container:
|
38 |
+
with gr.Tab("Inference"):
|
39 |
+
inference_ui()
|
40 |
+
with gr.Tab("Fine-tuning"):
|
41 |
+
finetune_ui()
|
42 |
+
with gr.Tab("Tokenizer"):
|
43 |
+
tokenizer_ui()
|
44 |
+
please_select_a_base_model_message = gr.Markdown("Please select a base model.", visible=False)
|
45 |
+
current_base_model_hint = gr.Markdown(lambda: Global.base_model_name, elem_id="current_base_model_hint")
|
46 |
+
foot_info = gr.Markdown(get_foot_info)
|
47 |
+
|
48 |
+
global_base_model_select.change(
|
49 |
+
fn=pre_handle_change_base_model,
|
50 |
+
inputs=[],
|
51 |
+
outputs=[main_page_tabs_container]
|
52 |
+
).then(
|
53 |
+
fn=handle_change_base_model,
|
54 |
+
inputs=[global_base_model_select],
|
55 |
+
outputs=[
|
56 |
+
main_page_tabs_container,
|
57 |
+
please_select_a_base_model_message,
|
58 |
+
current_base_model_hint,
|
59 |
+
# global_base_model_select_loading_status,
|
60 |
+
foot_info
|
61 |
+
]
|
62 |
+
)
|
63 |
+
|
64 |
main_page_blocks.load(_js=f"""
|
65 |
function () {{
|
66 |
{popperjs_core_code()}
|
|
|
86 |
});
|
87 |
handle_gradio_container_element_class_change();
|
88 |
}, 500);
|
89 |
+
""" + """
|
90 |
+
setTimeout(function () {
|
91 |
+
// Workaround default value not shown.
|
92 |
+
const current_base_model_hint_elem = document.querySelector('#current_base_model_hint > p');
|
93 |
+
if (!current_base_model_hint_elem) return;
|
94 |
+
|
95 |
+
const base_model_name = current_base_model_hint_elem.innerText;
|
96 |
+
document.querySelector('#global_base_model_select input').value = base_model_name;
|
97 |
+
document.querySelector('#global_base_model_select').classList.add('show');
|
98 |
+
}, 3200);
|
99 |
+
""" + """
|
100 |
}
|
101 |
""")
|
102 |
|
|
|
163 |
display: none;
|
164 |
}
|
165 |
|
166 |
+
#page_title {
|
167 |
+
flex-grow: 3;
|
168 |
+
}
|
169 |
+
#global_base_model_select {
|
170 |
+
position: relative;
|
171 |
+
align-self: center;
|
172 |
+
min-width: 250px;
|
173 |
+
padding: 2px 2px;
|
174 |
+
border: 0;
|
175 |
+
box-shadow: none;
|
176 |
+
opacity: 0;
|
177 |
+
pointer-events: none;
|
178 |
+
}
|
179 |
+
#global_base_model_select.show {
|
180 |
+
opacity: 1;
|
181 |
+
pointer-events: auto;
|
182 |
+
}
|
183 |
+
#global_base_model_select label .wrap-inner {
|
184 |
+
padding: 2px 8px;
|
185 |
+
}
|
186 |
+
#global_base_model_select label span {
|
187 |
+
margin-bottom: 2px;
|
188 |
+
font-size: 80%;
|
189 |
+
position: absolute;
|
190 |
+
top: -14px;
|
191 |
+
left: 8px;
|
192 |
+
opacity: 0;
|
193 |
+
}
|
194 |
+
#global_base_model_select:hover label span {
|
195 |
+
opacity: 1;
|
196 |
+
}
|
197 |
+
|
198 |
+
#global_base_model_select_loading_status {
|
199 |
+
position: absolute;
|
200 |
+
pointer-events: none;
|
201 |
+
top: 0;
|
202 |
+
left: 0;
|
203 |
+
right: 0;
|
204 |
+
bottom: 0;
|
205 |
+
}
|
206 |
+
#global_base_model_select_loading_status > .wrap:not(.hide) {
|
207 |
+
z-index: 9999;
|
208 |
+
position: absolute;
|
209 |
+
top: 112px !important;
|
210 |
+
bottom: 0 !important;
|
211 |
+
max-height: none;
|
212 |
+
background: var(--background-fill-primary);
|
213 |
+
opacity: 0.8;
|
214 |
+
}
|
215 |
+
#global_base_model_select ul {
|
216 |
+
z-index: 9999;
|
217 |
+
background: var(--block-background-fill);
|
218 |
+
}
|
219 |
+
|
220 |
+
#current_base_model_hint {
|
221 |
+
display: none;
|
222 |
+
}
|
223 |
+
|
224 |
#main_page_content > .tabs > .tab-nav * {
|
225 |
font-size: 1rem;
|
226 |
font-weight: 700;
|
227 |
/* text-transform: uppercase; */
|
228 |
}
|
229 |
|
230 |
+
#inference_reload_selected_models_btn {
|
231 |
+
position: absolute;
|
232 |
+
top: 0;
|
233 |
+
left: 0;
|
234 |
+
width: 0;
|
235 |
+
height: 0;
|
236 |
+
padding: 0;
|
237 |
+
opacity: 0;
|
238 |
+
pointer-events: none;
|
239 |
+
}
|
240 |
+
|
241 |
#inference_lora_model_group {
|
242 |
border-radius: var(--block-radius);
|
243 |
background: var(--block-background-fill);
|
|
|
252 |
position: absolute;
|
253 |
bottom: 8px;
|
254 |
left: 20px;
|
255 |
+
z-index: 61;
|
256 |
+
width: 999px;
|
257 |
font-size: 12px;
|
258 |
opacity: 0.7;
|
259 |
}
|
|
|
519 |
margin: -32px -16px;
|
520 |
}
|
521 |
|
522 |
+
#finetune_continue_from_model_box {
|
523 |
+
/* padding: 0; */
|
524 |
+
}
|
525 |
+
#finetune_continue_from_model_box .block {
|
526 |
+
border: 0;
|
527 |
+
box-shadow: none;
|
528 |
+
padding: 0;
|
529 |
+
}
|
530 |
+
#finetune_continue_from_model_box > * {
|
531 |
+
/* gap: 0; */
|
532 |
+
}
|
533 |
+
#finetune_continue_from_model_box button {
|
534 |
+
margin-top: 16px;
|
535 |
+
}
|
536 |
+
#finetune_continue_from_model {
|
537 |
+
flex-grow: 2;
|
538 |
+
}
|
539 |
+
|
540 |
.finetune_dataset_error_message {
|
541 |
color: var(--error-text-color) !important;
|
542 |
}
|
|
|
552 |
white-space: pre-wrap;
|
553 |
}
|
554 |
|
555 |
+
/*
|
556 |
+
#finetune_dataset_preview {
|
557 |
+
max-height: 100vh;
|
558 |
+
overflow: auto;
|
559 |
+
border: var(--block-border-width) solid var(--border-color-primary);
|
560 |
+
border-radius: var(--radius-lg);
|
561 |
+
}
|
562 |
+
#finetune_dataset_preview .table-wrap {
|
563 |
+
border: 0 !important;
|
564 |
+
}
|
565 |
+
*/
|
566 |
+
|
567 |
#finetune_max_seq_length {
|
568 |
flex: 2;
|
569 |
}
|
570 |
|
571 |
+
#finetune_lora_target_modules_add_box {
|
572 |
+
margin-top: -24px;
|
573 |
+
padding-top: 8px;
|
574 |
+
border-top-left-radius: 0;
|
575 |
+
border-top-right-radius: 0;
|
576 |
+
border-top: 0;
|
577 |
+
}
|
578 |
+
#finetune_lora_target_modules_add_box > * > .form {
|
579 |
+
border: 0;
|
580 |
+
box-shadow: none;
|
581 |
+
}
|
582 |
+
#finetune_lora_target_modules_add {
|
583 |
+
padding: 0;
|
584 |
+
}
|
585 |
+
#finetune_lora_target_modules_add input {
|
586 |
+
padding: 4px 8px;
|
587 |
+
}
|
588 |
+
#finetune_lora_target_modules_add_btn {
|
589 |
+
min-width: 60px;
|
590 |
+
}
|
591 |
+
|
592 |
#finetune_save_total_limit,
|
593 |
#finetune_save_steps,
|
594 |
#finetune_logging_steps {
|
|
|
660 |
.tippy-box[data-animation=scale-subtle][data-placement^=top]{transform-origin:bottom}.tippy-box[data-animation=scale-subtle][data-placement^=bottom]{transform-origin:top}.tippy-box[data-animation=scale-subtle][data-placement^=left]{transform-origin:right}.tippy-box[data-animation=scale-subtle][data-placement^=right]{transform-origin:left}.tippy-box[data-animation=scale-subtle][data-state=hidden]{transform:scale(.8);opacity:0}
|
661 |
"""
|
662 |
return css
|
663 |
+
|
664 |
+
|
665 |
+
def pre_handle_change_base_model():
|
666 |
+
return gr.Column.update(visible=False)
|
667 |
+
|
668 |
+
|
669 |
+
def handle_change_base_model(selected_base_model_name):
|
670 |
+
Global.base_model_name = selected_base_model_name
|
671 |
+
|
672 |
+
if Global.base_model_name:
|
673 |
+
return gr.Column.update(visible=True), gr.Markdown.update(visible=False), Global.base_model_name, get_foot_info()
|
674 |
+
|
675 |
+
return gr.Column.update(visible=False), gr.Markdown.update(visible=True), Global.base_model_name, get_foot_info()
|
676 |
+
|
677 |
+
|
678 |
+
def get_foot_info():
|
679 |
+
info = []
|
680 |
+
if Global.version:
|
681 |
+
info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
|
682 |
+
info.append(f"Base model: `{Global.base_model_name}`")
|
683 |
+
if Global.ui_show_sys_info:
|
684 |
+
info.append(f"Data dir: `{Global.data_dir}`")
|
685 |
+
return f"""\
|
686 |
+
<small>{" · ".join(info)}</small>
|
687 |
+
"""
|
llama_lora/ui/tokenizer_ui.py
CHANGED
@@ -7,7 +7,7 @@ from ..models import get_tokenizer
|
|
7 |
|
8 |
|
9 |
def handle_decode(encoded_tokens_json):
|
10 |
-
base_model_name = Global.
|
11 |
try:
|
12 |
encoded_tokens = json.loads(encoded_tokens_json)
|
13 |
if Global.ui_dev_mode:
|
@@ -20,7 +20,7 @@ def handle_decode(encoded_tokens_json):
|
|
20 |
|
21 |
|
22 |
def handle_encode(decoded_tokens):
|
23 |
-
base_model_name = Global.
|
24 |
try:
|
25 |
if Global.ui_dev_mode:
|
26 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
|
|
7 |
|
8 |
|
9 |
def handle_decode(encoded_tokens_json):
|
10 |
+
base_model_name = Global.base_model_name
|
11 |
try:
|
12 |
encoded_tokens = json.loads(encoded_tokens_json)
|
13 |
if Global.ui_dev_mode:
|
|
|
20 |
|
21 |
|
22 |
def handle_encode(decoded_tokens):
|
23 |
+
base_model_name = Global.base_model_name
|
24 |
try:
|
25 |
if Global.ui_dev_mode:
|
26 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
llama_lora/utils/data.py
CHANGED
@@ -30,19 +30,22 @@ def copy_sample_data_if_not_exists(source, destination):
|
|
30 |
def get_available_template_names():
|
31 |
templates_directory_path = os.path.join(Global.data_dir, "templates")
|
32 |
all_files = os.listdir(templates_directory_path)
|
33 |
-
|
|
|
34 |
|
35 |
|
36 |
def get_available_dataset_names():
|
37 |
datasets_directory_path = os.path.join(Global.data_dir, "datasets")
|
38 |
all_files = os.listdir(datasets_directory_path)
|
39 |
-
|
|
|
40 |
|
41 |
|
42 |
def get_available_lora_model_names():
|
43 |
-
|
44 |
-
all_items = os.listdir(
|
45 |
-
|
|
|
46 |
|
47 |
|
48 |
def get_path_of_available_lora_model(name):
|
|
|
30 |
def get_available_template_names():
|
31 |
templates_directory_path = os.path.join(Global.data_dir, "templates")
|
32 |
all_files = os.listdir(templates_directory_path)
|
33 |
+
names = [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
|
34 |
+
return sorted(names)
|
35 |
|
36 |
|
37 |
def get_available_dataset_names():
|
38 |
datasets_directory_path = os.path.join(Global.data_dir, "datasets")
|
39 |
all_files = os.listdir(datasets_directory_path)
|
40 |
+
names = [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
|
41 |
+
return sorted(names)
|
42 |
|
43 |
|
44 |
def get_available_lora_model_names():
|
45 |
+
lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
|
46 |
+
all_items = os.listdir(lora_models_directory_path)
|
47 |
+
names = [item for item in all_items if os.path.isdir(os.path.join(lora_models_directory_path, item))]
|
48 |
+
return sorted(names)
|
49 |
|
50 |
|
51 |
def get_path_of_available_lora_model(name):
|
llama_lora/utils/prompter.py
CHANGED
@@ -5,13 +5,15 @@ From https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py
|
|
5 |
|
6 |
import json
|
7 |
import os.path as osp
|
|
|
|
|
8 |
from typing import Union, List
|
9 |
|
10 |
from ..globals import Global
|
11 |
|
12 |
|
13 |
class Prompter(object):
|
14 |
-
__slots__ = ("template_name", "template", "_verbose")
|
15 |
|
16 |
def __init__(self, template_name: str = "", verbose: bool = False):
|
17 |
self._verbose = verbose
|
@@ -21,12 +23,41 @@ class Prompter(object):
|
|
21 |
self.template_name = "None"
|
22 |
return
|
23 |
self.template_name = template_name
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
self.template = json.load(fp)
|
31 |
if self._verbose:
|
32 |
print(
|
@@ -47,23 +78,31 @@ class Prompter(object):
|
|
47 |
res = variables.get("prompt", "")
|
48 |
elif "variables" in self.template:
|
49 |
variable_names = self.template.get("variables")
|
50 |
-
if
|
51 |
-
variables
|
52 |
-
|
53 |
-
|
54 |
-
raise ValueError(
|
55 |
-
f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
|
56 |
-
default_prompt_name = self.template.get("default")
|
57 |
-
if default_prompt_name not in self.template:
|
58 |
-
raise ValueError(
|
59 |
-
f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
|
60 |
-
prompt_name = get_prompt_name(variables, variable_names)
|
61 |
-
prompt_template = self.template.get(default_prompt_name)
|
62 |
-
if prompt_name in self.template:
|
63 |
-
prompt_template = self.template.get(prompt_name)
|
64 |
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
else:
|
69 |
if type(variables) == dict:
|
@@ -92,18 +131,50 @@ class Prompter(object):
|
|
92 |
def get_response(self, output: str) -> str:
|
93 |
if self.template_name == "None":
|
94 |
return output
|
|
|
|
|
|
|
|
|
|
|
95 |
return self.template["response_split"].join(
|
96 |
-
|
97 |
).strip()
|
98 |
|
99 |
def get_variable_names(self) -> List[str]:
|
100 |
if self.template_name == "None":
|
101 |
return ["prompt"]
|
102 |
elif "variables" in self.template:
|
103 |
-
return self.template
|
104 |
else:
|
105 |
return ["instruction", "input"]
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
def get_val(arr, index, default=None):
|
109 |
return arr[index] if -len(arr) <= index < len(arr) else default
|
@@ -116,4 +187,62 @@ def get_prompt_name(variables, variable_names):
|
|
116 |
|
117 |
|
118 |
def variables_to_dict(variables, variable_names):
|
119 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
import json
|
7 |
import os.path as osp
|
8 |
+
import importlib
|
9 |
+
import itertools
|
10 |
from typing import Union, List
|
11 |
|
12 |
from ..globals import Global
|
13 |
|
14 |
|
15 |
class Prompter(object):
|
16 |
+
__slots__ = ("template_name", "template", "template_module", "_verbose")
|
17 |
|
18 |
def __init__(self, template_name: str = "", verbose: bool = False):
|
19 |
self._verbose = verbose
|
|
|
23 |
self.template_name = "None"
|
24 |
return
|
25 |
self.template_name = template_name
|
26 |
+
self.template_module = None
|
27 |
|
28 |
+
base_filename, ext = osp.splitext(template_name)
|
29 |
+
if ext == "":
|
30 |
+
filename = base_filename + ".json"
|
31 |
+
else:
|
32 |
+
filename = base_filename + ext
|
33 |
+
|
34 |
+
file_path = osp.join(Global.data_dir, "templates", filename)
|
35 |
+
|
36 |
+
if not osp.exists(file_path):
|
37 |
+
raise ValueError(f"Can't read {file_path}")
|
38 |
+
|
39 |
+
if ext == ".py":
|
40 |
+
template_module_spec = importlib.util.spec_from_file_location(
|
41 |
+
"template_module", file_path)
|
42 |
+
template_module = importlib.util.module_from_spec(
|
43 |
+
template_module_spec)
|
44 |
+
template_module_spec.loader.exec_module(template_module)
|
45 |
+
self.template_module = template_module
|
46 |
+
|
47 |
+
if not hasattr(template_module, "variables"):
|
48 |
+
raise ValueError(
|
49 |
+
"The template module does not have a \"variables\" attribute.")
|
50 |
+
|
51 |
+
self.template = {
|
52 |
+
'variables': template_module.variables
|
53 |
+
}
|
54 |
+
|
55 |
+
if hasattr(template_module, "response_split"):
|
56 |
+
self.template["response_split"] = template_module.response_split
|
57 |
+
|
58 |
+
return
|
59 |
+
|
60 |
+
with open(file_path) as fp:
|
61 |
self.template = json.load(fp)
|
62 |
if self._verbose:
|
63 |
print(
|
|
|
78 |
res = variables.get("prompt", "")
|
79 |
elif "variables" in self.template:
|
80 |
variable_names = self.template.get("variables")
|
81 |
+
if self.template_module:
|
82 |
+
if type(variables) == list:
|
83 |
+
variables = {k: v for k, v in zip(
|
84 |
+
variable_names, variables)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
res = self.template_module.get_prompt(variables)
|
87 |
+
else:
|
88 |
+
if type(variables) == dict:
|
89 |
+
variables = [variables.get(name, None)
|
90 |
+
for name in variable_names]
|
91 |
+
|
92 |
+
if "default" not in self.template:
|
93 |
+
raise ValueError(
|
94 |
+
f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
|
95 |
+
default_prompt_name = self.template.get("default")
|
96 |
+
if default_prompt_name not in self.template:
|
97 |
+
raise ValueError(
|
98 |
+
f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
|
99 |
+
prompt_name = get_prompt_name(variables, variable_names)
|
100 |
+
prompt_template = self.template.get(default_prompt_name)
|
101 |
+
if prompt_name in self.template:
|
102 |
+
prompt_template = self.template.get(prompt_name)
|
103 |
+
|
104 |
+
res = prompt_template.format(
|
105 |
+
**variables_to_dict(variables, variable_names))
|
106 |
|
107 |
else:
|
108 |
if type(variables) == dict:
|
|
|
131 |
def get_response(self, output: str) -> str:
|
132 |
if self.template_name == "None":
|
133 |
return output
|
134 |
+
|
135 |
+
splitted_output = output.split(self.template["response_split"])
|
136 |
+
# if len(splitted_output) <= 1:
|
137 |
+
# return output.strip()
|
138 |
+
|
139 |
return self.template["response_split"].join(
|
140 |
+
splitted_output[1:]
|
141 |
).strip()
|
142 |
|
143 |
def get_variable_names(self) -> List[str]:
|
144 |
if self.template_name == "None":
|
145 |
return ["prompt"]
|
146 |
elif "variables" in self.template:
|
147 |
+
return self.template['variables']
|
148 |
else:
|
149 |
return ["instruction", "input"]
|
150 |
|
151 |
+
def get_train_data_from_dataset(self, data, only_first_n_items=None):
|
152 |
+
if self.template_module:
|
153 |
+
if hasattr(self.template_module,
|
154 |
+
"get_train_data_list_from_dataset"):
|
155 |
+
data = self.template_module.get_train_data_list_from_dataset(
|
156 |
+
data)
|
157 |
+
if only_first_n_items:
|
158 |
+
data = data[:only_first_n_items]
|
159 |
+
return list(itertools.chain(*list(
|
160 |
+
map(self.template_module.get_train_data, data)
|
161 |
+
)))
|
162 |
+
|
163 |
+
if only_first_n_items:
|
164 |
+
data = data[:only_first_n_items]
|
165 |
+
|
166 |
+
data = process_json_dataset(data)
|
167 |
+
|
168 |
+
train_data = [
|
169 |
+
{
|
170 |
+
'prompt': self.generate_prompt(d['variables']),
|
171 |
+
'completion': d['output'],
|
172 |
+
**{"_var_" + k: v for k, v in d['variables'].items()}
|
173 |
+
}
|
174 |
+
for d in data]
|
175 |
+
|
176 |
+
return train_data
|
177 |
+
|
178 |
|
179 |
def get_val(arr, index, default=None):
|
180 |
return arr[index] if -len(arr) <= index < len(arr) else default
|
|
|
187 |
|
188 |
|
189 |
def variables_to_dict(variables, variable_names):
|
190 |
+
return {
|
191 |
+
key: (variables[i] if i < len(variables)
|
192 |
+
and variables[i] is not None else '')
|
193 |
+
for i, key in enumerate(variable_names)
|
194 |
+
}
|
195 |
+
|
196 |
+
|
197 |
+
def process_json_dataset(data):
|
198 |
+
if not isinstance(data, list):
|
199 |
+
raise ValueError("The dataset is not an array of objects.")
|
200 |
+
|
201 |
+
first_item = get_val_from_arr(data, 0, None)
|
202 |
+
|
203 |
+
if first_item is None:
|
204 |
+
raise ValueError("The dataset is empty.")
|
205 |
+
if not isinstance(first_item, dict):
|
206 |
+
raise ValueError("The dataset is not an array of objects.")
|
207 |
+
|
208 |
+
# Convert OpenAI fine-tuning dataset to LLaMA LoRA style
|
209 |
+
if "completion" in first_item and "output" not in first_item:
|
210 |
+
data = [
|
211 |
+
{"output" if k == "completion" else k: v for k, v in d.items()}
|
212 |
+
for d in data]
|
213 |
+
first_item = get_val_from_arr(data, 0, None)
|
214 |
+
|
215 |
+
# Flatten Stanford Alpaca style instances
|
216 |
+
if "instances" in first_item and isinstance(first_item["instances"], list):
|
217 |
+
data = [
|
218 |
+
{"output" if k == "completion" else k: v for k, v in d.items()}
|
219 |
+
for d in data]
|
220 |
+
flattened_data = []
|
221 |
+
for item in data:
|
222 |
+
for instance in item["instances"]:
|
223 |
+
d = {k: v for k, v in item.items() if k != "instances"}
|
224 |
+
d.update(instance)
|
225 |
+
flattened_data.append(d)
|
226 |
+
data = flattened_data
|
227 |
+
first_item = get_val_from_arr(data, 0, None)
|
228 |
+
|
229 |
+
if "output" not in first_item:
|
230 |
+
raise ValueError(
|
231 |
+
"The data does not contains an \"output\" or \"completion\".")
|
232 |
+
|
233 |
+
# Put all variables under the "variables" key if it does not exists
|
234 |
+
if "variables" not in first_item:
|
235 |
+
data = [
|
236 |
+
{
|
237 |
+
"variables":
|
238 |
+
{k: v for k, v in d.items() if k != "output"},
|
239 |
+
"output":
|
240 |
+
d["output"]
|
241 |
+
}
|
242 |
+
for d in data
|
243 |
+
]
|
244 |
+
return data
|
245 |
+
|
246 |
+
|
247 |
+
def get_val_from_arr(arr, index, default=None):
|
248 |
+
return arr[index] if -len(arr) <= index < len(arr) else default
|
lora_models/alpaca-lora-7b/finetune_params.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_train_epochs": 10,
|
3 |
+
"learning_rate": 0.0003,
|
4 |
+
"cutoff_len": 512,
|
5 |
+
"lora_r": 16,
|
6 |
+
"lora_alpha": 16,
|
7 |
+
"lora_dropout": 0.05,
|
8 |
+
"lora_target_modules": [
|
9 |
+
"q_proj",
|
10 |
+
"v_proj",
|
11 |
+
"k_proj",
|
12 |
+
"o_proj"
|
13 |
+
],
|
14 |
+
"train_on_inputs": true,
|
15 |
+
"group_by_length": false,
|
16 |
+
"save_steps": 2000,
|
17 |
+
"save_total_limit": 10,
|
18 |
+
"logging_steps": 10
|
19 |
+
}
|
lora_models/alpaca-lora-7b/info.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"hf_model_name": "tloen/alpaca-lora-7b",
|
3 |
+
"load_from_hf": true,
|
4 |
+
"base_model": "decapoda-research/llama-7b-hf",
|
5 |
+
"prompt_template": "alpaca"
|
6 |
+
}
|
lora_models/unhelpful-ai-v01/finetune_params.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_train_epochs": 16,
|
3 |
+
"learning_rate": 0.0003,
|
4 |
+
"cutoff_len": 512,
|
5 |
+
"lora_r": 12,
|
6 |
+
"lora_alpha": 32,
|
7 |
+
"lora_dropout": 0.05,
|
8 |
+
"lora_target_modules": [
|
9 |
+
"q_proj",
|
10 |
+
"v_proj",
|
11 |
+
"k_proj",
|
12 |
+
"o_proj"
|
13 |
+
],
|
14 |
+
"train_on_inputs": false,
|
15 |
+
"group_by_length": false,
|
16 |
+
"save_steps": 500,
|
17 |
+
"save_total_limit": 5,
|
18 |
+
"logging_steps": 10
|
19 |
+
}
|