Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,035 Bytes
09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 a5d07a8 09dd649 6ad43a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
DESCRIPTION = """
# Qwen2.5-VL-3B/7B-Instruct
"""
css = '''
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: #fff;
background: #1565c0;
border-radius: 100vh;
}
'''
# Define an animated progress bar HTML snippet
def progress_bar_html(label: str) -> str:
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
# Model IDs for 3B and 7B variants
MODEL_ID_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
MODEL_ID_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
# Load the processor and models for both versions
processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_3B,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
processor_7b = AutoProcessor.from_pretrained(MODEL_ID_7B, trust_remote_code=True)
model_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_7B,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
@spaces.GPU
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
# Determine which model to use based on the prefix tag
if text.lower().startswith("@3b"):
yield progress_bar_html("processing with Qwen2.5-VL-3B-Instruct")
selected_model = model_3b
selected_processor = processor_3b
text = text[len("@3b"):].strip()
elif text.lower().startswith("@7b"):
yield progress_bar_html("processing with Qwen2.5-VL-7B-Instruct")
selected_model = model_7b
selected_processor = processor_7b
text = text[len("@7b"):].strip()
else:
yield "Error: Please prefix your query with @3b or @7b to select the model."
return
# Load images if provided
if files:
if isinstance(files, list):
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
else:
images = [load_image(files)]
else:
images = []
# Validate input: text query is required
if text == "":
yield "Error: Please input a text query along with the image(s) if any."
return
# Prepare messages for the model
messages = [{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
]
}]
# Apply the chat template and process the inputs
prompt = selected_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = selected_processor(
text=[prompt],
images=images if images else None,
return_tensors="pt",
padding=True,
).to("cuda")
# Set up a streamer for real-time text generation
streamer = TextIteratorStreamer(selected_processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
# Start generation in a separate thread
thread = Thread(target=selected_model.generate, kwargs=generation_kwargs)
thread.start()
# Yield an animated progress message
yield progress_bar_html("Thinking...")
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# Example inputs with model prefixes
examples = [
[{"text": "@3b Describe the document?", "files": ["example_images/document.jpg"]}],
[{"text": "@7b What does this say?", "files": ["example_images/math.jpg"]}],
[{"text": "@3b What is this UI about?", "files": ["example_images/s2w_example.png"]}],
[{"text": "@7b Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
]
demo = gr.ChatInterface(
fn=model_inference,
description=DESCRIPTION,
css=css,
examples=examples,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="Use Tags @3b / @7b to trigger the models"),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
)
demo.launch(debug=True) |