Doc-VLMs-OCR / app.py
prithivMLmods's picture
Update app.py
4699b28 verified
raw
history blame
19 kB
import os
import random
import uuid
import json
import time
import asyncio
from threading import Thread
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageOps
import cv2
from transformers import (
Qwen2VLForConditionalGeneration,
VisionEncoderDecoderModel,
AutoModelForVision2Seq,
AutoProcessor,
TextIteratorStreamer,
)
from transformers.image_utils import load_image
from docling_core.types.doc import DoclingDocument, DocTagsDocument
import re
import ast
import html
# Constants for text generation
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load olmOCR-7B-0225-preview
MODEL_ID_M = "allenai/olmOCR-7B-0225-preview"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID_M,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load ByteDance's Dolphin
MODEL_ID_K = "ByteDance/Dolphin"
processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
model_k = VisionEncoderDecoderModel.from_pretrained(
MODEL_ID_K,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load SmolDocling-256M-preview
MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
model_x = AutoModelForVision2Seq.from_pretrained(
MODEL_ID_X,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Preprocessing functions for SmolDocling-256M
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
"""Add random padding to an image based on its size."""
image = image.convert("RGB")
width, height = image.size
pad_w_percent = random.uniform(min_percent, max_percent)
pad_h_percent = random.uniform(min_percent, max_percent)
pad_w = int(width * pad_w_percent)
pad_h = int(height * pad_h_percent)
corner_pixel = image.getpixel((0, 0)) # Top-left corner
padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
return padded_image
def normalize_values(text, target_max=500):
"""Normalize numerical values in text to a target maximum."""
def normalize_list(values):
max_value = max(values) if values else 1
return [round((v / max_value) * target_max) for v in values]
def process_match(match):
num_list = ast.literal_eval(match.group(0))
normalized = normalize_list(num_list)
return "".join([f"<loc_{num}>" for num in normalized])
pattern = r"\[([\d\.\s,]+)\]"
normalized_text = re.sub(pattern, process_match, text)
return normalized_text
def downsample_video(video_path):
"""Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
# Dolphin-specific functions
def model_chat(prompt, image, is_batch=False):
"""Use Dolphin model for inference, supporting both single and batch processing."""
processor = processor_k
model = model_k
device = "cuda" if torch.cuda.is_available() else "cpu"
if not is_batch:
images = [image]
prompts = [prompt]
else:
images = image
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
inputs = processor(images, return_tensors="pt", padding=True).to(device)
pixel_values = inputs.pixel_values.half()
prompts = [f"<s>{p} <Answer/>" for p in prompts]
prompt_inputs = processor.tokenizer(
prompts,
add_special_tokens=False, # Explicitly set to False
return_tensors="pt",
padding=True
).to(device)
outputs = model.generate(
pixel_values=pixel_values,
decoder_input_ids=prompt_inputs.input_ids,
decoder_attention_mask=prompt_inputs.attention_mask,
min_length=1,
max_length=4096,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
repetition_penalty=1.1
)
sequences = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
results = []
for i, sequence in enumerate(sequences):
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
results.append(cleaned)
return results[0] if not is_batch else results
def process_element_batch(elements, prompt, max_batch_size=16):
"""Process a batch of elements with the same prompt."""
results = []
batch_size = min(len(elements), max_batch_size)
for i in range(0, len(elements), batch_size):
batch_elements = elements[i:i + batch_size]
crops_list = [elem["crop"] for elem in batch_elements]
prompts_list = [prompt] * len(crops_list)
batch_results = model_chat(prompts_list, crops_list, is_batch=True)
for j, result in enumerate(batch_results):
elem = batch_elements[j]
results.append({
"label": elem["label"],
"bbox": elem["bbox"],
"text": result.strip(),
"reading_order": elem["reading_order"],
})
return results
def process_elements(layout_results, image):
"""Parse layout results and extract elements from the image."""
try:
elements = ast.literal_eval(layout_results)
except:
elements = []
text_elements = []
table_elements = []
figure_results = []
reading_order = 0
for bbox, label in elements:
try:
x1, y1, x2, y2 = map(int, bbox)
cropped = image.crop((x1, y1, x2, y2))
if cropped.size[0] > 0 and cropped.size[1] > 0:
element_info = {
"crop": cropped,
"label": label,
"bbox": [x1, y1, x2, y2],
"reading_order": reading_order,
}
if label == "text":
text_elements.append(element_info)
elif label == "table":
table_elements.append(element_info)
elif label == "figure":
figure_results.append({
"label": label,
"bbox": [x1, y1, x2, y2],
"text": "[Figure]",
"reading_order": reading_order
})
reading_order += 1
except Exception as e:
print(f"Error processing element: {e}")
continue
recognition_results = figure_results.copy()
if text_elements:
text_results = process_element_batch(text_elements, "Read text in the image.")
recognition_results.extend(text_results)
if table_elements:
table_results = process_element_batch(table_elements, "Parse the table in the image.")
recognition_results.extend(table_results)
recognition_results.sort(key=lambda x: x["reading_order"])
return recognition_results
def generate_markdown(recognition_results):
"""Generate markdown from extracted elements."""
markdown = ""
for element in recognition_results:
if element["label"] == "text":
markdown += f"{element['text']}\n\n"
elif element["label"] == "table":
markdown += f"**Table:**\n{element['text']}\n\n"
elif element["label"] == "figure":
markdown += f"{element['text']}\n\n"
return markdown.strip()
def process_image_with_dolphin(image):
"""Process a single image with Dolphin model."""
layout_output = model_chat("Parse the reading order of this document.", image)
elements = process_elements(layout_output, image)
markdown_content = generate_markdown(elements)
return markdown_content
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
"""Generate responses for image input using the selected model."""
if model_name == "ByteDance-s-Dolphin":
if image is None:
yield "Please upload an image."
return
markdown_content = process_image_with_dolphin(image)
yield markdown_content
else:
if model_name == "olmOCR-7B-0225-preview":
processor = processor_m
model = model_m
elif model_name == "SmolDocling-256M-preview":
processor = processor_x
model = model_x
else:
yield "Invalid model selected."
return
if image is None:
yield "Please upload an image."
return
images = [image]
if model_name == "SmolDocling-256M-preview":
if "OTSL" in text or "code" in text:
images = [add_random_padding(img) for img in images]
if "OCR at text at" in text or "Identify element" in text or "formula" in text:
text = normalize_values(text, target_max=500)
messages = [
{
"role": "user",
"content": [{"type": "image"} for _ in images] + [
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
full_output = ""
for new_text in streamer:
full_output += new_text
buffer += new_text.replace("<|im_end|>", "")
yield buffer
if model_name == "SmolDocling-256M-preview":
cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
if "<chart>" in cleaned_output:
cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
yield f"**MD Output:**\n\n{markdown_output}"
else:
yield cleaned_output
@spaces.GPU
def generate_video(model_name: str, text: str, video_path: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
"""Generate responses for video input using the selected model."""
if model_name == "ByteDance-s-Dolphin":
if video_path is None:
yield "Please upload a video."
return
frames = downsample_video(video_path)
markdown_contents = []
for frame, _ in frames:
markdown_content = process_image_with_dolphin(frame)
markdown_contents.append(markdown_content)
combined_markdown = "\n\n".join(markdown_contents)
yield combined_markdown
else:
if model_name == "olmOCR-7B-0225-preview":
processor = processor_m
model = model_m
elif model_name == "SmolDocling-256M-preview":
processor = processor_x
model = model_x
else:
yield "Invalid model selected."
return
if video_path is None:
yield "Please upload a video."
return
frames = downsample_video(video_path)
images = [frame for frame, _ in frames]
if model_name == "SmolDocling-256M-preview":
if "OTSL" in text or "code" in text:
images = [add_random_padding(img) for img in images]
if "OCR at text at" in text or "Identify element" in text or "formula" in text:
text = normalize_values(text, target_max=500)
messages = [
{
"role": "user",
"content": [{"type": "image"} for _ in images] + [
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
full_output = ""
for new_text in streamer:
full_output += new_text
buffer += new_text.replace("<|im_end|>", "")
yield buffer
if model_name == "SmolDocling-256M-preview":
cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
if "<chart>" in cleaned_output:
cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
markdown_output = doc.export_to_markdown()
yield f"**MD Output:**\n\n{markdown_output}"
else:
yield cleaned_output
# Define examples for image and video inference
image_examples = [
["Convert this page to docling", "images/1.png"],
["OCR the image", "images/2.jpg"],
["Convert this page to docling", "images/3.png"],
]
video_examples = [
["Explain the ad in detail", "example/1.mp4"],
["Identify the main actions in the coca cola ad...", "example/2.mp4"]
]
css = """
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
"""
# Create the Gradio Interface
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
with gr.Row():
with gr.Column():
with gr.Tabs():
with gr.TabItem("Image Inference"):
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
image_upload = gr.Image(type="pil", label="Image")
image_submit = gr.Button("Submit", elem_classes="submit-btn")
gr.Examples(
examples=image_examples,
inputs=[image_query, image_upload]
)
with gr.TabItem("Video Inference"):
video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
video_upload = gr.Video(label="Video")
video_submit = gr.Button("Submit", elem_classes="submit-btn")
gr.Examples(
examples=video_examples,
inputs=[video_query, video_upload]
)
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
with gr.Column():
output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
model_choice = gr.Radio(
choices=["olmOCR-7B-0225-preview", "SmolDocling-256M-preview", "ByteDance-s-Dolphin"],
label="Select Model",
value="olmOCR-7B-0225-preview"
)
image_submit.click(
fn=generate_image,
inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=output
)
video_submit.click(
fn=generate_video,
inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=output
)
if __name__ == "__main__":
demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)