File size: 6,658 Bytes
b8a0d2d
f5ecaf8
8716c2f
f5ecaf8
8081540
f5ecaf8
 
 
9dc7658
 
 
 
 
d9dde0d
9dc7658
 
 
 
 
 
8716c2f
9dc7658
f5ecaf8
86a82e4
 
 
 
 
8081540
9dc7658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dec51b2
9dc7658
 
86a82e4
9dc7658
 
 
86a82e4
 
 
 
9dc7658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a82e4
9dc7658
 
 
 
d9dde0d
9dc7658
 
 
 
 
 
 
 
 
 
 
 
 
d9dde0d
9dc7658
 
 
 
 
 
 
 
 
 
 
 
 
 
f5ecaf8
0b5bfb4
 
 
 
9dc7658
 
 
 
 
 
 
 
8716c2f
9dc7658
d9dde0d
f5ecaf8
8716c2f
f5ecaf8
8716c2f
f5ecaf8
 
9dc7658
 
 
 
 
 
 
 
 
86a82e4
 
9dc7658
 
 
 
 
 
 
 
 
86a82e4
 
 
 
 
b8a0d2d
9dc7658
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
from threading import Thread
import re
import time
import torch
import spaces
import subprocess
import uuid
import cv2
import numpy as np
from PIL import Image
from io import BytesIO

# Install flash-attn
subprocess.run(
    'pip install flash-attn --no-build-isolation',
    env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
    shell=True
)

# Load processor and model.
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
model = AutoModelForImageTextToText.from_pretrained(
    "HuggingFaceTB/SmolVLM2-2.2B-Instruct", 
    _attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
).to("cuda:0")

def downsample_video(video_path):
    """
    Extracts 10 evenly spaced frames from the video at video_path.
    Each frame is converted from BGR to RGB and returned as a PIL Image.
    """
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []
    if total_frames <= 0 or fps <= 0:
        vidcap.release()
        return frames
    frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
    for i in frame_indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, frame = vidcap.read()
        if success:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(frame)
            frames.append((pil_image, round(i / fps, 2)))
    vidcap.release()
    return frames

@spaces.GPU
def model_inference(input_dict, history, max_tokens):
    text = input_dict["text"]
    user_content = []
    media_queue = []
    
    # Process input files.
    for file in input_dict.get("files", []):
        if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
            media_queue.append({"type": "image", "path": file})
        elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
            # Extract frames from video using OpenCV.
            frames = downsample_video(file)
            for frame, timestamp in frames:
                temp_file = f"video_frame_{uuid.uuid4().hex}.png"
                frame.save(temp_file)
                media_queue.append({"type": "image", "path": temp_file})
    
    # Build the conversation messages.
    if not history:
        text = text.strip()
        # Use only the "<image>" token for inserting images.
        if "<image>" in text:
            parts = re.split(r'(<image>)', text)
            for part in parts:
                if part == "<image>" and media_queue:
                    user_content.append(media_queue.pop(0))
                elif part.strip():
                    user_content.append({"type": "text", "text": part.strip()})
        else:
            user_content.append({"type": "text", "text": text})
            for media in media_queue:
                user_content.append(media)
        resulting_messages = [{"role": "user", "content": user_content}]
    else:
        resulting_messages = []
        user_content = []
        media_queue = []
        # Process history: now only image files are expected.
        for hist in history:
            if hist["role"] == "user" and isinstance(hist["content"], tuple): 
                file_name = hist["content"][0]
                if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
                    media_queue.append({"type": "image", "path": file_name})
        for hist in history:
            if hist["role"] == "user" and isinstance(hist["content"], str):
                text = hist["content"]
                parts = re.split(r'(<image>)', text)
                for part in parts:
                    if part == "<image>" and media_queue:
                        user_content.append(media_queue.pop(0))
                    elif part.strip():
                        user_content.append({"type": "text", "text": part.strip()})
            elif hist["role"] == "assistant":
                resulting_messages.append({
                    "role": "user",
                    "content": user_content
                })
                resulting_messages.append({
                    "role": "assistant",
                    "content": [{"type": "text", "text": hist["content"]}]
                })
                user_content = []
    
    if text == "":
        gr.Error("Please input a query and optionally image(s).")
    
    print("resulting_messages", resulting_messages)
    inputs = processor.apply_chat_template(
        resulting_messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model.device)
    
    # Generate response with streaming.
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
    thread = Thread(target=model.generate, kwargs=generation_args)
    thread.start()
    
    yield "..."
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        time.sleep(0.01)
        yield buffer

examples = [
    [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
    [{"text": "What art era does this artpiece <image> belong to?", "files": ["example_images/rococo.jpg"]}],
    [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
    [{"text": "When was this purchase made and how much did it cost?", "files": ["example_images/fiche.jpg"]}],
    [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
    [{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
]

demo = gr.ChatInterface(
    fn=model_inference,
    title="SmolVLM2: The Smollest Video Model Ever 📺",
    description=(
        "Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. "
        "To get started, upload an image and text or try one of the examples. "
        "This demo doesn't use history for the chat, so every chat you start is a new conversation."
    ),
    examples=examples,
    textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
    additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
    type="messages"
)

demo.launch(debug=True)