Multimodal-OCR / app.py
prithivMLmods's picture
Update app.py
5d63d59 verified
raw
history blame
4.68 kB
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
# Define model options
MODEL_OPTIONS = {
"Qwen2VL Base": "Qwen/Qwen2-VL-2B-Instruct",
"Latex OCR": "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
"Math Prase": "prithivMLmods/Qwen2-VL-Math-Prase-2B-Instruct",
"Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
}
# Global variables for model and processor
model = None
processor = None
# Function to load the selected model
def load_model(model_name):
global model, processor
model_id = MODEL_OPTIONS[model_name]
print(f"Loading model: {model_id}")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.float16
).to("cuda").eval()
print(f"Model {model_id} loaded successfully!")
return f"Model {model_name} loaded!"
@spaces.GPU
def model_inference(input_dict, history, model_choice):
global model, processor
# Load the selected model if not already loaded
if model is None or processor is None:
load_model(model_choice)
text = input_dict["text"]
files = input_dict["files"]
# Load images if provided
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
# Validate input
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
return
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
return
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
],
}
]
# Apply chat template and process inputs
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt],
images=images if images else None,
return_tensors="pt",
padding=True,
).to("cuda")
# Set up streamer for real-time output
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the output
buffer = ""
yield "Thinking..."
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# Example inputs
examples = [
[{"text": "Describe the document?", "files": ["example_images/document.jpg"]}],
[{"text": "Describe this image.", "files": ["example_images/campeones.jpg"]}],
[{"text": "What does this say?", "files": ["example_images/math.jpg"]}],
[{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
[{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
[{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
[{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# **Qwen2.5-VL-3B-Instruct**")
# Model selection dropdown
model_choice = gr.Dropdown(
label="Model Selection",
choices=list(MODEL_OPTIONS.keys()),
value="Latex OCR"
)
# Load model button
load_model_btn = gr.Button("Load Model")
load_model_output = gr.Textbox(label="Model Load Status")
# Chat interface
chat_interface = gr.ChatInterface(
fn=model_inference,
description="Interact with the selected Qwen2-VL model.",
examples=examples,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
additional_inputs=[model_choice] # Pass model_choice as an additional input
)
# Link the load model button to the load_model function
load_model_btn.click(load_model, inputs=model_choice, outputs=load_model_output)
# Launch the demo
demo.launch(debug=True)