File size: 4,972 Bytes
e7fe941
21adf57
353582b
baa7d24
 
 
6f43411
e7fe941
 
534e4f4
 
e7fe941
cfee34b
6f43411
1079f52
534e4f4
e7fe941
 
 
21adf57
e7fe941
 
 
 
 
5bbbd7d
 
 
534e4f4
 
 
 
 
 
 
baa7d24
e7fe941
 
 
baa7d24
 
5bbbd7d
baa7d24
e7fe941
 
 
 
6f43411
5bbbd7d
6f43411
 
 
 
 
 
 
 
 
e7fe941
 
 
 
 
 
 
 
 
6f43411
e7fe941
 
 
 
 
 
 
6f43411
e7fe941
6f43411
 
 
 
e7fe941
353582b
a869700
 
 
 
 
353582b
a869700
 
 
 
b4c6c06
6f43411
 
 
 
a869700
 
 
 
 
534e4f4
a869700
baa7d24
a869700
6f43411
a869700
 
b4c6c06
a869700
 
 
6f43411
a869700
 
b4c6c06
baa7d24
b4c6c06
353582b
6f43411
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
import gradio as gr
import tempfile
import secrets
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image

max_size = 240

# Load Vision-Language Model
vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto"
)
vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", max_pixels=max_size*max_size)

# Load Text Model
model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

math_messages = []

def resize_image(image):
    if isinstance(image, str):  # Handle file paths
        image = Image.open(image)
    try:
        image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
        return image
    except Exception as e:
        print(f"Error resizing image: {e}")
        return None

def process_image(image, shouldConvert=False):
    global math_messages
    math_messages = []  # Reset when uploading an image

    if image is None:
        return "No image provided."
        
    if shouldConvert:
        new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
        new_img.paste(image, (0, 0), mask=image)
        image = new_img

    try:
        inputs = vl_processor(images=resize_image(image), return_tensors="pt").to(device)
        if inputs is None:
            return "Error processing image."

        generated_ids = vl_model.generate(**inputs)
        output = vl_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        description = output[0] if output else ""
        return f"Math-related content detected: {description}"
    except Exception as e:
        return f"Error in processing image: {str(e)}"

def get_math_response(image_description, user_question):
    global math_messages
    if not math_messages:
        math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'})
    math_messages = math_messages[:1]
    content = f'Image description: {image_description}\n\n' if image_description else ''
    query = f"{content}User question: {user_question}"
    math_messages.append({'role': 'user', 'content': query})
    
    model_inputs = tokenizer(query, return_tensors="pt").to(device)
    output = model.generate(**model_inputs, max_new_tokens=512)
    answer = tokenizer.decode(output[0], skip_special_tokens=True)
    yield answer.replace("\\", "\\\\")
    math_messages.append({'role': 'assistant', 'content': answer})

def math_chat_bot(image, sketchpad, question, state):
    current_tab_index = state.get("tab_index", 0)
    image_description = None
    if current_tab_index == 0 and image is not None:
        image_description = process_image(image)
    elif current_tab_index == 1 and sketchpad and sketchpad.get("composite"):
        image_description = process_image(sketchpad["composite"], True)
    yield from get_math_response(image_description, question)

css = """
#qwen-md .katex-display { display: inline; }
#qwen-md .katex-display>.katex { display: inline; }
#qwen-md .katex-display>.katex>.katex-html { display: inline; }
"""

def tabs_select(e: gr.SelectData, _state):
    _state["tab_index"] = e.index

with gr.Blocks(css=css) as demo:
    gr.HTML("""
    <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/></p>
    <center><font size=8>📖 Qwen2-Math Demo</font></center>
    <center><font size=3>This WebUI uses Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning.</font></center>
    """)
    state = gr.State({"tab_index": 0})
    with gr.Row():
        with gr.Column():
            with gr.Tabs() as input_tabs:
                with gr.Tab("Upload"):
                    input_image = gr.Image(type="pil", label="Upload")
                with gr.Tab("Sketch"):
                    input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
            input_tabs.select(fn=tabs_select, inputs=[state])
            input_text = gr.Textbox(label="Enter your question")
            with gr.Row():
                with gr.Column():
                    clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
                with gr.Column():
                    submit_btn = gr.Button("Submit", variant="primary")
        with gr.Column():
            output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
        submit_btn.click(
            fn=math_chat_bot,
            inputs=[input_image, input_sketchpad, input_text, state],
            outputs=output_md)

if __name__ == "__main__":
    demo.launch()