JimmyK300 commited on
Commit
b4c6c06
·
verified ·
1 Parent(s): 1079f52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -73
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import os
2
  import torch
3
  import gradio as gr
4
- import tempfile
5
- import secrets
6
- from pathlib import Path
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, BlipForConditionalGeneration, AutoProcessor, Qwen2VLForConditionalGeneration
8
  from qwen_vl_utils import process_vision_info
9
  from PIL import Image
10
 
@@ -12,9 +9,6 @@ from PIL import Image
12
  vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
13
  "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
14
  )
15
- Qwen2VLForConditionalGeneration.from_pretrained(
16
- "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
17
- )
18
  vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
19
 
20
  # Load Text Model
@@ -25,27 +19,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
 
26
  math_messages = []
27
 
28
- def process_image(image, shouldConvert=False):
29
  global math_messages
30
  math_messages = [] # Reset when uploading an image
31
 
32
- if shouldConvert:
33
  new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
34
  new_img.paste(image, (0, 0), mask=image)
35
  image = new_img
36
 
37
- # Convert the image to tensor
38
- inputs = vl_processor(images=image, return_tensors="pt")
39
  generated_ids = vl_model.generate(**inputs)
40
- generated_ids_trimmed = [
41
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
42
- ]
43
- output = processor.batch_decode(
44
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
45
- )
46
- description = vl_processor.batch_decode(output, skip_special_tokens=True)[0]
47
-
48
- return f"Math-related content detected: {description}"
49
 
50
  def get_math_response(image_description, user_question):
51
  global math_messages
@@ -55,21 +41,23 @@ def get_math_response(image_description, user_question):
55
  content = f'Image description: {image_description}\n\n' if image_description else ''
56
  query = f"{content}User question: {user_question}"
57
  math_messages.append({'role': 'user', 'content': query})
 
58
  model_inputs = tokenizer(query, return_tensors="pt").to(device)
59
  output = model.generate(**model_inputs, max_new_tokens=512)
60
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
 
61
  yield answer.replace("\\", "\\\\")
62
  math_messages.append({'role': 'assistant', 'content': answer})
63
 
64
  def math_chat_bot(image, sketchpad, question, state):
65
  current_tab_index = state["tab_index"]
66
  image_description = None
67
- if current_tab_index == 0:
68
- if image is not None:
69
- image_description = process_image(image)
70
- elif current_tab_index == 1:
71
- if sketchpad and sketchpad["composite"]:
72
- image_description = process_image(sketchpad["composite"], True)
73
  yield from get_math_response(image_description, question)
74
 
75
  css = """
@@ -81,67 +69,41 @@ css = """
81
  def tabs_select(e: gr.SelectData, _state):
82
  _state["tab_index"] = e.index
83
 
84
-
85
- # 创建Gradio接口
86
  with gr.Blocks(css=css) as demo:
87
- gr.HTML("""\
88
- <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/><p>"""
89
- """<center><font size=8>📖 Qwen2-Math Demo</center>"""
90
- """\
91
- <center><font size=3>This WebUI is based on Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning. You can input either images or texts of mathematical or arithmetic problems.</center>"""
92
- )
93
  state = gr.State({"tab_index": 0})
 
94
  with gr.Row():
95
  with gr.Column():
96
  with gr.Tabs() as input_tabs:
97
  with gr.Tab("Upload"):
98
- input_image = gr.Image(type="pil", label="Upload"),
99
  with gr.Tab("Sketch"):
100
- input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
 
101
  input_tabs.select(fn=tabs_select, inputs=[state])
102
- input_text = gr.Textbox(label="input your question")
 
103
  with gr.Row():
104
  with gr.Column():
105
- clear_btn = gr.ClearButton(
106
- [*input_image, input_sketchpad, input_text])
107
  with gr.Column():
108
  submit_btn = gr.Button("Submit", variant="primary")
 
109
  with gr.Column():
