Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer | |
from transformers.image_utils import load_image | |
from threading import Thread | |
import time | |
import torch | |
import spaces | |
DESCRIPTION = """ | |
# Qwen2.5-VL-3B/7B-Instruct | |
""" | |
css = ''' | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: #fff; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
''' | |
# Define an animated progress bar HTML snippet | |
def progress_bar_html(label: str) -> str: | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
# Model IDs for 3B and 7B variants | |
MODEL_ID_3B = "Qwen/Qwen2.5-VL-3B-Instruct" | |
MODEL_ID_7B = "Qwen/Qwen2.5-VL-7B-Instruct" | |
# Load the processor and models for both versions | |
processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True) | |
model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_3B, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda").eval() | |
processor_7b = AutoProcessor.from_pretrained(MODEL_ID_7B, trust_remote_code=True) | |
model_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_7B, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda").eval() | |
def model_inference(input_dict, history): | |
text = input_dict["text"] | |
files = input_dict["files"] | |
# Determine which model to use based on the prefix tag | |
if text.lower().startswith("@3b"): | |
yield progress_bar_html("processing with Qwen2.5-VL-3B-Instruct") | |
selected_model = model_3b | |
selected_processor = processor_3b | |
text = text[len("@3b"):].strip() | |
elif text.lower().startswith("@7b"): | |
yield progress_bar_html("processing with Qwen2.5-VL-7B-Instruct") | |
selected_model = model_7b | |
selected_processor = processor_7b | |
text = text[len("@7b"):].strip() | |
else: | |
yield "Error: Please prefix your query with @3b or @7b to select the model." | |
return | |
# Load images if provided | |
if files: | |
if isinstance(files, list): | |
if len(files) > 1: | |
images = [load_image(image) for image in files] | |
elif len(files) == 1: | |
images = [load_image(files[0])] | |
else: | |
images = [] | |
else: | |
images = [load_image(files)] | |
else: | |
images = [] | |
# Validate input: text query is required | |
if text == "": | |
yield "Error: Please input a text query along with the image(s) if any." | |
return | |
# Prepare messages for the model | |
messages = [{ | |
"role": "user", | |
"content": [ | |
*[{"type": "image", "image": image} for image in images], | |
{"type": "text", "text": text}, | |
] | |
}] | |
# Apply the chat template and process the inputs | |
prompt = selected_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = selected_processor( | |
text=[prompt], | |
images=images if images else None, | |
return_tensors="pt", | |
padding=True, | |
).to("cuda") | |
# Set up a streamer for real-time text generation | |
streamer = TextIteratorStreamer(selected_processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
# Start generation in a separate thread | |
thread = Thread(target=selected_model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield an animated progress message | |
yield progress_bar_html("Thinking...") | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
# Example inputs with model prefixes | |
examples = [ | |
[{"text": "@3b Describe the document?", "files": ["example_images/document.jpg"]}], | |
[{"text": "@7b What does this say?", "files": ["example_images/math.jpg"]}], | |
[{"text": "@3b What is this UI about?", "files": ["example_images/s2w_example.png"]}], | |
[{"text": "@7b Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}], | |
] | |
demo = gr.ChatInterface( | |
fn=model_inference, | |
description=DESCRIPTION, | |
css=css, | |
examples=examples, | |
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="Use Tags @3b / @7b to trigger the models"), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
cache_examples=False, | |
) | |
demo.launch(debug=True) |