File size: 5,035 Bytes
09dd649
 
 
 
 
 
 
 
a5d07a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09dd649
 
 
 
 
 
 
 
 
a5d07a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09dd649
a5d07a8
 
 
 
 
 
 
 
 
 
09dd649
 
 
a5d07a8
 
 
09dd649
 
 
a5d07a8
 
 
 
 
 
 
 
 
 
 
09dd649
 
 
 
 
 
a5d07a8
 
09dd649
 
 
a5d07a8
09dd649
 
a5d07a8
 
 
09dd649
 
 
 
 
 
a5d07a8
09dd649
a5d07a8
 
 
 
09dd649
 
 
 
a5d07a8
 
09dd649
a5d07a8
09dd649
 
 
 
 
6ad43a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces

DESCRIPTION = """
# Qwen2.5-VL-3B/7B-Instruct
"""

css = '''
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: #fff;
  background: #1565c0;
  border-radius: 100vh;
}
'''

# Define an animated progress bar HTML snippet
def progress_bar_html(label: str) -> str:
    return f'''
    <div style="display: flex; align-items: center;">
        <span style="margin-right: 10px; font-size: 14px;">{label}</span>
        <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
            <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
        </div>
    </div>
    <style>
    @keyframes loading {{
        0% {{ transform: translateX(-100%); }}
        100% {{ transform: translateX(100%); }}
    }}
    </style>
    '''

# Model IDs for 3B and 7B variants
MODEL_ID_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
MODEL_ID_7B = "Qwen/Qwen2.5-VL-7B-Instruct"

# Load the processor and models for both versions
processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID_3B,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

processor_7b = AutoProcessor.from_pretrained(MODEL_ID_7B, trust_remote_code=True)
model_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID_7B,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

@spaces.GPU
def model_inference(input_dict, history):
    text = input_dict["text"]
    files = input_dict["files"]

    # Determine which model to use based on the prefix tag
    if text.lower().startswith("@3b"):
        yield progress_bar_html("processing with Qwen2.5-VL-3B-Instruct")
        selected_model = model_3b
        selected_processor = processor_3b
        text = text[len("@3b"):].strip()
    elif text.lower().startswith("@7b"):
        yield progress_bar_html("processing with Qwen2.5-VL-7B-Instruct")
        selected_model = model_7b
        selected_processor = processor_7b
        text = text[len("@7b"):].strip()
    else:
        yield "Error: Please prefix your query with @3b or @7b to select the model."
        return

    # Load images if provided
    if files:
        if isinstance(files, list):
            if len(files) > 1:
                images = [load_image(image) for image in files]
            elif len(files) == 1:
                images = [load_image(files[0])]
            else:
                images = []
        else:
            images = [load_image(files)]
    else:
        images = []

    # Validate input: text query is required
    if text == "":
        yield "Error: Please input a text query along with the image(s) if any."
        return

    # Prepare messages for the model
    messages = [{
        "role": "user",
        "content": [
            *[{"type": "image", "image": image} for image in images],
            {"type": "text", "text": text},
        ]
    }]

    # Apply the chat template and process the inputs
    prompt = selected_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = selected_processor(
        text=[prompt],
        images=images if images else None,
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    # Set up a streamer for real-time text generation
    streamer = TextIteratorStreamer(selected_processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)

    # Start generation in a separate thread
    thread = Thread(target=selected_model.generate, kwargs=generation_kwargs)
    thread.start()

    # Yield an animated progress message
    yield progress_bar_html("Thinking...")

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        time.sleep(0.01)
        yield buffer

# Example inputs with model prefixes
examples = [
    [{"text": "@3b Describe the document?", "files": ["example_images/document.jpg"]}],
    [{"text": "@7b What does this say?", "files": ["example_images/math.jpg"]}],
    [{"text": "@3b What is this UI about?", "files": ["example_images/s2w_example.png"]}],
    [{"text": "@7b Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
]

demo = gr.ChatInterface(
    fn=model_inference,
    description=DESCRIPTION,
    css=css,
    examples=examples,
    textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="Use Tags @3b / @7b to trigger the models"),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
)

demo.launch(debug=True)