110
- output_md = gr.Markdown(label="answer",
111
- latex_delimiters=[{
112
- "left": "\\(",
113
- "right": "\\)",
114
- "display": True
115
- }, {
116
- "left": "\\begin\{equation\}",
117
- "right": "\\end\{equation\}",
118
- "display": True
119
- }, {
120
- "left": "\\begin\{align\}",
121
- "right": "\\end\{align\}",
122
- "display": True
123
- }, {
124
- "left": "\\begin\{alignat\}",
125
- "right": "\\end\{alignat\}",
126
- "display": True
127
- }, {
128
- "left": "\\begin\{gather\}",
129
- "right": "\\end\{gather\}",
130
- "display": True
131
- }, {
132
- "left": "\\begin\{CD\}",
133
- "right": "\\end\{CD\}",
134
- "display": True
135
- }, {
136
- "left": "\\[",
137
- "right": "\\]",
138
- "display": True
139
- }],
140
- elem_id="qwen-md")
141
  submit_btn.click(
142
  fn=math_chat_bot,
143
- inputs=[*input_image, input_sketchpad, input_text, state],
144
- outputs=output_md)
145
-
 
146
  if __name__ == "__main__":
147
  demo.launch()
 
1
  import os
2
  import torch
3
  import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
 
 
 
5
  from qwen_vl_utils import process_vision_info
6
  from PIL import Image
7
 
 
9
  vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
10
  "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
11
  )
 
 
 
12
  vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
13
 
14
  # Load Text Model
 
19
 
20
  math_messages = []
21
 
22
+ def process_image(image, should_convert=False):
23
  global math_messages
24
  math_messages = [] # Reset when uploading an image
25
 
26
+ if should_convert:
27
  new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
28
  new_img.paste(image, (0, 0), mask=image)
29
  image = new_img
30
 
31
+ inputs = vl_processor(images=image, return_tensors="pt").to(device)
 
32
  generated_ids = vl_model.generate(**inputs)
33
+ output = vl_processor.batch_decode(generated_ids, skip_special_tokens=True)
34
+ return f"Math-related content detected: {output[0]}"
 
 
 
 
 
 
 
35
 
36
  def get_math_response(image_description, user_question):
37
  global math_messages
 
41
  content = f'Image description: {image_description}\n\n' if image_description else ''
42
  query = f"{content}User question: {user_question}"
43
  math_messages.append({'role': 'user', 'content': query})
44
+
45
  model_inputs = tokenizer(query, return_tensors="pt").to(device)
46
  output = model.generate(**model_inputs, max_new_tokens=512)
47
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
48
+
49
  yield answer.replace("\\", "\\\\")
50
  math_messages.append({'role': 'assistant', 'content': answer})
51
 
52
  def math_chat_bot(image, sketchpad, question, state):
53
  current_tab_index = state["tab_index"]
54
  image_description = None
55
+
56
+ if current_tab_index == 0 and image is not None:
57
+ image_description = process_image(image)
58
+ elif current_tab_index == 1 and sketchpad and sketchpad["composite"]:
59
+ image_description = process_image(sketchpad["composite"], True)
60
+
61
  yield from get_math_response(image_description, question)
62
 
63
  css = """
 
69
  def tabs_select(e: gr.SelectData, _state):
70
  _state["tab_index"] = e.index
71
 
72
+ # Create Gradio UI
 
73
  with gr.Blocks(css=css) as demo:
74
+ gr.HTML("""
75
+ <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/></p>
76
+ <center><font size=8>📖 Qwen2-Math Demo</font></center>
77
+ <center><font size=3>This WebUI is based on Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning. You can input either images or texts of mathematical or arithmetic problems.</font></center>
78
+ """)
79
+
80
  state = gr.State({"tab_index": 0})
81
+
82
  with gr.Row():
83
  with gr.Column():
84
  with gr.Tabs() as input_tabs:
85
  with gr.Tab("Upload"):
86
+ input_image = gr.Image(type="pil", label="Upload")
87
  with gr.Tab("Sketch"):
88
+ input_sketchpad = gr.Sketchpad(label="Sketch", layers=False)
89
+
90
  input_tabs.select(fn=tabs_select, inputs=[state])
91
+ input_text = gr.Textbox(label="Input your question")
92
+
93
  with gr.Row():
94
  with gr.Column():
95
+ clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
 
96
  with gr.Column():
97
  submit_btn = gr.Button("Submit", variant="primary")
98
+
99
  with gr.Column():
100
+ output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
101
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  submit_btn.click(
103
  fn=math_chat_bot,
104
+ inputs=[input_image, input_sketchpad, input_text, state],
105
+ outputs=output_md
106
+ )
107
+
108
  if __name__ == "__main__":
109
  demo.launch()