import gradio as gr from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer from transformers.image_utils import load_image from threading import Thread import re import time import torch import spaces import ast import html import random import cv2 import numpy as np import uuid from PIL import Image, ImageOps from docling_core.types.doc import DoclingDocument from docling_core.types.doc.document import DocTagsDocument # --------------------------- # Helper Functions # --------------------------- def progress_bar_html(label: str) -> str: return f'''
{label}
''' def downsample_video(video_path, num_frames=10): """Downsamples a video to a fixed number of evenly spaced frames.""" vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] if total_frames <= 0 or fps <= 0: vidcap.release() return frames # Get indices for num_frames evenly spaced frames. frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: # Convert from BGR (OpenCV) to RGB (PIL) and then to PIL Image. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def add_random_padding(image, min_percent=0.1, max_percent=0.10): image = image.convert("RGB") width, height = image.size pad_w_percent = random.uniform(min_percent, max_percent) pad_h_percent = random.uniform(min_percent, max_percent) pad_w = int(width * pad_w_percent) pad_h = int(height * pad_h_percent) corner_pixel = image.getpixel((0, 0)) # Top-left corner for padding color padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) return padded_image def normalize_values(text, target_max=500): def normalize_list(values): max_value = max(values) if values else 1 return [round((v / max_value) * target_max) for v in values] def process_match(match): num_list = ast.literal_eval(match.group(0)) normalized = normalize_list(num_list) return "".join([f"" for num in normalized]) pattern = r"\[([\d\.\s,]+)\]" normalized_text = re.sub(pattern, process_match, text) return normalized_text # --------------------------- # Model & Processor Setup # --------------------------- processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview") model = AutoModelForVision2Seq.from_pretrained( "ds4sd/SmolDocling-256M-preview", torch_dtype=torch.bfloat16, ).to("cuda") # --------------------------- # Main Inference Function # --------------------------- @spaces.GPU def model_inference(input_dict, history): text = input_dict["text"] files = input_dict.get("files", []) # If there are files, check if any is a video video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm") if files and any(str(f).lower().endswith(video_extensions) for f in files): # -------- Video Inference Branch -------- video_file = files[0] # Assume first file is a video frames = downsample_video(video_file) if not frames: yield "Could not process video file." return images = [frame[0] for frame in frames] timestamps = [frame[1] for frame in frames] # Append frame timestamps to the query text. text_with_timestamps = text + " " + " ".join([f"Frame at {ts} seconds." for ts in timestamps]) resulting_messages = [{ "role": "user", "content": [{"type": "image"} for _ in range(len(images))] + [{"type": "text", "text": text_with_timestamps}] }] prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[images], return_tensors="pt").to("cuda") yield progress_bar_html("Processing video with SmolDocling") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False) generation_args = dict(inputs, streamer=streamer, max_new_tokens=8192) thread = Thread(target=model.generate, kwargs=generation_args) thread.start() buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += html.escape(new_text) yield buffer cleaned_output = full_output.replace("", "").strip() if cleaned_output: doctag_output = cleaned_output yield cleaned_output if any(tag in doctag_output for tag in ["", "", "", "", ""]): doc = DoclingDocument(name="Document") if "" in doctag_output: doctag_output = doctag_output.replace("", "").replace("", "") doctag_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', doctag_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctag_output], images) doc.load_from_doctags(doctags_doc) yield f"**MD Output:**\n\n{doc.export_to_markdown()}" return elif files: # -------- Image Inference Branch -------- if len(files) > 1: if "OTSL" in text or "code" in text: images = [add_random_padding(load_image(image)) for image in files] else: images = [load_image(image) for image in files] elif len(files) == 1: if "OTSL" in text or "code" in text: images = [add_random_padding(load_image(files[0]))] else: images = [load_image(files[0])] resulting_messages = [{ "role": "user", "content": [{"type": "image"} for _ in range(len(images))] + [{"type": "text", "text": text}] }] prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[images], return_tensors="pt").to("cuda") yield progress_bar_html("Processing with SmolDocling") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False) generation_args = dict(inputs, streamer=streamer, max_new_tokens=8192) thread = Thread(target=model.generate, kwargs=generation_args) thread.start() yield "..." buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += html.escape(new_text) yield buffer cleaned_output = full_output.replace("", "").strip() if cleaned_output: doctag_output = cleaned_output yield cleaned_output if any(tag in doctag_output for tag in ["", "", "", "", ""]): doc = DoclingDocument(name="Document") if "" in doctag_output: doctag_output = doctag_output.replace("", "").replace("", "") doctag_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', doctag_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctag_output], images) doc.load_from_doctags(doctags_doc) yield f"**MD Output:**\n\n{doc.export_to_markdown()}" return else: # -------- Text-Only Inference Branch -------- if text == "": gr.Error("Please input a query and optionally image(s).") resulting_messages = [{ "role": "user", "content": [{"type": "text", "text": text}] }] prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) inputs = processor(text=prompt, return_tensors="pt").to("cuda") yield progress_bar_html("Processing text with SmolDocling") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False) generation_args = dict(inputs, streamer=streamer, max_new_tokens=8192) thread = Thread(target=model.generate, kwargs=generation_args) thread.start() yield "..." buffer = "" full_output = "" for new_text in streamer: full_output += new_text buffer += html.escape(new_text) yield buffer cleaned_output = full_output.replace("", "").strip() if cleaned_output: yield cleaned_output return # --------------------------- # Gradio Interface Setup # --------------------------- examples = [ [{"text": "Convert this page to docling.", "files": ["example_images/2d0fbcc50e88065a040a537b717620e964fb4453314b71d83f3ed3425addcef6.png"]}], [{"text": "Convert this table to OTSL.", "files": ["example_images/image-2.jpg"]}], [{"text": "Convert code to text.", "files": ["example_images/7666.jpg"]}], [{"text": "Convert formula to latex.", "files": ["example_images/2433.jpg"]}], [{"text": "Convert chart to OTSL.", "files": ["example_images/06236926002285.png"]}], [{"text": "OCR the text in location [47, 531, 167, 565]", "files": ["example_images/s2w_example.png"]}], [{"text": "Extract all section header elements on the page.", "files": ["example_images/paper_3.png"]}], [{"text": "Identify element at location [123, 413, 1059, 1061]", "files": ["example_images/redhat.png"]}], [{"text": "Convert this page to docling.", "files": ["example_images/gazette_de_france.jpg"]}], # Example video file (if available) [{"text": "Describe the events in this video.", "files": ["example_videos/sample_video.mp4"]}], ] demo = gr.ChatInterface( fn=model_inference, title="SmolDocling-256M: Ultra-compact VLM for Document Conversion 💫", description=( "Play with [ds4sd/SmolDocling-256M-preview](https://huggingface.co/ds4sd/SmolDocling-256M-preview) in this demo. " "Upload an image, video, and text query or try one of the examples. Each chat starts a new conversation." ), examples=examples, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, cache_examples=False ) if __name__ == "__main__": demo.launch(debug=True)