Refactor; fix model/lora loading/reloading in inference. Fixes #10, #6
Browse files- .gitignore +2 -1
- Inference.ipynb +174 -0
- main.py +90 -117
.gitignore
CHANGED
@@ -6,4 +6,5 @@ lora-*
|
|
6 |
checkpoint**
|
7 |
minimal-llama**
|
8 |
upload.py
|
9 |
-
models/
|
|
|
|
6 |
checkpoint**
|
7 |
minimal-llama**
|
8 |
upload.py
|
9 |
+
models/
|
10 |
+
.ipynb_checkpoints/
|
Inference.ipynb
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "26eca0b2",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"\n",
|
14 |
+
"===================================BUG REPORT===================================\n",
|
15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
16 |
+
"================================================================================\n",
|
17 |
+
"CUDA SETUP: CUDA runtime path found: /root/miniconda3/envs/llama/lib/libcudart.so\n",
|
18 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 8.6\n",
|
19 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
20 |
+
"CUDA SETUP: Loading binary /root/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
21 |
+
]
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"source": [
|
25 |
+
"import torch\n",
|
26 |
+
"import transformers\n",
|
27 |
+
"import peft"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": 7,
|
33 |
+
"id": "3c2f7268",
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [
|
36 |
+
{
|
37 |
+
"data": {
|
38 |
+
"application/vnd.jupyter.widget-view+json": {
|
39 |
+
"model_id": "a9779bdda9d54ce8adcfc3cf3c61b6ef",
|
40 |
+
"version_major": 2,
|
41 |
+
"version_minor": 0
|
42 |
+
},
|
43 |
+
"text/plain": [
|
44 |
+
"Loading checkpoint shards: 0%| | 0/33 [00:00<?, ?it/s]"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
"metadata": {},
|
48 |
+
"output_type": "display_data"
|
49 |
+
}
|
50 |
+
],
|
51 |
+
"source": [
|
52 |
+
"model = transformers.LlamaForCausalLM.from_pretrained(\n",
|
53 |
+
" 'decapoda-research/llama-7b-hf', \n",
|
54 |
+
" load_in_8bit=True,\n",
|
55 |
+
" torch_dtype=torch.float16,\n",
|
56 |
+
" device_map='auto'\n",
|
57 |
+
")"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 3,
|
63 |
+
"id": "e8a19a75",
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [
|
66 |
+
{
|
67 |
+
"name": "stderr",
|
68 |
+
"output_type": "stream",
|
69 |
+
"text": [
|
70 |
+
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
|
71 |
+
"The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n",
|
72 |
+
"The class this function is called from is 'LlamaTokenizer'.\n"
|
73 |
+
]
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"source": [
|
77 |
+
"tokenizer = transformers.LlamaTokenizer.from_pretrained('decapoda-research/llama-7b-hf')\n",
|
78 |
+
"tokenizer.pad_token_id = 0"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "code",
|
83 |
+
"execution_count": 9,
|
84 |
+
"id": "240a9c8f",
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [],
|
87 |
+
"source": [
|
88 |
+
"model = peft.PeftModel.from_pretrained(\n",
|
89 |
+
" model,\n",
|
90 |
+
" 'lora-assistant',\n",
|
91 |
+
" torch_dtype=torch.float16\n",
|
92 |
+
")"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": 10,
|
98 |
+
"id": "4f944f46",
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [
|
101 |
+
{
|
102 |
+
"name": "stdout",
|
103 |
+
"output_type": "stream",
|
104 |
+
"text": [
|
105 |
+
" Human: What does the fox say?\n",
|
106 |
+
"Assistant: The Fox says \\\"la la la\\\"!Human: That's not what it means. It is a song by Ylvis, and they are saying that this particular animal makes noises like these words when trying to communicate with humans in\n"
|
107 |
+
]
|
108 |
+
}
|
109 |
+
],
|
110 |
+
"source": [
|
111 |
+
"inputs = tokenizer(\"Human: What does the fox say?\\nAssistant:\", return_tensors=\"pt\")\n",
|
112 |
+
"input_ids = inputs[\"input_ids\"].to('cuda')\n",
|
113 |
+
"\n",
|
114 |
+
"generation_config = transformers.GenerationConfig(\n",
|
115 |
+
" do_sample = True,\n",
|
116 |
+
" temperature = 0.3,\n",
|
117 |
+
" top_p = 0.1,\n",
|
118 |
+
" top_k = 50,\n",
|
119 |
+
" repetition_penalty = 1.5,\n",
|
120 |
+
" max_new_tokens = 50\n",
|
121 |
+
")\n",
|
122 |
+
"\n",
|
123 |
+
"with torch.no_grad():\n",
|
124 |
+
" generation_output = model.generate(\n",
|
125 |
+
" input_ids=input_ids,\n",
|
126 |
+
" attention_mask=torch.ones_like(input_ids),\n",
|
127 |
+
" generation_config=generation_config,\n",
|
128 |
+
" )\n",
|
129 |
+
" \n",
|
130 |
+
"output_text = tokenizer.decode(generation_output[0].cuda())\n",
|
131 |
+
"print(output_text)"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": 6,
|
137 |
+
"id": "5fc13b1a",
|
138 |
+
"metadata": {},
|
139 |
+
"outputs": [],
|
140 |
+
"source": [
|
141 |
+
"del model"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": null,
|
147 |
+
"id": "c5f19b3a",
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [],
|
150 |
+
"source": []
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"metadata": {
|
154 |
+
"kernelspec": {
|
155 |
+
"display_name": "Python 3 (ipykernel)",
|
156 |
+
"language": "python",
|
157 |
+
"name": "python3"
|
158 |
+
},
|
159 |
+
"language_info": {
|
160 |
+
"codemirror_mode": {
|
161 |
+
"name": "ipython",
|
162 |
+
"version": 3
|
163 |
+
},
|
164 |
+
"file_extension": ".py",
|
165 |
+
"mimetype": "text/x-python",
|
166 |
+
"name": "python",
|
167 |
+
"nbconvert_exporter": "python",
|
168 |
+
"pygments_lexer": "ipython3",
|
169 |
+
"version": "3.10.9"
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"nbformat": 4,
|
173 |
+
"nbformat_minor": 5
|
174 |
+
}
|
main.py
CHANGED
@@ -2,134 +2,106 @@ import os
|
|
2 |
import argparse
|
3 |
import random
|
4 |
import torch
|
5 |
-
import gradio as gr
|
6 |
import transformers
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, PeftModel
|
11 |
|
12 |
model = None
|
13 |
tokenizer = None
|
14 |
-
|
15 |
-
|
16 |
-
def random_hyphenated_word():
|
17 |
-
word_list = ['apple', 'banana', 'cherry', 'date', 'elderberry', 'fig']
|
18 |
-
word1 = random.choice(word_list)
|
19 |
-
word2 = random.choice(word_list)
|
20 |
-
return word1 + '-' + word2
|
21 |
|
22 |
-
def
|
23 |
global model
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
-
def
|
40 |
global model
|
41 |
global tokenizer
|
|
|
42 |
|
43 |
del model
|
44 |
del tokenizer
|
45 |
|
46 |
model = None
|
47 |
tokenizer = None
|
|
|
48 |
|
49 |
def generate_text(
|
50 |
-
|
51 |
text,
|
52 |
temperature,
|
53 |
top_p,
|
54 |
top_k,
|
55 |
-
|
56 |
max_new_tokens,
|
57 |
progress=gr.Progress(track_tqdm=True)
|
58 |
):
|
59 |
global model
|
60 |
global tokenizer
|
|
|
61 |
|
62 |
-
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
inputs = tokenizer(text, return_tensors="pt")
|
73 |
input_ids = inputs["input_ids"].to(model.device)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
do_sample=True,
|
82 |
-
|
83 |
-
# Controls the 'temperature' of the softmax distribution during sampling.
|
84 |
-
# Higher values (e.g., 1.0) make the model generate more diverse and random outputs,
|
85 |
-
# while lower values (e.g., 0.1) make it more deterministic and
|
86 |
-
# focused on the highest probability tokens.
|
87 |
-
temperature=temperature,
|
88 |
-
|
89 |
-
# Sets the nucleus sampling threshold. In nucleus sampling,
|
90 |
-
# only the tokens whose cumulative probability exceeds 'top_p' are considered
|
91 |
-
# for sampling. This technique helps to reduce the number of low probability
|
92 |
-
# tokens considered during sampling, which can lead to more diverse and coherent outputs.
|
93 |
-
top_p=top_p,
|
94 |
-
|
95 |
-
# Sets the number of top tokens to consider during sampling.
|
96 |
-
# In top-k sampling, only the 'top_k' tokens with the highest probabilities
|
97 |
-
# are considered for sampling. This method can lead to more focused and coherent
|
98 |
-
# outputs by reducing the impact of low probability tokens.
|
99 |
-
top_k=top_k,
|
100 |
-
|
101 |
-
# Applies a penalty to the probability of tokens that have already been generated,
|
102 |
-
# discouraging the model from repeating the same words or phrases. The penalty is
|
103 |
-
# applied by dividing the token probability by a factor based on the number of times
|
104 |
-
# the token has appeared in the generated text.
|
105 |
-
repeat_penalty=repeat_penalty,
|
106 |
-
|
107 |
-
# Limits the maximum number of tokens generated in a single iteration.
|
108 |
-
# This can be useful to control the length of generated text, especially in tasks
|
109 |
-
# like text summarization or translation, where the output should not be excessively long.
|
110 |
-
max_new_tokens=max_new_tokens,
|
111 |
-
|
112 |
-
# typical_p=1,
|
113 |
-
# stopping_criteria=stopping_criteria_list,
|
114 |
-
# eos_token_id=llama_config.eos_token_id,
|
115 |
-
# pad_token_id=llama_config.eos_token_id
|
116 |
)
|
117 |
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
|
120 |
-
with torch.no_grad():
|
121 |
-
generation_output = model.generate(
|
122 |
-
input_ids=input_ids,
|
123 |
-
attention_mask=torch.ones_like(input_ids),
|
124 |
-
generation_config=generation_config,
|
125 |
-
# return_dict_in_generate=True,
|
126 |
-
# output_scores=True,
|
127 |
-
# eos_token_id=[tokenizer.eos_token_id],
|
128 |
-
use_cache=True,
|
129 |
-
)[0].cuda()
|
130 |
-
|
131 |
-
output_text = tokenizer.decode(generation_output)
|
132 |
-
return output_text.strip()
|
133 |
|
134 |
def tokenize_and_train(
|
135 |
training_text,
|
@@ -147,8 +119,11 @@ def tokenize_and_train(
|
|
147 |
global model
|
148 |
global tokenizer
|
149 |
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
152 |
|
153 |
tokenizer.pad_token_id = 0
|
154 |
|
@@ -156,6 +131,7 @@ def tokenize_and_train(
|
|
156 |
print("Number of samples: " + str(len(paragraphs)))
|
157 |
|
158 |
def tokenize(item):
|
|
|
159 |
result = tokenizer(
|
160 |
item["text"],
|
161 |
truncation=True,
|
@@ -171,12 +147,12 @@ def tokenize_and_train(
|
|
171 |
return {"text": text}
|
172 |
|
173 |
paragraphs = [to_dict(x) for x in paragraphs]
|
174 |
-
data = Dataset.from_list(paragraphs)
|
175 |
data = data.shuffle().map(lambda x: tokenize(x))
|
176 |
|
177 |
-
model = prepare_model_for_int8_training(model)
|
178 |
|
179 |
-
model = get_peft_model(model, LoraConfig(
|
180 |
r=lora_r,
|
181 |
lora_alpha=lora_alpha,
|
182 |
target_modules=["q_proj", "v_proj"],
|
@@ -261,22 +237,22 @@ def tokenize_and_train(
|
|
261 |
)
|
262 |
|
263 |
result = trainer.train(resume_from_checkpoint=False)
|
264 |
-
|
265 |
model.save_pretrained(output_dir)
|
266 |
-
|
267 |
-
reset_models()
|
268 |
|
269 |
return result
|
270 |
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
-
|
273 |
-
css="#refresh-button { max-width: 32px }",
|
274 |
-
title="Simple LLaMA Finetuner") as demo:
|
275 |
-
|
276 |
with gr.Tab("Finetuning"):
|
277 |
|
278 |
with gr.Column():
|
279 |
-
training_text = gr.Textbox(lines=12, label="Training Data", info="Each sequence must be separated by
|
280 |
|
281 |
max_seq_length = gr.Slider(
|
282 |
minimum=1, maximum=4096, value=512,
|
@@ -363,6 +339,7 @@ with gr.Blocks(
|
|
363 |
|
364 |
abort_button.click(None, None, None, cancels=[train_progress])
|
365 |
|
|
|
366 |
with gr.Tab("Inference"):
|
367 |
with gr.Row():
|
368 |
with gr.Column():
|
@@ -380,13 +357,13 @@ with gr.Blocks(
|
|
380 |
with gr.Column():
|
381 |
# temperature, top_p, top_k, repeat_penalty, max_new_tokens
|
382 |
temperature = gr.Slider(
|
383 |
-
minimum=0, maximum=1.99, value=0.
|
384 |
label="Temperature",
|
385 |
info="Controls the 'temperature' of the softmax distribution during sampling. Higher values (e.g., 1.0) make the model generate more diverse and random outputs, while lower values (e.g., 0.1) make it more deterministic and focused on the highest probability tokens."
|
386 |
)
|
387 |
|
388 |
top_p = gr.Slider(
|
389 |
-
minimum=0, maximum=1, value=0.
|
390 |
label="Top P",
|
391 |
info="Sets the nucleus sampling threshold. In nucleus sampling, only the tokens whose cumulative probability exceeds 'top_p' are considered for sampling. This technique helps to reduce the number of low probability tokens considered during sampling, which can lead to more diverse and coherent outputs."
|
392 |
)
|
@@ -398,7 +375,7 @@ with gr.Blocks(
|
|
398 |
)
|
399 |
|
400 |
repeat_penalty = gr.Slider(
|
401 |
-
minimum=0, maximum=
|
402 |
label="Repeat Penalty",
|
403 |
info="Applies a penalty to the probability of tokens that have already been generated, discouraging the model from repeating the same words or phrases. The penalty is applied by dividing the token probability by a factor based on the number of times the token has appeared in the generated text."
|
404 |
)
|
@@ -413,12 +390,8 @@ with gr.Blocks(
|
|
413 |
generate_btn = gr.Button(
|
414 |
"Generate", variant="primary", label="Generate",
|
415 |
)
|
416 |
-
|
417 |
-
inference_abort_button = gr.Button(
|
418 |
-
"Abort", label="Abort",
|
419 |
-
)
|
420 |
|
421 |
-
|
422 |
fn=generate_text,
|
423 |
inputs=[
|
424 |
lora_model,
|
@@ -432,10 +405,6 @@ with gr.Blocks(
|
|
432 |
outputs=inference_output,
|
433 |
)
|
434 |
|
435 |
-
lora_model.change(
|
436 |
-
fn=reset_models
|
437 |
-
)
|
438 |
-
|
439 |
def update_models_list():
|
440 |
return gr.Dropdown.update(choices=["None"] + [
|
441 |
d for d in os.listdir() if os.path.isdir(d) and d.startswith('lora-')
|
@@ -447,11 +416,15 @@ with gr.Blocks(
|
|
447 |
outputs=lora_model,
|
448 |
)
|
449 |
|
450 |
-
|
|
|
|
|
|
|
|
|
451 |
|
452 |
-
if __name__ ==
|
453 |
parser = argparse.ArgumentParser(description="Simple LLaMA Finetuner")
|
454 |
parser.add_argument("-s", "--share", action="store_true", help="Enable sharing of the Gradio interface")
|
455 |
args = parser.parse_args()
|
456 |
|
457 |
-
demo.queue().launch(share=args.share)
|
|
|
2 |
import argparse
|
3 |
import random
|
4 |
import torch
|
|
|
5 |
import transformers
|
6 |
+
import peft
|
7 |
+
import datasets
|
8 |
+
import gradio as gr
|
|
|
9 |
|
10 |
model = None
|
11 |
tokenizer = None
|
12 |
+
current_peft_model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
def load_base_model():
|
15 |
global model
|
16 |
+
print('Loading base model...')
|
17 |
+
model = transformers.LlamaForCausalLM.from_pretrained(
|
18 |
+
'decapoda-research/llama-7b-hf',
|
19 |
+
load_in_8bit=True,
|
20 |
+
torch_dtype=torch.float16,
|
21 |
+
device_map='auto'
|
22 |
+
)
|
23 |
|
24 |
+
def load_tokenizer():
|
25 |
+
global tokenizer
|
26 |
+
print('Loading tokenizer...')
|
27 |
+
tokenizer = transformers.LlamaTokenizer.from_pretrained(
|
28 |
+
'decapoda-research/llama-7b-hf',
|
29 |
+
)
|
|
|
30 |
|
31 |
+
def load_peft_model(model_name):
|
32 |
+
global model
|
33 |
+
print('Loading peft model ' + model_name + '...')
|
34 |
+
model = peft.PeftModel.from_pretrained(
|
35 |
+
model, model_name,
|
36 |
+
torch_dtype=torch.float16
|
37 |
+
)
|
38 |
|
39 |
+
def reset_model():
|
40 |
global model
|
41 |
global tokenizer
|
42 |
+
global current_peft_model
|
43 |
|
44 |
del model
|
45 |
del tokenizer
|
46 |
|
47 |
model = None
|
48 |
tokenizer = None
|
49 |
+
current_peft_model = None
|
50 |
|
51 |
def generate_text(
|
52 |
+
peft_model,
|
53 |
text,
|
54 |
temperature,
|
55 |
top_p,
|
56 |
top_k,
|
57 |
+
repetition_penalty,
|
58 |
max_new_tokens,
|
59 |
progress=gr.Progress(track_tqdm=True)
|
60 |
):
|
61 |
global model
|
62 |
global tokenizer
|
63 |
+
global current_peft_model
|
64 |
|
65 |
+
if (peft_model == 'None'): peft_model = None
|
66 |
|
67 |
+
if (current_peft_model != peft_model):
|
68 |
+
if (current_peft_model is None):
|
69 |
+
if (model is None): load_base_model()
|
70 |
+
else:
|
71 |
+
reset_model()
|
72 |
+
load_base_model()
|
73 |
+
load_tokenizer()
|
74 |
|
75 |
+
current_peft_model = peft_model
|
76 |
+
if (peft_model is not None):
|
77 |
+
load_peft_model(peft_model)
|
78 |
+
|
79 |
+
if (model is None): load_base_model()
|
80 |
+
if (tokenizer is None): load_tokenizer()
|
81 |
+
|
82 |
+
assert model is not None
|
83 |
+
assert tokenizer is not None
|
84 |
|
85 |
inputs = tokenizer(text, return_tensors="pt")
|
86 |
input_ids = inputs["input_ids"].to(model.device)
|
87 |
|
88 |
+
generation_config = transformers.GenerationConfig(
|
89 |
+
max_new_tokens=max_new_tokens,
|
90 |
+
temperature=temperature,
|
91 |
+
top_p=top_p,
|
92 |
+
top_k=top_k,
|
93 |
+
repetition_penalty=repetition_penalty,
|
94 |
do_sample=True,
|
95 |
+
num_beams=1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
)
|
97 |
|
98 |
+
output = model.generate( # type: ignore
|
99 |
+
input_ids=input_ids,
|
100 |
+
attention_mask=torch.ones_like(input_ids),
|
101 |
+
generation_config=generation_config
|
102 |
+
)[0].cuda()
|
103 |
|
104 |
+
return tokenizer.decode(output, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def tokenize_and_train(
|
107 |
training_text,
|
|
|
119 |
global model
|
120 |
global tokenizer
|
121 |
|
122 |
+
if (model is None): load_base_model()
|
123 |
+
if (tokenizer is None): load_tokenizer()
|
124 |
+
|
125 |
+
assert model is not None
|
126 |
+
assert tokenizer is not None
|
127 |
|
128 |
tokenizer.pad_token_id = 0
|
129 |
|
|
|
131 |
print("Number of samples: " + str(len(paragraphs)))
|
132 |
|
133 |
def tokenize(item):
|
134 |
+
assert tokenizer is not None
|
135 |
result = tokenizer(
|
136 |
item["text"],
|
137 |
truncation=True,
|
|
|
147 |
return {"text": text}
|
148 |
|
149 |
paragraphs = [to_dict(x) for x in paragraphs]
|
150 |
+
data = datasets.Dataset.from_list(paragraphs)
|
151 |
data = data.shuffle().map(lambda x: tokenize(x))
|
152 |
|
153 |
+
model = peft.prepare_model_for_int8_training(model)
|
154 |
|
155 |
+
model = peft.get_peft_model(model, peft.LoraConfig(
|
156 |
r=lora_r,
|
157 |
lora_alpha=lora_alpha,
|
158 |
target_modules=["q_proj", "v_proj"],
|
|
|
237 |
)
|
238 |
|
239 |
result = trainer.train(resume_from_checkpoint=False)
|
|
|
240 |
model.save_pretrained(output_dir)
|
241 |
+
reset_model()
|
|
|
242 |
|
243 |
return result
|
244 |
|
245 |
+
def random_hyphenated_word():
|
246 |
+
word_list = ['apple', 'banana', 'cherry', 'date', 'elderberry', 'fig']
|
247 |
+
word1 = random.choice(word_list)
|
248 |
+
word2 = random.choice(word_list)
|
249 |
+
return word1 + '-' + word2
|
250 |
|
251 |
+
def training_tab():
|
|
|
|
|
|
|
252 |
with gr.Tab("Finetuning"):
|
253 |
|
254 |
with gr.Column():
|
255 |
+
training_text = gr.Textbox(lines=12, label="Training Data", info="Each sequence must be separated by 2 blank lines")
|
256 |
|
257 |
max_seq_length = gr.Slider(
|
258 |
minimum=1, maximum=4096, value=512,
|
|
|
339 |
|
340 |
abort_button.click(None, None, None, cancels=[train_progress])
|
341 |
|
342 |
+
def inference_tab():
|
343 |
with gr.Tab("Inference"):
|
344 |
with gr.Row():
|
345 |
with gr.Column():
|
|
|
357 |
with gr.Column():
|
358 |
# temperature, top_p, top_k, repeat_penalty, max_new_tokens
|
359 |
temperature = gr.Slider(
|
360 |
+
minimum=0, maximum=1.99, value=0.4, step=0.01,
|
361 |
label="Temperature",
|
362 |
info="Controls the 'temperature' of the softmax distribution during sampling. Higher values (e.g., 1.0) make the model generate more diverse and random outputs, while lower values (e.g., 0.1) make it more deterministic and focused on the highest probability tokens."
|
363 |
)
|
364 |
|
365 |
top_p = gr.Slider(
|
366 |
+
minimum=0, maximum=1, value=0.3, step=0.01,
|
367 |
label="Top P",
|
368 |
info="Sets the nucleus sampling threshold. In nucleus sampling, only the tokens whose cumulative probability exceeds 'top_p' are considered for sampling. This technique helps to reduce the number of low probability tokens considered during sampling, which can lead to more diverse and coherent outputs."
|
369 |
)
|
|
|
375 |
)
|
376 |
|
377 |
repeat_penalty = gr.Slider(
|
378 |
+
minimum=0, maximum=2.5, value=1.0, step=0.01,
|
379 |
label="Repeat Penalty",
|
380 |
info="Applies a penalty to the probability of tokens that have already been generated, discouraging the model from repeating the same words or phrases. The penalty is applied by dividing the token probability by a factor based on the number of times the token has appeared in the generated text."
|
381 |
)
|
|
|
390 |
generate_btn = gr.Button(
|
391 |
"Generate", variant="primary", label="Generate",
|
392 |
)
|
|
|
|
|
|
|
|
|
393 |
|
394 |
+
generate_btn.click(
|
395 |
fn=generate_text,
|
396 |
inputs=[
|
397 |
lora_model,
|
|
|
405 |
outputs=inference_output,
|
406 |
)
|
407 |
|
|
|
|
|
|
|
|
|
408 |
def update_models_list():
|
409 |
return gr.Dropdown.update(choices=["None"] + [
|
410 |
d for d in os.listdir() if os.path.isdir(d) and d.startswith('lora-')
|
|
|
416 |
outputs=lora_model,
|
417 |
)
|
418 |
|
419 |
+
with gr.Blocks(
|
420 |
+
css="#refresh-button { max-width: 32px }",
|
421 |
+
title="Simple LLaMA Finetuner") as demo:
|
422 |
+
training_tab()
|
423 |
+
inference_tab()
|
424 |
|
425 |
+
if __name__ == '__main__':
|
426 |
parser = argparse.ArgumentParser(description="Simple LLaMA Finetuner")
|
427 |
parser.add_argument("-s", "--share", action="store_true", help="Enable sharing of the Gradio interface")
|
428 |
args = parser.parse_args()
|
429 |
|
430 |
+
demo.queue().launch(share=args.share)
|