JimmyK300 commited on
Commit
6f43411
·
verified ·
1 Parent(s): baa7d24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -30
app.py CHANGED
@@ -4,13 +4,12 @@ import gradio as gr
4
  import tempfile
5
  import secrets
6
  from pathlib import Path
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, BlipForConditionalGeneration, AutoProcessor, Qwen2VLForConditionalGeneration
8
- from qwen_vl_utils import process_vision_info
9
  from PIL import Image
10
 
11
  # Load Vision-Language Model
12
  vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
13
- "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
14
  )
15
  vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
16
 
@@ -34,18 +33,17 @@ def process_image(image, shouldConvert=False):
34
  new_img.paste(image, (0, 0), mask=image)
35
  image = new_img
36
 
37
- # Convert the image to tensor
38
- inputs = vl_processor(images=image, return_tensors="pt").to(device)
39
- generated_ids = vl_model.generate(**inputs)
40
- generated_ids_trimmed = [
41
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
42
- ]
43
- output = vl_processor.batch_decode(
44
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
45
- )
46
- description = output[0] if output else ""
47
-
48
- return f"Math-related content detected: {description}"
49
 
50
  def get_math_response(image_description, user_question):
51
  global math_messages
@@ -55,6 +53,7 @@ def get_math_response(image_description, user_question):
55
  content = f'Image description: {image_description}\n\n' if image_description else ''
56
  query = f"{content}User question: {user_question}"
57
  math_messages.append({'role': 'user', 'content': query})
 
58
  model_inputs = tokenizer(query, return_tensors="pt").to(device)
59
  output = model.generate(**model_inputs, max_new_tokens=512)
60
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
@@ -62,14 +61,12 @@ def get_math_response(image_description, user_question):
62
  math_messages.append({'role': 'assistant', 'content': answer})
63
 
64
  def math_chat_bot(image, sketchpad, question, state):
65
- current_tab_index = state["tab_index"]
66
  image_description = None
67
- if current_tab_index == 0:
68
- if image is not None:
69
- image_description = process_image(image)
70
- elif current_tab_index == 1:
71
- if sketchpad and sketchpad.get("composite"):
72
- image_description = process_image(sketchpad["composite"], True)
73
  yield from get_math_response(image_description, question)
74
 
75
  css = """
@@ -81,13 +78,12 @@ css = """
81
  def tabs_select(e: gr.SelectData, _state):
82
  _state["tab_index"] = e.index
83
 
84
- # 创建Gradio接口
85
  with gr.Blocks(css=css) as demo:
86
  gr.HTML("""
87
- <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/><p>"""
88
- """<center><font size=8>📖 Qwen2-Math Demo</center>"""
89
- """
90
- <center><font size=3>This WebUI is based on Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning. You can input either images or texts of mathematical or arithmetic problems.</center>""")
91
  state = gr.State({"tab_index": 0})
92
  with gr.Row():
93
  with gr.Column():
@@ -97,18 +93,18 @@ with gr.Blocks(css=css) as demo:
97
  with gr.Tab("Sketch"):
98
  input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
99
  input_tabs.select(fn=tabs_select, inputs=[state])
100
- input_text = gr.Textbox(label="input your question")
101
  with gr.Row():
102
  with gr.Column():
103
  clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
104
  with gr.Column():
105
  submit_btn = gr.Button("Submit", variant="primary")
106
  with gr.Column():
107
- output_md = gr.Markdown(label="answer", elem_id="qwen-md")
108
  submit_btn.click(
109
  fn=math_chat_bot,
110
  inputs=[input_image, input_sketchpad, input_text, state],
111
  outputs=output_md)
112
 
113
  if __name__ == "__main__":
114
- demo.launch()
 
4
  import tempfile
5
  import secrets
6
  from pathlib import Path
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
 
8
  from PIL import Image
9
 
10
  # Load Vision-Language Model
11
  vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
12
+ "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto"
13
  )
14
  vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
15
 
 
33
  new_img.paste(image, (0, 0), mask=image)
34
  image = new_img
35
 
36
+ try:
37
+ inputs = vl_processor(images=image, return_tensors="pt").to(device)
38
+ if inputs is None:
39
+ return "Error processing image."
40
+
41
+ generated_ids = vl_model.generate(**inputs)
42
+ output = vl_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
43
+ description = output[0] if output else ""
44
+ return f"Math-related content detected: {description}"
45
+ except Exception as e:
46
+ return f"Error in processing image: {str(e)}"
 
47
 
48
  def get_math_response(image_description, user_question):
49
  global math_messages
 
53
  content = f'Image description: {image_description}\n\n' if image_description else ''
54
  query = f"{content}User question: {user_question}"
55
  math_messages.append({'role': 'user', 'content': query})
56
+
57
  model_inputs = tokenizer(query, return_tensors="pt").to(device)
58
  output = model.generate(**model_inputs, max_new_tokens=512)
59
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
 
61
  math_messages.append({'role': 'assistant', 'content': answer})
62
 
63
  def math_chat_bot(image, sketchpad, question, state):
64
+ current_tab_index = state.get("tab_index", 0)
65
  image_description = None
66
+ if current_tab_index == 0 and image is not None:
67
+ image_description = process_image(image)
68
+ elif current_tab_index == 1 and sketchpad and sketchpad.get("composite"):
69
+ image_description = process_image(sketchpad["composite"], True)
 
 
70
  yield from get_math_response(image_description, question)
71
 
72
  css = """
 
78
  def tabs_select(e: gr.SelectData, _state):
79
  _state["tab_index"] = e.index
80
 
 
81
  with gr.Blocks(css=css) as demo:
82
  gr.HTML("""
83
+ <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/></p>
84
+ <center><font size=8>📖 Qwen2-Math Demo</font></center>
85
+ <center><font size=3>This WebUI uses Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning.</font></center>
86
+ """)
87
  state = gr.State({"tab_index": 0})
88
  with gr.Row():
89
  with gr.Column():
 
93
  with gr.Tab("Sketch"):
94
  input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
95
  input_tabs.select(fn=tabs_select, inputs=[state])
96
+ input_text = gr.Textbox(label="Enter your question")
97
  with gr.Row():
98
  with gr.Column():
99
  clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
100
  with gr.Column():
101
  submit_btn = gr.Button("Submit", variant="primary")
102
  with gr.Column():
103
+ output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
104
  submit_btn.click(
105
  fn=math_chat_bot,
106
  inputs=[input_image, input_sketchpad, input_text, state],
107
  outputs=output_md)
108
 
109
  if __name__ == "__main__":
110
+ demo.launch()