Reload models better
Browse files
main.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import argparse
|
|
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
import transformers
|
@@ -12,6 +13,12 @@ model = None
|
|
12 |
tokenizer = None
|
13 |
peft_model = None
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def maybe_load_models():
|
16 |
global model
|
17 |
global tokenizer
|
@@ -29,8 +36,6 @@ def maybe_load_models():
|
|
29 |
"decapoda-research/llama-7b-hf",
|
30 |
)
|
31 |
|
32 |
-
return model, tokenizer
|
33 |
-
|
34 |
def reset_models():
|
35 |
global model
|
36 |
global tokenizer
|
@@ -51,7 +56,10 @@ def generate_text(
|
|
51 |
max_new_tokens,
|
52 |
progress=gr.Progress(track_tqdm=True)
|
53 |
):
|
54 |
-
model
|
|
|
|
|
|
|
55 |
|
56 |
if model_name and model_name != "None":
|
57 |
model = PeftModel.from_pretrained(
|
@@ -123,7 +131,11 @@ def tokenize_and_train(
|
|
123 |
model_name,
|
124 |
progress=gr.Progress(track_tqdm=True)
|
125 |
):
|
126 |
-
model
|
|
|
|
|
|
|
|
|
127 |
|
128 |
tokenizer.pad_token_id = 0
|
129 |
|
@@ -302,7 +314,7 @@ with gr.Blocks(css="#refresh-button { max-width: 32px }") as demo:
|
|
302 |
|
303 |
with gr.Column():
|
304 |
model_name = gr.Textbox(
|
305 |
-
lines=1, label="LoRA Model Name", value=
|
306 |
)
|
307 |
|
308 |
with gr.Row():
|
|
|
1 |
import os
|
2 |
import argparse
|
3 |
+
import random
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
import transformers
|
|
|
13 |
tokenizer = None
|
14 |
peft_model = None
|
15 |
|
16 |
+
def random_hyphenated_word():
|
17 |
+
word_list = ['apple', 'banana', 'cherry', 'date', 'elderberry', 'fig']
|
18 |
+
word1 = random.choice(word_list)
|
19 |
+
word2 = random.choice(word_list)
|
20 |
+
return word1 + '-' + word2
|
21 |
+
|
22 |
def maybe_load_models():
|
23 |
global model
|
24 |
global tokenizer
|
|
|
36 |
"decapoda-research/llama-7b-hf",
|
37 |
)
|
38 |
|
|
|
|
|
39 |
def reset_models():
|
40 |
global model
|
41 |
global tokenizer
|
|
|
56 |
max_new_tokens,
|
57 |
progress=gr.Progress(track_tqdm=True)
|
58 |
):
|
59 |
+
global model
|
60 |
+
global tokenizer
|
61 |
+
|
62 |
+
maybe_load_models()
|
63 |
|
64 |
if model_name and model_name != "None":
|
65 |
model = PeftModel.from_pretrained(
|
|
|
131 |
model_name,
|
132 |
progress=gr.Progress(track_tqdm=True)
|
133 |
):
|
134 |
+
global model
|
135 |
+
global tokenizer
|
136 |
+
|
137 |
+
reset_models()
|
138 |
+
maybe_load_models()
|
139 |
|
140 |
tokenizer.pad_token_id = 0
|
141 |
|
|
|
314 |
|
315 |
with gr.Column():
|
316 |
model_name = gr.Textbox(
|
317 |
+
lines=1, label="LoRA Model Name", value=random_hyphenated_word()
|
318 |
)
|
319 |
|
320 |
with gr.Row():
|