Tongbo commited on
Commit
b973fba
·
verified ·
1 Parent(s): 0f8fc82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  from flashsloth.constants import (
4
  IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
@@ -14,28 +15,39 @@ from flashsloth.mm_utils import (
14
  from PIL import Image
15
  import gradio as gr
16
 
17
-
18
  from transformers import TextIteratorStreamer
19
  from threading import Thread
20
 
21
-
22
  disable_torch_init()
23
 
24
- MODEL_PATH = "Tongbo/FlashSloth_HD-3.2B"
 
25
 
26
- model_name = get_model_name_from_path(MODEL_PATH)
27
- tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None, model_name)
28
- model.to('cuda')
29
- model.eval()
30
 
31
- def generate_description(image, prompt_text, temperature, top_p, max_tokens):
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  keywords = ['</s>']
34
 
 
35
 
36
  text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
37
  text = text + LEARNABLE_TOKEN
38
-
39
 
40
  image = image.convert('RGB')
41
  if model.config.image_hd:
@@ -43,14 +55,12 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
43
  else:
44
  image_tensor = process_images([image], image_processor, model.config)[0]
45
  image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
46
-
47
 
48
  conv = conv_templates["phi2"].copy()
49
  conv.append_message(conv.roles[0], text)
50
  conv.append_message(conv.roles[1], None)
51
  prompt = conv.get_prompt()
52
 
53
-
54
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
55
  input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
56
 
@@ -79,11 +89,9 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
79
  with torch.inference_mode():
80
  model.generate(**generation_kwargs)
81
 
82
- # 在单独线程中运行生成,防止阻塞
83
  generation_thread = Thread(target=_generate)
84
  generation_thread.start()
85
 
86
- # 边生成边yield输出
87
  partial_text = ""
88
  for new_text in streamer:
89
  partial_text += new_text
@@ -91,7 +99,6 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
91
 
92
  generation_thread.join()
93
 
94
- # 自定义CSS样式,用于增大字体和美化界面
95
  custom_css = """
96
  <style>
97
  /* 增大标题字体 */
@@ -152,10 +159,17 @@ with gr.Blocks(css=custom_css) as demo:
152
  minimum=64,
153
  maximum=3072,
154
  step=1,
155
- value=512,
156
  label="Max Tokens"
157
  )
158
 
 
 
 
 
 
 
 
159
  with gr.Column(scale=1):
160
  prompt_input = gr.Textbox(
161
  lines=3,
@@ -173,10 +187,10 @@ with gr.Blocks(css=custom_css) as demo:
173
 
174
  submit_button.click(
175
  fn=generate_description,
176
- inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider],
177
  outputs=output_text,
178
  show_progress=True
179
  )
180
 
181
  if __name__ == "__main__":
182
- demo.queue().launch()
 
1
  import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
3
  import torch
4
  from flashsloth.constants import (
5
  IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
 
15
  from PIL import Image
16
  import gradio as gr
17
 
 
18
  from transformers import TextIteratorStreamer
19
  from threading import Thread
20
 
 
21
  disable_torch_init()
22
 
23
+ MODEL_PATH_HD = "Tongbo/FlashSloth_HD-3.2B"
24
+ MODEL_PATH_NEW = "Tongbo/FlashSloth-3.2B"
25
 
26
+ model_name_hd = get_model_name_from_path(MODEL_PATH_HD)
27
+ model_name_new = get_model_name_from_path(MODEL_PATH_NEW)
 
 
28
 
29
+ models = {
30
+ "FlashSloth HD": load_pretrained_model(MODEL_PATH_HD, None, model_name_hd),
31
+ "FlashSloth": load_pretrained_model(MODEL_PATH_NEW, None, model_name_new)
32
+ }
33
 
34
+ for key in models:
35
+ tokenizer, model, image_processor, context_len = models[key]
36
+ model.to('cuda')
37
+ model.eval()
38
+
39
+ def generate_description(image, prompt_text, temperature, top_p, max_tokens, selected_model):
40
+ """
41
+ 生成图片描述的函数,支持流式输出,并根据选择的模型进行处理。
42
+ 新增参数:
43
+ - selected_model: 用户选择的模型名称
44
+ """
45
  keywords = ['</s>']
46
 
47
+ tokenizer, model, image_processor, context_len = models[selected_model]
48
 
49
  text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
50
  text = text + LEARNABLE_TOKEN
 
51
 
52
  image = image.convert('RGB')
53
  if model.config.image_hd:
 
55
  else:
56
  image_tensor = process_images([image], image_processor, model.config)[0]
57
  image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
 
58
 
59
  conv = conv_templates["phi2"].copy()
60
  conv.append_message(conv.roles[0], text)
61
  conv.append_message(conv.roles[1], None)
62
  prompt = conv.get_prompt()
63
 
 
64
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
65
  input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
66
 
 
89
  with torch.inference_mode():
90
  model.generate(**generation_kwargs)
91
 
 
92
  generation_thread = Thread(target=_generate)
93
  generation_thread.start()
94
 
 
95
  partial_text = ""
96
  for new_text in streamer:
97
  partial_text += new_text
 
99
 
100
  generation_thread.join()
101
 
 
102
  custom_css = """
103
  <style>
104
  /* 增大标题字体 */
 
159
  minimum=64,
160
  maximum=3072,
161
  step=1,
162
+ value=3072,
163
  label="Max Tokens"
164
  )
165
 
166
+ model_dropdown = gr.Dropdown(
167
+ choices=list(models.keys()),
168
+ value=list(models.keys())[0],
169
+ label="选择模型",
170
+ type="value"
171
+ )
172
+
173
  with gr.Column(scale=1):
174
  prompt_input = gr.Textbox(
175
  lines=3,
 
187
 
188
  submit_button.click(
189
  fn=generate_description,
190
+ inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider, model_dropdown],
191
  outputs=output_text,
192
  show_progress=True
193
  )
194
 
195
  if __name__ == "__main__":
196
+ demo.queue().launch(server_name="0.0.0.0", server_port=8888)