prithivMLmods's picture
Update app.py
a5d07a8 verified
raw
history blame
5.04 kB
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
DESCRIPTION = """
# Qwen2.5-VL-3B/7B-Instruct
"""
css = '''
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: #fff;
background: #1565c0;
border-radius: 100vh;
}
'''
# Define an animated progress bar HTML snippet
def progress_bar_html(label: str) -> str:
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
# Model IDs for 3B and 7B variants
MODEL_ID_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
MODEL_ID_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
# Load the processor and models for both versions
processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_3B,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
processor_7b = AutoProcessor.from_pretrained(MODEL_ID_7B, trust_remote_code=True)
model_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_7B,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
@spaces.GPU
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
# Determine which model to use based on the prefix tag
if text.lower().startswith("@3b"):
yield progress_bar_html("processing with Qwen2.5-VL-3B-Instruct")
selected_model = model_3b
selected_processor = processor_3b
text = text[len("@3b"):].strip()
elif text.lower().startswith("@7b"):
yield progress_bar_html("processing with Qwen2.5-VL-7B-Instruct")
selected_model = model_7b
selected_processor = processor_7b
text = text[len("@7b"):].strip()
else:
yield "Error: Please prefix your query with @3b or @7b to select the model."
return
# Load images if provided
if files:
if isinstance(files, list):
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
else:
images = [load_image(files)]
else:
images = []
# Validate input: text query is required
if text == "":
yield "Error: Please input a text query along with the image(s) if any."
return
# Prepare messages for the model
messages = [{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
]
}]
# Apply the chat template and process the inputs
prompt = selected_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = selected_processor(
text=[prompt],
images=images if images else None,
return_tensors="pt",
padding=True,
).to("cuda")
# Set up a streamer for real-time text generation
streamer = TextIteratorStreamer(selected_processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
# Start generation in a separate thread
thread = Thread(target=selected_model.generate, kwargs=generation_kwargs)
thread.start()
# Yield an animated progress message
yield progress_bar_html("Thinking...")
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# Example inputs with model prefixes
examples = [
[{"text": "@3b Describe the document?", "files": ["example_images/document.jpg"]}],
[{"text": "@7b What does this say?", "files": ["example_images/math.jpg"]}],
[{"text": "@3b What is this UI about?", "files": ["example_images/s2w_example.png"]}],
[{"text": "@7b Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
]
demo = gr.ChatInterface(
fn=model_inference,
description=DESCRIPTION,
css=css,
examples=examples,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="Use Tags @3b / @7b to trigger the models"),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
)
demo.launch(debug=True)