import os import random import uuid import json import time import asyncio from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image, ImageDraw import cv2 import re from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, ) from transformers.image_utils import load_image # Constants for text generation MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load Camel-Doc-OCR-062825 MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load Qwen2.5-VL-7B-Instruct MODEL_ID_X = "Qwen/Qwen2.5-VL-7B-Instruct" processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load Qwen2.5-VL-3B-Instruct MODEL_ID_T = "Qwen/Qwen2.5-VL-3B-Instruct" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() def downsample_video(video_path): """ Downsamples the video to evenly spaced frames. Each frame is returned as a PIL image along with its timestamp. """ vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2): draw = ImageDraw.Draw(image) for box in bounding_boxes: xmin, ymin, xmax, ymax = box draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width) return image def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000): x_scale = original_width / scaled_width y_scale = original_height / scaled_height rescaled_boxes = [] for box in bounding_boxes: xmin, ymin, xmax, ymax = box rescaled_box = [ xmin * x_scale, ymin * y_scale, xmax * x_scale, ymax * y_scale ] rescaled_boxes.append(rescaled_box) return rescaled_boxes @spaces.GPU def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """ Generates responses using the selected model for image input. """ if model_name == "Camel-Doc-OCR-062825": processor = processor_m model = model_m elif model_name == "Qwen2.5-VL-7B-Instruct": processor = processor_x model = model_x elif model_name == "Qwen2.5-VL-3B-Instruct": processor = processor_t model = model_t else: yield "Invalid model selected.", "Invalid model selected." return if image is None: yield "Please upload an image.", "Please upload an image." return messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text}, ] }] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=False, max_length=MAX_INPUT_TOKEN_LENGTH ).to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer, buffer @spaces.GPU def generate_video(model_name: str, text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """ Generates responses using the selected model for video input. """ if model_name == "Camel-Doc-OCR-062825": processor = processor_m model = model_m elif model_name == "Qwen2.5-VL-7B-Instruct": processor = processor_x model = model_x elif model_name == "Qwen2.5-VL-3B-Instruct": processor = processor_t model = model_t else: yield "Invalid model selected.", "Invalid model selected." return if video_path is None: yield "Please upload a video.", "Please upload a video." return frames = downsample_video(video_path) messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "user", "content": [{"type": "text", "text": text}]} ] for frame in frames: image, timestamp = frame messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) messages[1]["content"].append({"type": "image", "image": image}) inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", truncation=False, max_length=MAX_INPUT_TOKEN_LENGTH ).to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer, buffer @spaces.GPU def run_object_detection(model_name: str, image: Image.Image, text_input: str, system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): if model_name == "Camel-Doc-OCR-062825": processor = processor_m model = model_m elif model_name == "Qwen2.5-VL-7B-Instruct": processor = processor_x model = model_x elif model_name == "Qwen2.5-VL-3B-Instruct": processor = processor_t model = model_t else: return "Invalid model selected.", "", image if image is None: return "Please upload an image.", "", image messages = [ { "role": "user", "content": [ {"type": "text", "text": system_prompt}, {"type": "text", "text": text_input}, {"type": "image", "image": image}, ], } ] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=False, max_length=MAX_INPUT_TOKEN_LENGTH ).to(device) generation_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } generated_ids = model.generate(**inputs, **generation_kwargs) generated_ids_trimmed = generated_ids[:, inputs["input_ids"].shape[1]:] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] pattern = r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]' matches = re.findall(pattern, output_text) parsed_boxes = [[int(num) for num in match] for match in matches] original_width, original_height = image.size scaled_boxes = rescale_bounding_boxes(parsed_boxes, original_width, original_height) annotated_image = draw_bounding_boxes(image.copy(), scaled_boxes) return output_text, str(parsed_boxes), annotated_image # Define examples for image and video inference image_examples = [ ["Convert this page to doc [text] precisely.", "images/3.png"], ["Convert this page to doc [text] precisely.", "images/4.png"], ["Convert this page to doc [text] precisely.", "images/1.png"], ["Convert chart to OTSL.", "images/2.png"] ] video_examples = [ ["Explain the video in detail.", "videos/2.mp4"], ["Explain the ad in detail.", "videos/1.mp4"] ] # Define examples for object detection default_system_prompt = "You are a helpful assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled to 1000 by 1000 pixels. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]." object_detection_examples = [ ["images/3.png", "Detect all text blocks", default_system_prompt], ["images/4.png", "Find all images", default_system_prompt], ["images/1.png", "Locate the headers", default_system_prompt], ["images/2.png", "Detect the chart", default_system_prompt], ] # Added CSS to style the output area as a "Canvas" css = """ .submit-btn { background-color: #2980b9 !important; color: white !important; } .submit-btn:hover { background-color: #3498db !important; } .canvas-output { border: 2px solid #4682B4; border-radius: 10px; padding: 20px; } """ # Create the Gradio Interface with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: gr.Markdown("# **[Doc-VLMs-v2-Localization](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**") with gr.Row(): with gr.Column(): model_choice = gr.Radio( choices=["Camel-Doc-OCR-062825", "Qwen2.5-VL-7B-Instruct", "Qwen2.5-VL-3B-Instruct"], label="Select Model", value="Camel-Doc-OCR-062825" ) with gr.Tabs(): with gr.TabItem("Image Inference"): with gr.Row(): with gr.Column(): image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") image_upload = gr.Image(type="pil", label="Image") image_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples( examples=image_examples, inputs=[image_query, image_upload] ) with gr.Column(): output_image = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2) markdown_output_image = gr.Markdown(label="Formatted Result (Result.Md)") with gr.TabItem("Video Inference"): with gr.Row(): with gr.Column(): video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") video_upload = gr.Video(label="Video") video_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples( examples=video_examples, inputs=[video_query, video_upload] ) with gr.Column(): output_video = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2) markdown_output_video = gr.Markdown(label="Formatted Result (Result.Md)") with gr.TabItem("Object Detection"): with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input Image", type="pil") system_prompt = gr.Textbox(label="System Prompt", value=default_system_prompt) text_input = gr.Textbox(label="User Prompt") object_detection_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples( examples=object_detection_examples, inputs=[input_img, text_input, system_prompt] ) with gr.Column(): model_output_text = gr.Textbox(label="Model Output Text") parsed_boxes = gr.Textbox(label="Parsed Boxes") annotated_image = gr.Image(label="Annotated Image") with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) image_submit.click( fn=generate_image, inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output_image, markdown_output_image] ) video_submit.click( fn=generate_video, inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output_video, markdown_output_video] ) object_detection_submit.click( fn=run_object_detection, inputs=[model_choice, input_img, text_input, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[model_output_text, parsed_boxes, annotated_image] ) if __name__ == "__main__": demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)