import os import random import uuid import json import time import asyncio from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image, ImageOps import cv2 from transformers import ( Qwen2VLForConditionalGeneration, VisionEncoderDecoderModel, AutoModelForVision2Seq, AutoProcessor, TextIteratorStreamer, ) from transformers.image_utils import load_image from docling_core.types.doc import DoclingDocument, DocTagsDocument import re import ast import html # Constants for text generation MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load olmOCR-7B-0225-preview MODEL_ID_M = "allenai/olmOCR-7B-0225-preview" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load ByteDance's Dolphin MODEL_ID_K = "ByteDance/Dolphin" processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True) model_k = VisionEncoderDecoderModel.from_pretrained( MODEL_ID_K, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load SmolDocling-256M-preview MODEL_ID_X = "ds4sd/SmolDocling-256M-preview" processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) model_x = AutoModelForVision2Seq.from_pretrained( MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Preprocessing functions for SmolDocling-256M def add_random_padding(image, min_percent=0.1, max_percent=0.10): """Add random padding to an image based on its size.""" image = image.convert("RGB") width, height = image.size pad_w_percent = random.uniform(min_percent, max_percent) pad_h_percent = random.uniform(min_percent, max_percent) pad_w = int(width * pad_w_percent) pad_h = int(height * pad_h_percent) corner_pixel = image.getpixel((0, 0)) # Top-left corner padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) return padded_image def normalize_values(text, target_max=500): """Normalize numerical values in text to a target maximum.""" def normalize_list(values): max_value = max(values) if values else 1 return [round((v / max_value) * target_max) for v in values] def process_match(match): num_list = ast.literal_eval(match.group(0)) normalized = normalize_list(num_list) return "".join([f"" for num in normalized]) pattern = r"\[([\d\.\s,]+)\]" normalized_text = re.sub(pattern, process_match, text) return normalized_text def downsample_video(video_path): """Downsample a video to evenly spaced frames, returning PIL images with timestamps.""" vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames # Dolphin-specific functions def model_chat(prompt, image, is_batch=False): """Use Dolphin model for inference, supporting both single and batch processing.""" processor = processor_k model = model_k device = "cuda" if torch.cuda.is_available() else "cpu" if not is_batch: images = [image] prompts = [prompt] else: images = image prompts = prompt if isinstance(prompt, list) else [prompt] * len(images) inputs = processor(images, return_tensors="pt", padding=True).to(device) pixel_values = inputs.pixel_values.half() prompts = [f"{p} " for p in prompts] prompt_inputs = processor.tokenizer( prompts, add_special_tokens=False, # Explicitly set to False return_tensors="pt", padding=True ).to(device) outputs = model.generate( pixel_values=pixel_values, decoder_input_ids=prompt_inputs.input_ids, decoder_attention_mask=prompt_inputs.attention_mask, min_length=1, max_length=4096, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1 ) sequences = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False) results = [] for i, sequence in enumerate(sequences): cleaned = sequence.replace(prompts[i], "").replace("", "").replace("", "").strip() results.append(cleaned) return results[0] if not is_batch else results def process_element_batch(elements, prompt, max_batch_size=16): """Process a batch of elements with the same prompt.""" results = [] batch_size = min(len(elements), max_batch_size) for i in range(0, len(elements), batch_size): batch_elements = elements[i:i + batch_size] crops_list = [elem["crop"] for elem in batch_elements] prompts_list = [prompt] * len(crops_list) batch_results = model_chat(prompts_list, crops_list, is_batch=True) for j, result in enumerate(batch_results): elem = batch_elements[j] results.append({ "label": elem["label"], "bbox": elem["bbox"], "text": result.strip(), "reading_order": elem["reading_order"], }) return results def process_elements(layout_results, image): """Parse layout results and extract elements from the image.""" try: elements = ast.literal_eval(layout_results) except: elements = [] text_elements = [] table_elements = [] figure_results = [] reading_order = 0 for bbox, label in elements: try: x1, y1, x2, y2 = map(int, bbox) cropped = image.crop((x1, y1, x2, y2)) if cropped.size[0] > 0 and cropped.size[1] > 0: element_info = { "crop": cropped, "label": label, "bbox": [x1, y1, x2, y2], "reading_order": reading_order, } if label == "text": text_elements.append(element_info) elif label == "table": table_elements.append(element_info) elif label == "figure": figure_results.append({ "label": label, "bbox": [x1, y1, x2, y2], "text": "[Figure]", "reading_order": reading_order }) reading_order += 1 except Exception as e: print(f"Error processing element: {e}") continue recognition_results = figure_results.copy() if text_elements: text_results = process_element_batch(text_elements, "Read text in the image.") recognition_results.extend(text_results) if table_elements: table_results = process_element_batch(table_elements, "Parse the table in the image.") recognition_results.extend(table_results) recognition_results.sort(key=lambda x: x["reading_order"]) return recognition_results def generate_markdown(recognition_results): """Generate markdown from extracted elements.""" markdown = "" for element in recognition_results: if element["label"] == "text": markdown += f"{element['text']}\n\n" elif element["label"] == "table": markdown += f"**Table:**\n{element['text']}\n\n" elif element["label"] == "figure": markdown += f"{element['text']}\n\n" return markdown.strip() def process_image_with_dolphin(image): """Process a single image with Dolphin model.""" layout_output = model_chat("Parse the reading order of this document.", image) elements = process_elements(layout_output, image) markdown_content = generate_markdown(elements) return markdown_content @spaces.GPU def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for image input using the selected model.""" if model_name == "ByteDance-s-Dolphin": if image is None: yield "Please upload an image." return markdown_content = process_image_with_dolphin(image) yield markdown_content else: if model_name == "olmOCR-7B-0225-preview": processor = processor_m model = model_m elif model_name == "SmolDocling-256M-preview": processor = processor_x model = model_x else: yield "Invalid model selected." return if image is None: yield "Please upload an image." return images = [image] if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += new_text.replace("<|im_end|>", "") yield buffer if model_name == "SmolDocling-256M-preview": cleaned_output = full_output.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield f"**MD Output:**\n\n{markdown_output}" else: yield cleaned_output @spaces.GPU def generate_video(model_name: str, text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for video input using the selected model.""" if model_name == "ByteDance-s-Dolphin": if video_path is None: yield "Please upload a video." return frames = downsample_video(video_path) markdown_contents = [] for frame, _ in frames: markdown_content = process_image_with_dolphin(frame) markdown_contents.append(markdown_content) combined_markdown = "\n\n".join(markdown_contents) yield combined_markdown else: if model_name == "olmOCR-7B-0225-preview": processor = processor_m model = model_m elif model_name == "SmolDocling-256M-preview": processor = processor_x model = model_x else: yield "Invalid model selected." return if video_path is None: yield "Please upload a video." return frames = downsample_video(video_path) images = [frame for frame, _ in frames] if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += new_text.replace("<|im_end|>", "") yield buffer if model_name == "SmolDocling-256M-preview": cleaned_output = full_output.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield f"**MD Output:**\n\n{markdown_output}" else: yield cleaned_output # Define examples for image and video inference image_examples = [ ["Convert this page to docling", "images/1.png"], ["OCR the image", "images/2.jpg"], ["Convert this page to docling", "images/3.png"], ] video_examples = [ ["Explain the ad in detail", "example/1.mp4"], ["Identify the main actions in the coca cola ad...", "example/2.mp4"] ] css = """ .submit-btn { background-color: #2980b9 !important; color: white !important; } .submit-btn:hover { background-color: #3498db !important; } """ # Create the Gradio Interface with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**") with gr.Row(): with gr.Column(): with gr.Tabs(): with gr.TabItem("Image Inference"): image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") image_upload = gr.Image(type="pil", label="Image") image_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples( examples=image_examples, inputs=[image_query, image_upload] ) with gr.TabItem("Video Inference"): video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") video_upload = gr.Video(label="Video") video_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples( examples=video_examples, inputs=[video_query, video_upload] ) with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) with gr.Column(): output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2) model_choice = gr.Radio( choices=["olmOCR-7B-0225-preview", "SmolDocling-256M-preview", "ByteDance-s-Dolphin"], label="Select Model", value="olmOCR-7B-0225-preview" ) image_submit.click( fn=generate_image, inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=output ) video_submit.click( fn=generate_video, inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=output ) if __name__ == "__main__": demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)