Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
from flashsloth.constants import (
|
4 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
|
@@ -14,28 +15,39 @@ from flashsloth.mm_utils import (
|
|
14 |
from PIL import Image
|
15 |
import gradio as gr
|
16 |
|
17 |
-
|
18 |
from transformers import TextIteratorStreamer
|
19 |
from threading import Thread
|
20 |
|
21 |
-
|
22 |
disable_torch_init()
|
23 |
|
24 |
-
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
model.to('cuda')
|
29 |
-
model.eval()
|
30 |
|
31 |
-
|
|
|
|
|
|
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
keywords = ['</s>']
|
34 |
|
|
|
35 |
|
36 |
text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
|
37 |
text = text + LEARNABLE_TOKEN
|
38 |
-
|
39 |
|
40 |
image = image.convert('RGB')
|
41 |
if model.config.image_hd:
|
@@ -43,14 +55,12 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
|
|
43 |
else:
|
44 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
45 |
image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
|
46 |
-
|
47 |
|
48 |
conv = conv_templates["phi2"].copy()
|
49 |
conv.append_message(conv.roles[0], text)
|
50 |
conv.append_message(conv.roles[1], None)
|
51 |
prompt = conv.get_prompt()
|
52 |
|
53 |
-
|
54 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
55 |
input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
|
56 |
|
@@ -79,11 +89,9 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
|
|
79 |
with torch.inference_mode():
|
80 |
model.generate(**generation_kwargs)
|
81 |
|
82 |
-
# 在单独线程中运行生成,防止阻塞
|
83 |
generation_thread = Thread(target=_generate)
|
84 |
generation_thread.start()
|
85 |
|
86 |
-
# 边生成边yield输出
|
87 |
partial_text = ""
|
88 |
for new_text in streamer:
|
89 |
partial_text += new_text
|
@@ -91,7 +99,6 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
|
|
91 |
|
92 |
generation_thread.join()
|
93 |
|
94 |
-
# 自定义CSS样式,用于增大字体和美化界面
|
95 |
custom_css = """
|
96 |
<style>
|
97 |
/* 增大标题字体 */
|
@@ -152,10 +159,17 @@ with gr.Blocks(css=custom_css) as demo:
|
|
152 |
minimum=64,
|
153 |
maximum=3072,
|
154 |
step=1,
|
155 |
-
value=
|
156 |
label="Max Tokens"
|
157 |
)
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
with gr.Column(scale=1):
|
160 |
prompt_input = gr.Textbox(
|
161 |
lines=3,
|
@@ -173,10 +187,10 @@ with gr.Blocks(css=custom_css) as demo:
|
|
173 |
|
174 |
submit_button.click(
|
175 |
fn=generate_description,
|
176 |
-
inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider],
|
177 |
outputs=output_text,
|
178 |
show_progress=True
|
179 |
)
|
180 |
|
181 |
if __name__ == "__main__":
|
182 |
-
demo.queue().launch()
|
|
|
1 |
import os
|
2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
3 |
import torch
|
4 |
from flashsloth.constants import (
|
5 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
|
|
|
15 |
from PIL import Image
|
16 |
import gradio as gr
|
17 |
|
|
|
18 |
from transformers import TextIteratorStreamer
|
19 |
from threading import Thread
|
20 |
|
|
|
21 |
disable_torch_init()
|
22 |
|
23 |
+
MODEL_PATH_HD = "Tongbo/FlashSloth_HD-3.2B"
|
24 |
+
MODEL_PATH_NEW = "Tongbo/FlashSloth-3.2B"
|
25 |
|
26 |
+
model_name_hd = get_model_name_from_path(MODEL_PATH_HD)
|
27 |
+
model_name_new = get_model_name_from_path(MODEL_PATH_NEW)
|
|
|
|
|
28 |
|
29 |
+
models = {
|
30 |
+
"FlashSloth HD": load_pretrained_model(MODEL_PATH_HD, None, model_name_hd),
|
31 |
+
"FlashSloth": load_pretrained_model(MODEL_PATH_NEW, None, model_name_new)
|
32 |
+
}
|
33 |
|
34 |
+
for key in models:
|
35 |
+
tokenizer, model, image_processor, context_len = models[key]
|
36 |
+
model.to('cuda')
|
37 |
+
model.eval()
|
38 |
+
|
39 |
+
def generate_description(image, prompt_text, temperature, top_p, max_tokens, selected_model):
|
40 |
+
"""
|
41 |
+
生成图片描述的函数,支持流式输出,并根据选择的模型进行处理。
|
42 |
+
新增参数:
|
43 |
+
- selected_model: 用户选择的模型名称
|
44 |
+
"""
|
45 |
keywords = ['</s>']
|
46 |
|
47 |
+
tokenizer, model, image_processor, context_len = models[selected_model]
|
48 |
|
49 |
text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
|
50 |
text = text + LEARNABLE_TOKEN
|
|
|
51 |
|
52 |
image = image.convert('RGB')
|
53 |
if model.config.image_hd:
|
|
|
55 |
else:
|
56 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
57 |
image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
|
|
|
58 |
|
59 |
conv = conv_templates["phi2"].copy()
|
60 |
conv.append_message(conv.roles[0], text)
|
61 |
conv.append_message(conv.roles[1], None)
|
62 |
prompt = conv.get_prompt()
|
63 |
|
|
|
64 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
65 |
input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
|
66 |
|
|
|
89 |
with torch.inference_mode():
|
90 |
model.generate(**generation_kwargs)
|
91 |
|
|
|
92 |
generation_thread = Thread(target=_generate)
|
93 |
generation_thread.start()
|
94 |
|
|
|
95 |
partial_text = ""
|
96 |
for new_text in streamer:
|
97 |
partial_text += new_text
|
|
|
99 |
|
100 |
generation_thread.join()
|
101 |
|
|
|
102 |
custom_css = """
|
103 |
<style>
|
104 |
/* 增大标题字体 */
|
|
|
159 |
minimum=64,
|
160 |
maximum=3072,
|
161 |
step=1,
|
162 |
+
value=3072,
|
163 |
label="Max Tokens"
|
164 |
)
|
165 |
|
166 |
+
model_dropdown = gr.Dropdown(
|
167 |
+
choices=list(models.keys()),
|
168 |
+
value=list(models.keys())[0],
|
169 |
+
label="选择模型",
|
170 |
+
type="value"
|
171 |
+
)
|
172 |
+
|
173 |
with gr.Column(scale=1):
|
174 |
prompt_input = gr.Textbox(
|
175 |
lines=3,
|
|
|
187 |
|
188 |
submit_button.click(
|
189 |
fn=generate_description,
|
190 |
+
inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider, model_dropdown],
|
191 |
outputs=output_text,
|
192 |
show_progress=True
|
193 |
)
|
194 |
|
195 |
if __name__ == "__main__":
|
196 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=8888)
|