import os import random import uuid import json import time import asyncio from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image, ImageOps import cv2 import pymupdf import io from transformers import ( Qwen2VLForConditionalGeneration, VisionEncoderDecoderModel, AutoModelForVision2Seq, AutoProcessor, TextIteratorStreamer, ) from transformers.image_utils import load_image from docling_core.types.doc import DoclingDocument, DocTagsDocument import re import ast import html # Constants for text generation MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Global variables for Dolphin model model_k = None processor_k = None tokenizer_k = None # Load models def initialize_models(): global model_k, processor_k, tokenizer_k # Load olmOCR-7B-0225-preview MODEL_ID_M = "allenai/olmOCR-7B-0225-preview" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load ByteDance's Dolphin MODEL_ID_K = "ByteDance/Dolphin" processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True) if model_k is None: model_k = VisionEncoderDecoderModel.from_pretrained( MODEL_ID_K, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() tokenizer_k = processor_k.tokenizer # Load SmolDocling-256M-preview MODEL_ID_X = "ds4sd/SmolDocling-256M-preview" processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) model_x = AutoModelForVision2Seq.from_pretrained( MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() return processor_m, model_m, processor_x, model_x processor_m, model_m, processor_x, model_x = initialize_models() # Preprocessing functions for SmolDocling-256M def add_random_padding(image, min_percent=0.1, max_percent=0.10): """Add random padding to an image based on its size.""" image = image.convert("RGB") width, height = image.size pad_w_percent = random.uniform(min_percent, max_percent) pad_h_percent = random.uniform(min_percent, max_percent) pad_w = int(width * pad_w_percent) pad_h = int(height * pad_h_percent) corner_pixel = image.getpixel((0, 0)) padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) return padded_image def normalize_values(text, target_max=500): """Normalize numerical values in text to a target maximum.""" def normalize_list(values): max_value = max(values) if values else 1 return [round((v / max_value) * target_max) for v in values] def process_match(match): num_list = ast.literal_eval(match.group(0)) normalized = normalize_list(num_list) return "".join([f"" for num in normalized]) pattern = r"\[([\d\.\s,]+)\]" normalized_text = re.sub(pattern, process_match, text) return normalized_text def downsample_video(video_path): """Downsample a video to evenly spaced frames, returning PIL images with timestamps.""" vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames # Dolphin-specific functions @spaces.GPU def model_chat(prompt, image, is_batch=False): """Use Dolphin model for inference, supporting both single and batch processing.""" global model_k, processor_k, tokenizer_k if model_k is None: initialize_models() if not is_batch: images = [image] prompts = [prompt] else: images = image prompts = prompt if isinstance(prompt, list) else [prompt] * len(images) inputs = processor_k(images, return_tensors="pt", padding=True).to(device) pixel_values = inputs.pixel_values.half() prompts = [f"{p} " for p in prompts] prompt_inputs = tokenizer_k( prompts, add_special_tokens=False, return_tensors="pt", padding=True ).to(device) outputs = model_k.generate( pixel_values=pixel_values, decoder_input_ids=prompt_inputs.input_ids, decoder_attention_mask=prompt_inputs.attention_mask, min_length=1, max_length=4096, pad_token_id=tokenizer_k.pad_token_id, eos_token_id=tokenizer_k.eos_token_id, use_cache=True, bad_words_ids=[[tokenizer_k.unk_token_id]], return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1 ) sequences = tokenizer_k.batch_decode(outputs.sequences, skip_special_tokens=False) results = [] for i, sequence in enumerate(sequences): cleaned = sequence.replace(prompts[i], "").replace("", "").replace("", "").strip() results.append(cleaned) return results[0] if not is_batch else results @spaces.GPU def process_element_batch(elements, prompt, max_batch_size=16): """Process a batch of elements with the same prompt.""" results = [] batch_size = min(len(elements), max_batch_size) for i in range(0, len(elements), batch_size): batch_elements = elements[i:i + batch_size] crops_list = [elem["crop"] for elem in batch_elements] prompts_list = [prompt] * len(crops_list) batch_results = model_chat(prompts_list, crops_list, is_batch=True) for j, result in enumerate(batch_results): elem = batch_elements[j] results.append({ "label": elem["label"], "bbox": elem["bbox"], "text": result.strip(), "reading_order": elem["reading_order"], }) return results def process_elements(layout_results, image): """Parse layout results and extract elements from the image.""" try: elements = ast.literal_eval(layout_results) except: elements = [] text_elements = [] table_elements = [] figure_results = [] reading_order = 0 for bbox, label in elements: try: x1, y1, x2, y2 = map(int, bbox) cropped = image.crop((x1, y1, x2, y2)) if cropped.size[0] > 0 and cropped.size[1] > 0: element_info = { "crop": cropped, "label": label, "bbox": [x1, y1, x2, y2], "reading_order": reading_order, } if label == "text": text_elements.append(element_info) elif label == "table": table_elements.append(element_info) elif label == "figure": figure_results.append({ "label": label, "bbox": [x1, y1, x2, y2], "text": "[Figure]", "reading_order": reading_order }) reading_order += 1 except Exception as e: print(f"Error processing element: {e}") continue recognition_results = figure_results.copy() if text_elements: text_results = process_element_batch(text_elements, "Read text in the image.") recognition_results.extend(text_results) if table_elements: table_results = process_element_batch(table_elements, "Parse the table in the image.") recognition_results.extend(table_results) recognition_results.sort(key=lambda x: x["reading_order"]) return recognition_results def generate_markdown(recognition_results): """Generate markdown from extracted elements.""" markdown = "" for element in recognition_results: if element["label"] == "text": markdown += f"{element['text']}\n\n" elif element["label"] == "table": markdown += f"**Table:**\n{element['text']}\n\n" elif element["label"] == "figure": markdown += f"{element['text']}\n\n" return markdown.strip() def convert_to_image(image): """Convert uploaded file to PIL Image, handling PDFs by extracting the first page.""" if isinstance(image, str): # File path from Gradio if image.lower().endswith('.pdf'): doc = pymupdf.open(image) page = doc[0] pix = page.get_pixmap() img_data = pix.tobytes("png") pil_image = Image.open(io.BytesIO(img_data)).convert("RGB") doc.close() return pil_image else: return Image.open(image).convert("RGB") elif isinstance(image, Image.Image): # Already a PIL Image return image.convert("RGB") return None def process_image_with_dolphin(image): """Process a single image with Dolphin model.""" pil_image = convert_to_image(image) if pil_image is None: return "Error: Unable to process the uploaded file." layout_output = model_chat("Parse the reading order of this document.", pil_image) elements = process_elements(layout_output, pil_image) markdown_content = generate_markdown(elements) return markdown_content @spaces.GPU def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for image input using the selected model.""" if model_name == "ByteDance-s-Dolphin": if image is None: yield "Please upload an image or PDF (first page will be processed)." return markdown_content = process_image_with_dolphin(image) yield markdown_content else: if model_name == "olmOCR-7B-0225-preview": processor = processor_m model = model_m elif model_name == "SmolDocling-256M-preview": processor = processor_x model = model_x else: yield "Invalid model selected." return if image is None: yield "Please upload an image." return images = [convert_to_image(image)] if images[0] is None: yield "Error: Unable to process the uploaded file." return if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += new_text.replace("<|im_end|>", "") yield buffer if model_name == "SmolDocling-256M-preview": cleaned_output = full_output.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield f"**MD Output:**\n\n{markdown_output}" else: yield cleaned_output @spaces.GPU def generate_video(model_name: str, text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for video input using the selected model.""" if model_name == "ByteDance-s-Dolphin": if video_path is None: yield "Please upload a video." return frames = downsample_video(video_path) markdown_contents = [] for idx, (frame, _) in enumerate(frames): markdown_content = process_image_with_dolphin(frame) markdown_contents.append(f"**Frame {idx + 1}:**\n{markdown_content}") combined_markdown = "\n\n---\n\n".join(markdown_contents) yield combined_markdown else: if model_name == "olmOCR-7B-0225-preview": processor = processor_m model = model_m elif model_name == "SmolDocling-256M-preview": processor = processor_x model = model_x else: yield "Invalid model selected." return if video_path is None: yield "Please upload a video." return frames = downsample_video(video_path) images = [frame for frame, _ in frames] if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += new_text.replace("<|im_end|>", "") yield buffer if model_name == "SmolDocling-256M-preview": cleaned_output = full_output.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield f"**MD Output:**\n\n{markdown_output}" else: yield cleaned_output # Define examples image_examples = [ ["Convert this page to docling", "images/1.png"], ["OCR the image", "images/2.jpg"], ["Convert this page to docling", "images/3.png"], ] video_examples = [ ["Explain the ad in detail", "example/1.mp4"], ["Identify the main actions in the coca cola ad...", "example/2.mp4"] ] css = """ .submit-btn { background-color: #2980b9 !important; color: white !important; } .submit-btn:hover { background-color: #3498db !important; } """ # Create Gradio Interface with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**") gr.Markdown("**Note:** For Dolphin model, the text query is ignored, and PDFs are processed by parsing the first page.") with gr.Row(): with gr.Column(): with gr.Tabs(): with gr.TabItem("Image Inference"): image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") image_upload = gr.Image(type="pil", label="Image or PDF") image_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples(examples=image_examples, inputs=[image_query, image_upload]) with gr.TabItem("Video Inference"): video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") video_upload = gr.Video(label="Video") video_submit = gr.Button("Submit", elem_classes="submit-btn") gr.Examples(examples=video_examples, inputs=[video_query, video_upload]) with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) with gr.Column(): output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2) model_choice = gr.Radio( choices=["olmOCR-7B-0225-preview", "SmolDocling-256M-preview", "ByteDance-s-Dolphin"], label="Select Model", value="olmOCR-7B-0225-preview" ) image_submit.click( fn=generate_image, inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=output ) video_submit.click( fn=generate_video, inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=output ) if __name__ == "__main__": demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)