JimmyK300's picture
Update app.py
534e4f4 verified
import os
import torch
import gradio as gr
import tempfile
import secrets
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
max_size = 240
# Load Vision-Language Model
vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto"
)
vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", max_pixels=max_size*max_size)
# Load Text Model
model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
math_messages = []
def resize_image(image):
if isinstance(image, str): # Handle file paths
image = Image.open(image)
try:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
return image
except Exception as e:
print(f"Error resizing image: {e}")
return None
def process_image(image, shouldConvert=False):
global math_messages
math_messages = [] # Reset when uploading an image
if image is None:
return "No image provided."
if shouldConvert:
new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
new_img.paste(image, (0, 0), mask=image)
image = new_img
try:
inputs = vl_processor(images=resize_image(image), return_tensors="pt").to(device)
if inputs is None:
return "Error processing image."
generated_ids = vl_model.generate(**inputs)
output = vl_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
description = output[0] if output else ""
return f"Math-related content detected: {description}"
except Exception as e:
return f"Error in processing image: {str(e)}"
def get_math_response(image_description, user_question):
global math_messages
if not math_messages:
math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'})
math_messages = math_messages[:1]
content = f'Image description: {image_description}\n\n' if image_description else ''
query = f"{content}User question: {user_question}"
math_messages.append({'role': 'user', 'content': query})
model_inputs = tokenizer(query, return_tensors="pt").to(device)
output = model.generate(**model_inputs, max_new_tokens=512)
answer = tokenizer.decode(output[0], skip_special_tokens=True)
yield answer.replace("\\", "\\\\")
math_messages.append({'role': 'assistant', 'content': answer})
def math_chat_bot(image, sketchpad, question, state):
current_tab_index = state.get("tab_index", 0)
image_description = None
if current_tab_index == 0 and image is not None:
image_description = process_image(image)
elif current_tab_index == 1 and sketchpad and sketchpad.get("composite"):
image_description = process_image(sketchpad["composite"], True)
yield from get_math_response(image_description, question)
css = """
#qwen-md .katex-display { display: inline; }
#qwen-md .katex-display>.katex { display: inline; }
#qwen-md .katex-display>.katex>.katex-html { display: inline; }
"""
def tabs_select(e: gr.SelectData, _state):
_state["tab_index"] = e.index
with gr.Blocks(css=css) as demo:
gr.HTML("""
<p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/></p>
<center><font size=8>πŸ“– Qwen2-Math Demo</font></center>
<center><font size=3>This WebUI uses Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning.</font></center>
""")
state = gr.State({"tab_index": 0})
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab("Upload"):
input_image = gr.Image(type="pil", label="Upload")
with gr.Tab("Sketch"):
input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
input_tabs.select(fn=tabs_select, inputs=[state])
input_text = gr.Textbox(label="Enter your question")
with gr.Row():
with gr.Column():
clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
with gr.Column():
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
submit_btn.click(
fn=math_chat_bot,
inputs=[input_image, input_sketchpad, input_text, state],
outputs=output_md)
if __name__ == "__main__":
demo.launch()