File size: 6,008 Bytes
ed275c9
5d63d59
ed275c9
5d63d59
 
fc95e60
6401487
 
92e002a
6401487
 
 
 
 
92e002a
ed275c9
92e002a
d2b791d
92e002a
 
8bf8d90
92e002a
7342b9f
 
 
92e002a
 
7342b9f
 
 
c8cd2f3
 
 
7342b9f
 
 
 
6401487
92e002a
8bf8d90
92e002a
 
6401487
 
 
 
92e002a
 
 
564e537
6401487
 
 
 
 
 
 
 
 
 
 
92e002a
c307af6
3f6a788
 
 
91cda81
 
 
ed275c9
ce03905
92e002a
 
 
 
 
 
 
6401487
ed275c9
c307af6
8c1f8ea
3f6a788
92e002a
c307af6
 
 
 
 
 
 
 
 
92e002a
 
 
c307af6
 
 
64f9a07
c307af6
 
 
 
 
 
92e002a
d2b791d
c307af6
 
 
 
 
 
 
5d63d59
c307af6
 
 
 
 
 
 
 
92e002a
c307af6
5633a75
fe53594
ed275c9
c307af6
 
ed275c9
c307af6
ed275c9
 
c307af6
ed275c9
8bf8d90
 
5d63d59
ed275c9
 
92e002a
5d63d59
c307af6
 
64f9a07
9522057
91cda81
 
9522057
ef23e44
91cda81
7342b9f
c307af6
 
 
 
7342b9f
91cda81
 
 
3b7fae9
91cda81
 
8bf8d90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
import numpy as np
from PIL import Image
from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
    TextIteratorStreamer,
)
from transformers import Qwen2_5_VLForConditionalGeneration

# Helper Functions
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
    """
    Returns an HTML snippet for a thin animated progress bar with a label.
    Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
    """
    return f'''
<div style="display: flex; align-items: center;">
    <span style="margin-right: 10px; font-size: 14px;">{label}</span>
    <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
        <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
    </div>
</div>
<style>
@keyframes loading {{
    0% {{ transform: translateX(-100%); }}
    100% {{ transform: translateX(100%); }}
}}
</style>
    '''

def downsample_video(video_path):
    """
    Downsamples a video file by extracting 10 evenly spaced frames.
    Returns a list of tuples (PIL.Image, timestamp).
    """
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []
    if total_frames <= 0 or fps <= 0:
        vidcap.release()
        return frames
    frame_indices = np.linspace(0, total_frames - 1, 25, dtype=int)
    for i in frame_indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = vidcap.read()
        if success:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))
    vidcap.release()
    return frames

# Model and Processor Setup
QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    QV_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16
).to("cuda").eval()

ROLMOCR_MODEL_ID = "reducto/RolmOCR"
rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    ROLMOCR_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

# Main Inference Function
@spaces.GPU
def model_inference(input_dict, history, use_rolmocr=False):
    text = input_dict["text"].strip()
    files = input_dict.get("files", [])

    if not text and not files:
        yield "Error: Please input a text query or provide files (images or videos)."
        return

    # Process files: images and videos
    image_list = []
    for idx, file in enumerate(files):
        if file.lower().endswith((".mp4", ".avi", ".mov")):
            frames = downsample_video(file)
            if not frames:
                yield "Error: Could not extract frames from the video."
                return
            for frame, timestamp in frames:
                label = f"Video {idx+1} Frame {timestamp}:"
                image_list.append((label, frame))
        else:
            try:
                img = load_image(file)
                label = f"Image {idx+1}:"
                image_list.append((label, img))
            except Exception as e:
                yield f"Error loading image: {str(e)}"
                return

    # Build content list
    content = [{"type": "text", "text": text}]
    for label, img in image_list:
        content.append({"type": "text", "text": label})
        content.append({"type": "image", "image": img})

    messages = [{"role": "user", "content": content}]

    # Select processor and model
    processor = rolmocr_processor if use_rolmocr else qwen_processor
    model = rolmocr_model if use_rolmocr else qwen_model
    model_name = "RolmOCR" if use_rolmocr else "Qwen2VL OCR"

    prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    all_images = [item["image"] for item in content if item["type"] == "image"]
    inputs = processor(
        text=[prompt_full],
        images=all_images if all_images else None,
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    yield progress_bar_html(f"Processing with {model_name}")
    for new_text in streamer:
        buffer += new_text
        buffer = buffer.replace("<|im_end|>", "")
        time.sleep(0.01)
        yield buffer

# Gradio Interface
examples = [
    [{"text": "OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
    [{"text": "Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
    [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
]

demo = gr.ChatInterface(
    fn=model_inference,
    description="# *Multimodal OCR `@RolmOCR and Default Qwen2VL OCR`*",
    examples=examples,
    textbox=gr.MultimodalTextbox(
        label="Query Input",
        file_types=["image", "video"],
        file_count="multiple",
        placeholder="Input your query and optionally upload image(s) or video(s). Select the model using the checkbox."
    ),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
    additional_inputs=[gr.Checkbox(label="Use RolmOCR", value=False, info="Check to use RolmOCR, uncheck to use Qwen2VL OCR")],
)

demo.launch(debug=True, ssr_mode=False)