Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,278 Bytes
ed275c9 ebca0ae ed275c9 5d63d59 ed275c9 5d63d59 ebca0ae 3331201 ed275c9 5abbaf4 ebca0ae 9522057 91cda81 9522057 91cda81 ed275c9 ebca0ae 99eb8ee ed275c9 9522057 5d63d59 ebca0ae 5633a75 ebca0ae 5d63d59 ebca0ae 5633a75 5d63d59 ebca0ae 5633a75 5d63d59 ed275c9 ebca0ae 5d63d59 ed275c9 5d63d59 3331201 99eb8ee 3331201 ed275c9 5d63d59 3331201 5633a75 fe53594 ed275c9 5d63d59 ed275c9 5d63d59 ed275c9 5d63d59 ed275c9 5d63d59 ed275c9 0de5083 5d63d59 ed275c9 5d63d59 80b3d4a 3cae58c 615c76a 7fef2b6 8b9fc4b 5d63d59 9522057 91cda81 9522057 91cda81 5633a75 91cda81 fe53594 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
import spaces
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
from PIL import Image
import uuid
import io
import os
# Fine-tuned for OCR-based tasks from Qwen's [ Qwen/Qwen2-VL-2B-Instruct ]
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16
).to("cuda").eval()
# Supported media extensions
image_extensions = Image.registered_extensions()
video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
def identify_and_save_blob(blob_path):
"""Identifies if the blob is an image or video and saves it accordingly."""
try:
with open(blob_path, 'rb') as file:
blob_content = file.read()
# Try to identify if it's an image
try:
Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
extension = ".png" # Default to PNG for saving
media_type = "image"
except (IOError, SyntaxError):
# If it's not a valid image, assume it's a video
extension = ".mp4" # Default to MP4 for saving
media_type = "video"
# Create a unique filename
filename = f"temp_{uuid.uuid4()}_media{extension}"
with open(filename, "wb") as f:
f.write(blob_content)
return filename, media_type
except FileNotFoundError:
raise ValueError(f"The file {blob_path} was not found.")
except Exception as e:
raise ValueError(f"An error occurred while processing the file: {e}")
def process_vision_info(messages):
"""Processes vision inputs (images and videos) from messages."""
image_inputs = []
video_inputs = []
for message in messages:
for content in message["content"]:
if content["type"] == "image":
image_inputs.append(load_image(content["image"]))
elif content["type"] == "video":
video_inputs.append(content["video"])
return image_inputs, video_inputs
@spaces.GPU
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
# Process media files (images or videos)
media_paths = []
media_types = []
for file in files:
if file.endswith(tuple([i for i, f in image_extensions.items()])):
media_type = "image"
elif file.endswith(video_extensions):
media_type = "video"
else:
try:
file, media_type = identify_and_save_blob(file)
except Exception as e:
gr.Error(f"Unsupported media type: {e}")
return
media_paths.append(file)
media_types.append(media_type)
# Validate input
if text == "" and not media_paths:
gr.Error("Please input a query and optionally image(s) or video(s).")
return
if text == "" and media_paths:
gr.Error("Please input a text query along with the image(s) or video(s).")
return
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
*[{"type": media_type, media_type: media_path} for media_path, media_type in zip(media_paths, media_types)],
{"type": "text", "text": text},
],
}
]
# Apply chat template and process inputs
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Process vision inputs (images and videos)
image_inputs, video_inputs = process_vision_info(messages)
# Ensure video_inputs is not empty
if not video_inputs:
video_inputs = None
inputs = processor(
text=[prompt],
images=image_inputs if image_inputs else None,
videos=video_inputs if video_inputs else None,
return_tensors="pt",
padding=True,
).to("cuda")
# Set up streamer for real-time output
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the output
buffer = ""
yield "Thinking..."
for new_text in streamer:
buffer += new_text
# Remove <|im_end|> or similar tokens from the output
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer
# Example inputs
examples = [
[{"text": "Describe the video.", "files": ["examples/demo.mp4"]}],
[{"text": "Extract JSON from the image", "files": ["example_images/document.jpg"]}],
[{"text": "summarize the letter", "files": ["examples/1.png"]}],
[{"text": "Describe the photo", "files": ["examples/3.png"]}],
[{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
[{"text": "Summarize the full image in detail", "files": ["examples/2.jpg"]}],
[{"text": "Describe this image.", "files": ["example_images/campeones.jpg"]}],
[{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
[{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
[{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
[{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
]
demo = gr.ChatInterface(
fn=model_inference,
description="# **Multimodal OCR**",
examples=examples,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
)
demo.launch(debug=True, share=True) |