File size: 4,972 Bytes
e7fe941 21adf57 353582b baa7d24 6f43411 e7fe941 534e4f4 e7fe941 cfee34b 6f43411 1079f52 534e4f4 e7fe941 21adf57 e7fe941 5bbbd7d 534e4f4 baa7d24 e7fe941 baa7d24 5bbbd7d baa7d24 e7fe941 6f43411 5bbbd7d 6f43411 e7fe941 6f43411 e7fe941 6f43411 e7fe941 6f43411 e7fe941 353582b a869700 353582b a869700 b4c6c06 6f43411 a869700 534e4f4 a869700 baa7d24 a869700 6f43411 a869700 b4c6c06 a869700 6f43411 a869700 b4c6c06 baa7d24 b4c6c06 353582b 6f43411 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import torch
import gradio as gr
import tempfile
import secrets
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
max_size = 240
# Load Vision-Language Model
vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto"
)
vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", max_pixels=max_size*max_size)
# Load Text Model
model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
math_messages = []
def resize_image(image):
if isinstance(image, str): # Handle file paths
image = Image.open(image)
try:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
return image
except Exception as e:
print(f"Error resizing image: {e}")
return None
def process_image(image, shouldConvert=False):
global math_messages
math_messages = [] # Reset when uploading an image
if image is None:
return "No image provided."
if shouldConvert:
new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
new_img.paste(image, (0, 0), mask=image)
image = new_img
try:
inputs = vl_processor(images=resize_image(image), return_tensors="pt").to(device)
if inputs is None:
return "Error processing image."
generated_ids = vl_model.generate(**inputs)
output = vl_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
description = output[0] if output else ""
return f"Math-related content detected: {description}"
except Exception as e:
return f"Error in processing image: {str(e)}"
def get_math_response(image_description, user_question):
global math_messages
if not math_messages:
math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'})
math_messages = math_messages[:1]
content = f'Image description: {image_description}\n\n' if image_description else ''
query = f"{content}User question: {user_question}"
math_messages.append({'role': 'user', 'content': query})
model_inputs = tokenizer(query, return_tensors="pt").to(device)
output = model.generate(**model_inputs, max_new_tokens=512)
answer = tokenizer.decode(output[0], skip_special_tokens=True)
yield answer.replace("\\", "\\\\")
math_messages.append({'role': 'assistant', 'content': answer})
def math_chat_bot(image, sketchpad, question, state):
current_tab_index = state.get("tab_index", 0)
image_description = None
if current_tab_index == 0 and image is not None:
image_description = process_image(image)
elif current_tab_index == 1 and sketchpad and sketchpad.get("composite"):
image_description = process_image(sketchpad["composite"], True)
yield from get_math_response(image_description, question)
css = """
#qwen-md .katex-display { display: inline; }
#qwen-md .katex-display>.katex { display: inline; }
#qwen-md .katex-display>.katex>.katex-html { display: inline; }
"""
def tabs_select(e: gr.SelectData, _state):
_state["tab_index"] = e.index
with gr.Blocks(css=css) as demo:
gr.HTML("""
<p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/></p>
<center><font size=8>📖 Qwen2-Math Demo</font></center>
<center><font size=3>This WebUI uses Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning.</font></center>
""")
state = gr.State({"tab_index": 0})
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab("Upload"):
input_image = gr.Image(type="pil", label="Upload")
with gr.Tab("Sketch"):
input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
input_tabs.select(fn=tabs_select, inputs=[state])
input_text = gr.Textbox(label="Enter your question")
with gr.Row():
with gr.Column():
clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
with gr.Column():
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
submit_btn.click(
fn=math_chat_bot,
inputs=[input_image, input_sketchpad, input_text, state],
outputs=output_md)
if __name__ == "__main__":
demo.launch()
|