import gradio as gr from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document from typing import List import logging from pathlib import Path import requests import base64 import io from PIL import Image from datasets import Dataset from huggingface_hub import HfApi import os from mistralai import Mistral # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Mistral OCR setup api_key = os.environ.get("MISTRAL_API_KEY") if not api_key: raise ValueError("MISTRAL_API_KEY environment variable not set") client = Mistral(api_key=api_key) # Function to encode image to base64 def encode_image(image_path): try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') except FileNotFoundError: return "Error: The file was not found." except Exception as e: return f"Error: {e}" # Function to replace images in markdown with base64 strings def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: for img_name, base64_str in images_dict.items(): markdown_str = markdown_str.replace(f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})") return markdown_str # Function to combine markdown from OCR response def get_combined_markdown(ocr_response) -> tuple: markdowns = [] raw_markdowns = [] for page in ocr_response.pages: image_data = {} for img in page.images: image_data[img.id] = img.image_base64 markdowns.append(replace_images_in_markdown(page.markdown, image_data)) raw_markdowns.append(page.markdown) return "\n\n".join(markdowns), "\n\n".join(raw_markdowns) # Perform OCR on uploaded file def perform_ocr_file(file): try: if file.name.lower().endswith('.pdf'): uploaded_pdf = client.files.upload( file={ "file_name": file.name, "content": open(file.name, "rb"), }, purpose="ocr" ) signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id) ocr_response = client.ocr.process( model="mistral-ocr-latest", document={ "type": "document_url", "document_url": signed_url.url, }, include_image_base64=True ) client.files.delete(file_id=uploaded_pdf.id) elif file.name.lower().endswith(('.png', '.jpg', '.jpeg')): base64_image = encode_image(file.name) ocr_response = client.ocr.process( model="mistral-ocr-latest", document={ "type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}" }, include_image_base64=True ) else: return "Unsupported file type. Please provide a PDF or an image (png, jpeg, jpg).", "" combined_markdown, raw_markdown = get_combined_markdown(ocr_response) return combined_markdown, raw_markdown except Exception as e: return f"Error during OCR: {str(e)}", "" # Function to chunk markdown text def chunk_markdown( markdown_text: str, chunk_size: int = 1000, chunk_overlap: int = 200, preserve_numbering: bool = True ) -> List[Document]: if chunk_size <= 0: raise ValueError("chunk_size must be positive") if chunk_overlap < 0: raise ValueError("chunk_overlap cannot be negative") if chunk_overlap >= chunk_size: raise ValueError("chunk_overlap must be less than chunk_size") try: document = Document(page_content=markdown_text, metadata={"source": "ocr_output"}) separators = ( ["\n\d+\.\s+", "\n\n", "\n", ".", " ", ""] if preserve_numbering else ["\n\n", "\n", ".", " ", ""] ) text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, separators=separators, # Fixed parameter name keep_separator=True, add_start_index=True, is_separator_regex=preserve_numbering ) logger.info("Splitting markdown text into chunks") chunks = text_splitter.split_documents([document]) if preserve_numbering: merged_chunks = [] current_chunk = None for chunk in chunks: content = chunk.page_content.strip() if current_chunk is None: current_chunk = chunk elif content.startswith(tuple(f"{i}." for i in range(10))): if current_chunk: merged_chunks.append(current_chunk) current_chunk = chunk else: current_chunk.page_content += "\n" + content current_chunk.metadata["end_index"] = chunk.metadata["start_index"] + len(content) if current_chunk: merged_chunks.append(current_chunk) chunks = merged_chunks logger.info(f"Created {len(chunks)} chunks") return chunks except Exception as e: logger.error(f"Error processing markdown: {str(e)}") raise # Placeholder image generation def text_to_base64_dummy(text: str, chunk_index: int): img = Image.new('RGB', (200, 200), color='white') buffer = io.BytesIO() img.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") # Process file: OCR -> Chunk -> Save def process_file_and_save(file, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name): try: # Step 1: Perform OCR combined_markdown, raw_markdown = perform_ocr_file(file) if "Error" in combined_markdown: return combined_markdown # Step 2: Chunk the markdown chunks = chunk_markdown(combined_markdown, chunk_size, chunk_overlap, preserve_numbering) # Step 3: Prepare dataset data = { "chunk_id": [], "content": [], "metadata": [], "page_image": [] } for i, chunk in enumerate(chunks): data["chunk_id"].append(i) data["content"].append(chunk.page_content) data["metadata"].append(chunk.metadata) img_base64 = None if "![image" in chunk.page_content: start = chunk.page_content.find("data:image") if start != -1: end = chunk.page_content.find(")", start) img_base64 = chunk.page_content[start:end] if not img_base64: img_base64 = text_to_base64_dummy(chunk.page_content, i) data["page_image"].append(img_base64) # Step 4: Create and push dataset to Hugging Face dataset = Dataset.from_dict(data) api = HfApi() api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True) dataset.push_to_hub(repo_name, token=hf_token) return f"Dataset created with {len(chunks)} chunks and saved to Hugging Face at {repo_name}" except Exception as e: return f"Error: {str(e)}" # Gradio Interface with gr.Blocks(title="PDF/Image OCR, Chunking, and Dataset Creator") as demo: gr.Markdown("# PDF/Image OCR, Chunking, and Dataset Creator") gr.Markdown("Upload a PDF or image, extract text/images with Mistral OCR, chunk the markdown, and save to Hugging Face.") with gr.Row(): with gr.Column(): file_input = gr.File(label="Upload PDF or Image") chunk_size = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size") chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap") preserve_numbering = gr.Checkbox(label="Preserve Numbering", value=True) hf_token = gr.Textbox(label="Hugging Face Token", type="password") repo_name = gr.Textbox(label="Hugging Face Repository Name (e.g., username/dataset-name)") submit_btn = gr.Button("Process and Save") with gr.Column(): output = gr.Textbox(label="Result") submit_btn.click( fn=process_file_and_save, inputs=[file_input, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name], outputs=output ) demo.launch(share=True)