Tongbo commited on
Commit
b0b2441
·
verified ·
1 Parent(s): 4fc316b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -47
app.py CHANGED
@@ -1,64 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
 
 
 
 
 
 
34
  temperature=temperature,
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
38
 
39
- response += token
40
- yield response
 
41
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import os
2
+ import torch
3
+ from flashsloth.constants import (
4
+ IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
5
+ DEFAULT_IM_END_TOKEN, LEARNABLE_TOKEN, LEARNABLE_TOKEN_INDEX
6
+ )
7
+ from flashsloth.conversation import conv_templates, SeparatorStyle
8
+ from flashsloth.model.builder import load_pretrained_model
9
+ from flashsloth.utils import disable_torch_init
10
+ from flashsloth.mm_utils import (
11
+ tokenizer_image_token, process_images, process_images_hd_inference,
12
+ get_model_name_from_path, KeywordsStoppingCriteria
13
+ )
14
+ from PIL import Image
15
  import gradio as gr
 
16
 
 
 
 
 
17
 
18
+ from transformers import TextIteratorStreamer
19
+ from threading import Thread
20
+
21
+
22
+ disable_torch_init()
23
+
24
+ MODEL_PATH = "Tongbo/FlashSloth_HD-3.2B"
25
+
26
+ model_name = get_model_name_from_path(MODEL_PATH)
27
+ tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None, model_name)
28
+ model.to('cuda')
29
+ model.eval()
30
+
31
+ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
32
+
33
+ keywords = ['</s>']
34
+
35
+
36
+ text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
37
+ text = text + LEARNABLE_TOKEN
38
+
39
+
40
+ image = image.convert('RGB')
41
+ if model.config.image_hd:
42
+ image_tensor = process_images_hd_inference([image], image_processor, model.config)[0]
43
+ else:
44
+ image_tensor = process_images([image], image_processor, model.config)[0]
45
+ image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
46
+
47
 
48
+ conv = conv_templates["phi2"].copy()
49
+ conv.append_message(conv.roles[0], text)
50
+ conv.append_message(conv.roles[1], None)
51
+ prompt = conv.get_prompt()
 
 
 
 
 
52
 
 
 
 
 
 
53
 
54
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
55
+ input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
56
 
57
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
58
 
59
+ streamer = TextIteratorStreamer(
60
+ tokenizer=tokenizer,
61
+ skip_prompt=True,
62
+ skip_special_tokens=True
63
+ )
64
+
65
+ generation_kwargs = dict(
66
+ inputs=input_ids,
67
+ images=image_tensor,
68
+ do_sample=True,
69
  temperature=temperature,
70
  top_p=top_p,
71
+ max_new_tokens=int(max_tokens),
72
+ use_cache=True,
73
+ eos_token_id=tokenizer.eos_token_id,
74
+ stopping_criteria=[stopping_criteria],
75
+ streamer=streamer
76
+ )
77
 
78
+ def _generate():
79
+ with torch.inference_mode():
80
+ model.generate(**generation_kwargs)
81
 
82
+ # 在单独线程中运行生成,防止阻塞
83
+ generation_thread = Thread(target=_generate)
84
+ generation_thread.start()
85
 
86
+ # 边生成边yield输出
87
+ partial_text = ""
88
+ for new_text in streamer:
89
+ partial_text += new_text
90
+ yield partial_text
91
+
92
+ generation_thread.join()
93
+
94
+ # 自定义CSS样式,用于增大字体和美化界面
95
+ custom_css = """
96
+ <style>
97
+ /* 增大标题字体 */
98
+ #title {
99
+ font-size: 80px !important;
100
+ text-align: center;
101
+ margin-bottom: 20px;
102
+ }
103
+
104
+ /* 增大描述文字字体 */
105
+ #description {
106
+ font-size: 24px !important;
107
+ text-align: center;
108
+ margin-bottom: 40px;
109
+ }
110
+
111
+ /* 增大标签和输入框的字体 */
112
+ .gradio-container * {
113
+ font-size: 18px !important;
114
+ }
115
+
116
+ /* 增大按钮字体 */
117
+ button {
118
+ font-size: 20px !important;
119
+ padding: 10px 20px;
120
+ }
121
+
122
+ /* 增大输出文本的字体 */
123
+ .output_text {
124
+ font-size: 20px !important;
125
+ }
126
+ </style>
127
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ with gr.Blocks(css=custom_css) as demo:
130
+ gr.HTML(custom_css)
131
+ gr.HTML("<h1 style='font-size:70px; text-align:center;'>FlashSloth 多模态大模型 Demo</h1>")
132
+
133
+ with gr.Row():
134
+ with gr.Column(scale=1):
135
+ image_input = gr.Image(type="pil", label="上传图片")
136
+
137
+ temperature_slider = gr.Slider(
138
+ minimum=0.01,
139
+ maximum=1.0,
140
+ step=0.05,
141
+ value=0.7,
142
+ label="Temperature"
143
+ )
144
+ topp_slider = gr.Slider(
145
+ minimum=0.01,
146
+ maximum=1.0,
147
+ step=0.05,
148
+ value=0.9,
149
+ label="Top-p"
150
+ )
151
+ maxtoken_slider = gr.Slider(
152
+ minimum=64,
153
+ maximum=3072,
154
+ step=1,
155
+ value=512,
156
+ label="Max Tokens"
157
+ )
158
+
159
+ with gr.Column(scale=1):
160
+ prompt_input = gr.Textbox(
161
+ lines=3,
162
+ placeholder="Describe this photo in detail.",
163
+ label="问题提示"
164
+ )
165
+ submit_button = gr.Button("生成答案", variant="primary")
166
+
167
+ output_text = gr.Textbox(
168
+ label="生成的答案",
169
+ interactive=False,
170
+ lines=15,
171
+ elem_classes=["output_text"]
172
+ )
173
+
174
+ submit_button.click(
175
+ fn=generate_description,
176
+ inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider],
177
+ outputs=output_text,
178
+ show_progress=True
179
+ )
180
 
181
  if __name__ == "__main__":
182
+ demo.queue().launch()