Spaces:
Running
Running
import gradio as gr | |
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from typing import List, Dict | |
import logging | |
import re | |
from pathlib import Path | |
import requests | |
import base64 | |
import io | |
from PIL import Image | |
from datasets import Dataset | |
from huggingface_hub import HfApi | |
import os | |
from mistralai import Mistral | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Mistral OCR setup | |
api_key = os.environ.get("MISTRAL_API_KEY") | |
if not api_key: | |
raise ValueError("MISTRAL_API_KEY environment variable not set") | |
client = Mistral(api_key=api_key) | |
# Function to encode image to base64 | |
def encode_image(image_path): | |
try: | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
except FileNotFoundError: | |
return "Error: The file was not found." | |
except Exception as e: | |
return f"Error: {e}" | |
# Function to replace images in markdown with base64 strings | |
def replace_images_in_markdown(markdown_str: str, images_dict: Dict[str, str]) -> str: | |
for img_name, base64_str in images_dict.items(): | |
markdown_str = markdown_str.replace(f"", f"") | |
return markdown_str | |
# Function to combine markdown from OCR response | |
def get_combined_markdown(ocr_response) -> tuple[str, str]: | |
markdowns = [] | |
raw_markdowns = [] | |
image_data = {} # Collect all image data | |
for page in ocr_response.pages: | |
for img in page.images: | |
image_data[img.id] = img.image_base64 | |
markdowns.append(replace_images_in_markdown(page.markdown, image_data)) | |
raw_markdowns.append(page.markdown) | |
return "\n\n".join(markdowns), "\n\n".join(raw_markdowns), image_data | |
# Perform OCR on uploaded file | |
def perform_ocr_file(file): | |
try: | |
if file.name.lower().endswith('.pdf'): | |
uploaded_pdf = client.files.upload( | |
file={ | |
"file_name": file.name, | |
"content": open(file.name, "rb"), | |
}, | |
purpose="ocr" | |
) | |
signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id) | |
ocr_response = client.ocr.process( | |
model="mistral-ocr-latest", | |
document={ | |
"type": "document_url", | |
"document_url": signed_url.url, | |
}, | |
include_image_base64=True | |
) | |
client.files.delete(file_id=uploaded_pdf.id) | |
elif file.name.lower().endswith(('.png', '.jpg', '.jpeg')): | |
base64_image = encode_image(file.name) | |
ocr_response = client.ocr.process( | |
model="mistral-ocr-latest", | |
document={ | |
"type": "image_url", | |
"image_url": f"data:image/jpeg;base64,{base64_image}" | |
}, | |
include_image_base64=True | |
) | |
else: | |
return "Unsupported file type. Please provide a PDF or an image (png, jpeg, jpg).", "", {} | |
combined_markdown, raw_markdown, image_data = get_combined_markdown(ocr_response) | |
return combined_markdown, raw_markdown, image_data | |
except Exception as e: | |
return f"Error during OCR: {str(e)}", "", {} | |
# Function to extract image names from markdown content | |
def extract_image_names_from_markdown(markdown_text: str) -> List[str]: | |
# Regex to match markdown image syntax | |
pattern = r"!\[(.*?)\]\(" | |
return [match.replace(" for match in re.findall(pattern, markdown_text)] | |
# Function to chunk markdown text with image handling | |
def chunk_markdown( | |
markdown_text: str, | |
image_data: Dict[str, str], | |
chunk_size: int = 1000, | |
chunk_overlap: int = 200, | |
strip_headers: bool = True | |
) -> List[Document]: | |
try: | |
# Define headers to split on | |
headers_to_split_on = [ | |
("#", "Header 1"), | |
("##", "Header 2"), | |
("###", "Header 3"), | |
] | |
# Initialize MarkdownHeaderTextSplitter | |
markdown_splitter = MarkdownHeaderTextSplitter( | |
headers_to_split_on=headers_to_split_on, | |
strip_headers=strip_headers | |
) | |
# Split markdown by headers | |
logger.info("Splitting markdown by headers") | |
chunks = markdown_splitter.split_text(markdown_text) | |
# If chunk_size is specified, further split large chunks | |
if chunk_size > 0: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
length_function=len, | |
separators=["\n\n", "\n", ".", " ", ""], | |
keep_separator=True, | |
add_start_index=True | |
) | |
logger.info(f"Applying character-level splitting with chunk_size={chunk_size}") | |
final_chunks = [] | |
for chunk in chunks: | |
if len(chunk.page_content) > chunk_size: | |
sub_chunks = text_splitter.split_documents([chunk]) | |
final_chunks.extend(sub_chunks) | |
else: | |
final_chunks.append(chunk) | |
chunks = final_chunks | |
# Add images to metadata | |
for chunk in chunks: | |
image_names = extract_image_names_from_markdown(chunk.page_content) | |
chunk.metadata["images"] = {name: image_data.get(name, None) for name in image_names} | |
logger.info(f"Created {len(chunks)} chunks") | |
return chunks | |
except Exception as e: | |
logger.error(f"Error processing markdown: {str(e)}") | |
raise | |
# Placeholder image generation | |
def text_to_base64_dummy(text: str, chunk_index: int): | |
img = Image.new('RGB', (200, 200), color='white') | |
buffer = io.BytesIO() | |
img.save(buffer, format="PNG") | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
# Process file: OCR -> Chunk -> Save | |
def process_file_and_save(file, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name): | |
try: | |
# Step 1: Perform OCR | |
combined_markdown, raw_markdown, image_data = perform_ocr_file(file) | |
if "Error" in combined_markdown: | |
return combined_markdown | |
# Step 2: Chunk the markdown | |
chunks = chunk_markdown(combined_markdown, image_data, chunk_size, chunk_overlap, strip_headers) | |
# Step 3: Prepare dataset | |
data = { | |
"chunk_id": [], | |
"content": [], | |
"metadata": [], | |
} | |
for i, chunk in enumerate(chunks): | |
data["chunk_id"].append(i) | |
data["content"].append(chunk.page_content) | |
data["metadata"].append(chunk.metadata) | |
# Step 4: Create and push dataset to Hugging Face | |
dataset = Dataset.from_dict(data) | |
api = HfApi() | |
api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True) | |
dataset.push_to_hub(repo_name, token=hf_token) | |
return f"Dataset created with {len(chunks)} chunks and saved to Hugging Face at {repo_name}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Gradio Interface | |
with gr.Blocks(title="PDF/Image OCR, Markdown Chunking, and Dataset Creator") as demo: | |
gr.Markdown("# PDF/Image OCR, Markdown Chunking, and Dataset Creator") | |
gr.Markdown("Upload a PDF or image, extract text/images with Mistral OCR, chunk the markdown by headers, and save to Hugging Face.") | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File(label="Upload PDF or Image") | |
chunk_size = gr.Slider(0, 2000, value=1000, step=100, label="Max Chunk Size (0 to disable)") | |
chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap") | |
strip_headers = gr.Checkbox(label="Strip Headers from Content", value=True) | |
hf_token = gr.Textbox(label="Hugging Face Token", type="password") | |
repo_name = gr.Textbox(label="Hugging Face Repository Name (e.g., username/dataset-name)") | |
submit_btn = gr.Button("Process and Save") | |
with gr.Column(): | |
output = gr.Textbox(label="Result") | |
submit_btn.click( | |
fn=process_file_and_save, | |
inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name], | |
outputs=output | |
) | |
demo.launch(share=True) |