Spaces:
Running
Running
File size: 33,747 Bytes
099e67a c5e1e79 099e67a f13386c 099e67a c5e1e79 099e67a f13386c 099e67a f13386c 099e67a f13386c 099e67a f13386c 099e67a f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c 8a5a9ab f13386c c5e1e79 8a5a9ab f13386c 099e67a c5e1e79 099e67a f13386c 099e67a c5e1e79 f13386c c5e1e79 f13386c 099e67a c5e1e79 f13386c c5e1e79 f13386c c5e1e79 f13386c 099e67a f13386c 099e67a f13386c 099e67a f13386c 716b14f 099e67a f13386c 099e67a f13386c 099e67a f13386c 099e67a f13386c 099e67a f13386c 099e67a f13386c 099e67a 8a5a9ab c5e1e79 099e67a f13386c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 |
import gradio as gr
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
from langchain.schema import Document
from typing import List, Dict, Any, Tuple
import logging
import re
import base64
import mimetypes # Added
from datasets import Dataset
from huggingface_hub import HfApi
import huggingface_hub # Added for token checking and errors
import os
from mistralai import Mistral # Assuming this is the correct import for the client
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Mistral OCR Setup ---
api_key = os.environ.get("MISTRAL_API_KEY")
if not api_key:
logger.warning("MISTRAL_API_KEY environment variable not set. Attempting to use Hugging Face token.")
try:
api_key = huggingface_hub.get_token()
if not api_key:
# If running locally, this might still fail if not logged in.
logger.warning("Could not retrieve token from Hugging Face login.")
# Error will be raised later if client init fails or during HF push if token still missing
else:
logger.info("Using Hugging Face token as MISTRAL_API_KEY.")
except Exception as e:
logger.warning(f"Could not check Hugging Face login for token: {e}")
# Proceed without API key, client initialization might fail
# Initialize Mistral Client
client = None
if api_key:
try:
client = Mistral(api_key=api_key)
logger.info("Mistral client initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize Mistral client: {e}", exc_info=True)
# Raise a clearer error for Gradio startup if client fails
raise RuntimeError(f"Failed to initialize Mistral client. Check API key and mistralai installation. Error: {e}")
else:
# This path might be hit if no env var and no HF token found
logger.error("Mistral API key is not available. OCR functionality will fail.")
# We could raise an error here, or let it fail when client methods are called.
# Let's allow Gradio to load but OCR will fail clearly later.
# --- Helper Functions ---
def encode_image_bytes(image_bytes: bytes) -> str:
"""Encodes image bytes to a base64 string."""
return base64.b64encode(image_bytes).decode('utf-8')
def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]:
"""
Combines markdown from OCR pages, replacing image IDs with base64 data URIs.
Args:
ocr_response: The response object from the Mistral OCR API.
Returns:
A tuple containing:
- combined_markdown_with_images: Markdown string with image references replaced by base64 data URIs.
- combined_raw_markdown: Original markdown string without image replacement.
- image_data_map: A dictionary mapping image IDs to their base64 data URIs.
Raises ValueError on unexpected response structure.
"""
processed_markdowns = []
raw_markdowns = []
image_data_map = {} # Collect image_id -> base64_data_uri
if not hasattr(ocr_response, 'pages') or not ocr_response.pages:
logger.warning("OCR response has no 'pages' attribute or pages list is empty.")
return "", "", {}
try:
# Collect all image data first (assuming image_base64 includes data URI prefix from Mistral)
for page_idx, page in enumerate(ocr_response.pages):
if hasattr(page, 'images') and page.images:
for img in page.images:
if hasattr(img, 'id') and hasattr(img, 'image_base64') and img.image_base64:
image_data_map[img.id] = img.image_base64 # Assuming this is the full data URI
else:
logger.warning(f"Page {page_idx}: Image object lacks 'id' or valid 'image_base64'. Image: {img}")
# else: # Don't warn if a page simply has no images
# logger.debug(f"Page {page_idx} has no 'images' attribute or no images found.")
# Process markdown for each page
for page_idx, page in enumerate(ocr_response.pages):
if not hasattr(page, 'markdown'):
logger.warning(f"Page {page_idx} in OCR response lacks 'markdown' attribute. Skipping.")
continue # Skip page if no markdown
current_raw_markdown = page.markdown if page.markdown else ""
raw_markdowns.append(current_raw_markdown)
current_processed_markdown = current_raw_markdown
# Find all image references like 
# Regex to find the image ID (content within parentheses)
img_refs = re.findall(r"!\[.*?\]\((.*?)\)", current_processed_markdown)
for img_id in img_refs:
if img_id in image_data_map:
base64_data_uri = image_data_map[img_id]
# Escape potential regex special characters in img_id before using in replace
escaped_img_id = re.escape(img_id)
# Replace  with 
# Use a specific regex for replacement: find the exact pattern 
pattern = r"(!\[.*?\]\()" + escaped_img_id + r"(\))"
# Check if replacement target exists before replacing
if re.search(pattern, current_processed_markdown):
current_processed_markdown = re.sub(
pattern,
r"\1" + base64_data_uri + r"\2",
current_processed_markdown
)
else:
# This case shouldn't happen often if img_id came from findall on the same string
logger.warning(f"Page {page_idx}: Found img_id '{img_id}' but couldn't find exact pattern '{pattern}' for replacement.")
else:
# Only log warning if the ID looks like an expected image ID pattern (e.g., 'image_X')
# Avoid warning for regular URLs that might be in the markdown
if not img_id.startswith(('http:', 'https:', 'data:')): # Check if it's not already a URL
logger.warning(f"Page {page_idx}: Image ID '{img_id}' found in markdown but not in collected image data.")
processed_markdowns.append(current_processed_markdown)
return "\n\n".join(processed_markdowns), "\n\n".join(raw_markdowns), image_data_map
except AttributeError as ae:
logger.error(f"Attribute error accessing OCR response structure: {ae}", exc_info=True)
raise ValueError(f"Unexpected OCR response structure. Check Mistral API changes. Error: {ae}")
except Exception as e:
logger.error(f"Error processing OCR response markdown: {e}", exc_info=True)
raise
def perform_ocr_file(file_obj: Any) -> Tuple[str, str, Dict[str, str]]:
"""
Performs OCR on an uploaded file (PDF or image) using the Mistral API.
Args:
file_obj: The file object from Gradio's gr.File component.
Returns:
A tuple containing:
- processed_markdown: Markdown string with base64 images, or error message.
- raw_markdown: Original markdown string.
- image_data_map: Dictionary mapping image IDs to base64 data URIs.
"""
if not client:
return "Error: Mistral client not initialized. Check API key setup.", "", {}
if not file_obj:
# This check might be redundant if called from process_file_and_save, but good practice
return "Error: No file provided to OCR function.", "", {}
try:
file_path = file_obj.name # Get the temporary file path from Gradio
# Use the original filename if available (Gradio>=4), else use the temp path's basename
file_name = getattr(file_obj, 'orig_name', os.path.basename(file_path))
logger.info(f"Performing OCR on file: {file_name} (temp path: {file_path})")
# Determine file type from extension
file_ext = os.path.splitext(file_name)[1].lower()
ocr_response = None
uploaded_file_id = None
if file_ext == '.pdf':
try:
with open(file_path, "rb") as f:
logger.info(f"Uploading PDF {file_name} to Mistral...")
# Pass as tuple (filename, file-like object)
uploaded_pdf = client.files.upload(
file=(file_name, f),
purpose="ocr"
)
uploaded_file_id = uploaded_pdf.id
logger.info(f"PDF uploaded successfully. File ID: {uploaded_file_id}")
logger.info(f"Getting signed URL for file ID: {uploaded_file_id}")
signed_url_response = client.files.get_signed_url(file_id=uploaded_file_id)
logger.info(f"Got signed URL: {signed_url_response.url[:50]}...")
logger.info("Sending PDF URL to Mistral OCR (model: mistral-ocr-latest)...")
ocr_response = client.ocr.process(
model="mistral-ocr-latest",
document={
"type": "document_url",
"document_url": signed_url_response.url,
},
include_image_base64=True
)
logger.info("OCR processing complete for PDF.")
finally:
# Ensure cleanup even if OCR fails after upload
if uploaded_file_id:
try:
logger.info(f"Deleting temporary Mistral file: {uploaded_file_id}")
client.files.delete(file_id=uploaded_file_id)
except Exception as delete_err:
logger.warning(f"Failed to delete temporary Mistral file {uploaded_file_id}: {delete_err}")
elif file_ext in ['.png', '.jpg', '.jpeg', '.webp', '.bmp']:
try:
with open(file_path, "rb") as f:
image_bytes = f.read()
if not image_bytes:
return f"Error: Uploaded image file '{file_name}' is empty.", "", {}
base64_encoded_image = encode_image_bytes(image_bytes)
# Determine MIME type
mime_type, _ = mimetypes.guess_type(file_path)
if not mime_type or not mime_type.startswith('image'):
logger.warning(f"Could not determine MIME type for {file_name} using extension. Defaulting to image/jpeg.")
mime_type = 'image/jpeg' # Fallback
data_uri = f"data:{mime_type};base64,{base64_encoded_image}"
logger.info(f"Sending image {file_name} ({mime_type}) as data URI to Mistral OCR (model: mistral-ocr-latest)...")
ocr_response = client.ocr.process(
model="mistral-ocr-latest",
document={
"type": "image_url",
"image_url": data_uri
},
include_image_base64=True
)
logger.info(f"OCR processing complete for image {file_name}.")
except Exception as img_ocr_err:
logger.error(f"Error during image OCR for {file_name}: {img_ocr_err}", exc_info=True)
return f"Error during OCR for image '{file_name}': {img_ocr_err}", "", {}
else:
unsupported_msg = f"Unsupported file type: '{file_name}'. Please provide a PDF or an image (png, jpg, jpeg, webp, bmp)."
logger.warning(unsupported_msg)
return unsupported_msg, "", {}
# Process the OCR response (common path for PDF/Image)
if ocr_response:
logger.info("Processing OCR response to combine markdown and images...")
processed_md, raw_md, img_map = get_combined_markdown(ocr_response)
logger.info("Markdown and image data extraction complete.")
return processed_md, raw_md, img_map
else:
# This case might occur if OCR processing itself failed silently or returned None
logger.error(f"OCR processing for '{file_name}' did not return a valid response.")
return f"Error: OCR processing failed for '{file_name}'. No response received.", "", {}
except FileNotFoundError:
logger.error(f"Temporary file not found: {file_path}", exc_info=True)
return f"Error: Could not read the uploaded file '{file_name}'. Ensure it uploaded correctly.", "", {}
except Exception as e:
logger.error(f"Unexpected error during OCR processing file {file_name}: {e}", exc_info=True)
# Provide more context in the error message returned to the user
return f"Error during OCR processing for '{file_name}': {str(e)}", "", {}
def chunk_markdown(
markdown_text_with_images: str,
chunk_size: int = 1000,
chunk_overlap: int = 200,
strip_headers: bool = True
) -> List[Document]:
"""
Chunks markdown text, preserving headers in metadata and adding embedded image info.
Args:
markdown_text_with_images: The markdown string containing base64 data URIs for images.
chunk_size: The target size for chunks (characters). 0 to disable recursive splitting.
chunk_overlap: The overlap between consecutive chunks (characters).
strip_headers: Whether to remove header syntax (e.g., '# ') from the chunk content.
Returns:
A list of Langchain Document objects representing the chunks. Returns empty list if input is empty.
"""
if not markdown_text_with_images or not markdown_text_with_images.strip():
logger.warning("chunk_markdown received empty or whitespace-only input string.")
return []
try:
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"), # Added more levels
("######", "Header 6"),
]
# Initialize MarkdownHeaderTextSplitter
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on,
strip_headers=strip_headers,
return_each_line=False # Process blocks
)
logger.info("Splitting markdown by headers...")
header_chunks = markdown_splitter.split_text(markdown_text_with_images)
logger.info(f"Split into {len(header_chunks)} chunks based on headers.")
if not header_chunks:
logger.warning("MarkdownHeaderTextSplitter returned zero chunks.")
# Maybe the input had no headers? Treat the whole text as one chunk?
# Or just return empty? Let's return empty for now, as header splitting is intended.
# Alternative: create a single Document if header_chunks is empty but input wasn't.
# doc = Document(page_content=markdown_text_with_images, metadata={})
# header_chunks = [doc]
# logger.info("No headers found, treating input as a single chunk.")
# For now, stick to returning empty list if no header chunks are made.
return []
final_chunks = []
# If chunk_size is specified and > 0, further split large chunks
if chunk_size > 0:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
# More robust separators
separators=["\n\n", "\n", "(?<=\. )", "(?<=\? )", "(?<=! )", ", ", "; ", " ", ""],
keep_separator=False,
add_start_index=True # Add start index relative to the parent (header) chunk
)
logger.info(f"Applying recursive character splitting (size={chunk_size}, overlap={chunk_overlap})...")
processed_chunks_count = 0
for i, header_chunk in enumerate(header_chunks):
# Check if page_content exists and is longer than chunk_size
if header_chunk.page_content and len(header_chunk.page_content) > chunk_size:
logger.debug(f"Header chunk {i} (length {len(header_chunk.page_content)}) needs recursive splitting.")
try:
# split_documents preserves metadata from the parent chunk
sub_chunks = text_splitter.split_documents([header_chunk])
final_chunks.extend(sub_chunks)
processed_chunks_count += len(sub_chunks)
logger.debug(f" -> Split into {len(sub_chunks)} sub-chunks.")
except Exception as split_err:
logger.error(f"Error splitting header chunk {i}: {split_err}", exc_info=True)
# Option: Add the original large chunk instead? Or skip? Let's skip broken ones.
logger.warning(f"Skipping header chunk {i} due to splitting error.")
continue
else:
# If the chunk is already small enough or empty, just add it
if header_chunk.page_content: # Add only if it has content
final_chunks.append(header_chunk)
processed_chunks_count += 1
logger.debug(f"Header chunk {i} (length {len(header_chunk.page_content)}) kept as is.")
else:
logger.debug(f"Header chunk {i} was empty, skipping.")
logger.info(f"Recursive character splitting finished. Processed {processed_chunks_count} chunks.")
else:
# If chunk_size is 0, use only non-empty header chunks
logger.info("chunk_size is 0, using only non-empty header-based chunks.")
final_chunks = [chunk for chunk in header_chunks if chunk.page_content]
# Post-process final chunks: Extract embedded image data URIs and add to metadata
logger.info("Extracting embedded image data URIs for final chunk metadata...")
for chunk in final_chunks:
images_in_chunk = []
if chunk.page_content:
try:
# Regex to find all base64 data URIs in the chunk content
# Non-greedy alt text `.*?`, robust base64 chars `[A-Za-z0-9+/=]+`
# Ensure the closing parenthesis `\)` is matched correctly
pattern = r"!\[.*?\]\((data:image/[a-zA-Z+]+;base64,[A-Za-z0-9+/=]+)\)"
images_in_chunk = re.findall(pattern, chunk.page_content)
except Exception as regex_err:
logger.error(f"Regex error extracting images from chunk: {regex_err}", exc_info=True)
# Leave images list empty for this chunk
# Ensure metadata exists and add images list (can be empty)
if not hasattr(chunk, 'metadata'):
chunk.metadata = {}
chunk.metadata["images_base64"] = images_in_chunk # Use a more specific key name
logger.info(f"Created {len(final_chunks)} final chunks after processing and filtering.")
return final_chunks
except Exception as e:
logger.error(f"Error during markdown chunking process: {str(e)}", exc_info=True)
raise # Re-raise to be caught by the main processing function
# --- Main Processing Function ---
def process_file_and_save(
file_obj: Any, # Gradio File object
chunk_size: int,
chunk_overlap: int,
strip_headers: bool,
hf_token: str,
repo_name: str
) -> str:
"""
Orchestrates the OCR, chunking, and saving process to Hugging Face Hub.
Args:
file_obj: The uploaded file object from Gradio.
chunk_size: Max chunk size for text splitting (chars). 0 disables recursive splitting.
chunk_overlap: Overlap for text splitting (chars).
strip_headers: Whether to remove markdown headers from chunk content.
hf_token: Hugging Face API token (write permission).
repo_name: Name for the Hugging Face dataset repository (e.g., 'username/my-ocr-dataset').
Returns:
A string indicating success or failure, suitable for display in Gradio.
"""
# --- Input Validation ---
if not file_obj:
return "Error: No file uploaded. Please upload a PDF or image file."
if not repo_name or '/' not in repo_name:
return "Error: Invalid Hugging Face Repository Name. Use format 'username/dataset-name'."
# Validate chunking parameters
if chunk_size < 0:
logger.warning("Chunk size cannot be negative. Setting to 0 (header splits only).")
chunk_size = 0
if chunk_overlap < 0:
logger.warning("Chunk overlap cannot be negative. Setting to 0.")
chunk_overlap = 0
if chunk_size > 0 and chunk_overlap >= chunk_size:
logger.warning(f"Chunk overlap ({chunk_overlap}) >= chunk size ({chunk_size}). Adjusting overlap to {min(200, chunk_size // 2)}.")
chunk_overlap = min(200, chunk_size // 2) # Set a reasonable overlap
# Handle Hugging Face Token
if not hf_token:
logger.info("No explicit HF token provided. Trying to use token from local Hugging Face login.")
try:
hf_token = huggingface_hub.get_token()
if not hf_token:
return "Error: Hugging Face Token is required. Please provide a token or log in using `huggingface-cli login`."
logger.info("Using HF token from local login for dataset operations.")
except Exception as e:
logger.error(f"Error checking HF login for token: {e}", exc_info=True)
return f"Error: Hugging Face Token is required. Could not verify HF login: {e}"
try:
source_filename = getattr(file_obj, 'orig_name', os.path.basename(file_obj.name))
logger.info(f"--- Starting processing for file: {source_filename} ---")
# --- Step 1: Perform OCR ---
logger.info("Step 1: Performing OCR...")
processed_markdown, _, _ = perform_ocr_file(file_obj) # raw_markdown and image_map not directly used later
# Check if OCR returned an error message or was empty/invalid
if not processed_markdown or isinstance(processed_markdown, str) and (
processed_markdown.startswith("Error:") or processed_markdown.startswith("Unsupported file type:")):
logger.error(f"OCR failed or returned error/unsupported: {processed_markdown}")
return processed_markdown # Return the error message directly
if not isinstance(processed_markdown, str) or len(processed_markdown.strip()) == 0:
logger.error("OCR processing returned empty or invalid markdown content.")
return "Error: OCR returned empty or invalid content."
logger.info("Step 1: OCR finished successfully.")
# --- Step 2: Chunk the markdown ---
logger.info("Step 2: Chunking the markdown...")
chunks = chunk_markdown(processed_markdown, chunk_size, chunk_overlap, strip_headers)
if not chunks:
logger.error("Chunking resulted in zero chunks. Check OCR output and chunking parameters.")
return "Error: Failed to chunk the document (possibly empty after OCR or no headers found)."
logger.info(f"Step 2: Chunking finished, produced {len(chunks)} chunks.")
# --- Step 3: Prepare dataset ---
logger.info("Step 3: Preparing data for Hugging Face dataset...")
data: Dict[str, List[Any]] = {
"chunk_id": [],
"text": [], # Renamed 'content' to 'text'
"metadata": [],
"source_filename": [],
}
for i, chunk in enumerate(chunks):
chunk_id = f"{source_filename}_chunk_{i}"
data["chunk_id"].append(chunk_id)
data["text"].append(chunk.page_content if chunk.page_content else "") # Ensure text is string
# Ensure metadata is serializable (dicts, lists, primitives) for HF Datasets
serializable_metadata = {}
if hasattr(chunk, 'metadata') and chunk.metadata:
for k, v in chunk.metadata.items():
if isinstance(v, (str, int, float, bool, list, dict, type(None))):
serializable_metadata[k] = v
else:
# Convert potentially problematic types (like Langchain objects) to string
logger.warning(f"Chunk {chunk_id}: Metadata key '{k}' has non-standard type {type(v)}. Converting to string.")
try:
serializable_metadata[k] = str(v)
except Exception as str_err:
logger.error(f"Chunk {chunk_id}: Failed to convert metadata key '{k}' to string: {str_err}")
serializable_metadata[k] = f"ERROR_CONVERTING_{type(v).__name__}"
data["metadata"].append(serializable_metadata)
data["source_filename"].append(source_filename)
# --- Step 4: Create and push dataset to Hugging Face ---
logger.info(f"Step 4: Creating Hugging Face Dataset object for repo '{repo_name}'...")
try:
# Explicitly define features for robustness, especially if metadata varies
# features = datasets.Features({
# "chunk_id": datasets.Value("string"),
# "text": datasets.Value("string"),
# "metadata": datasets.features.Features({}), # Define known metadata fields if possible, or leave open
# "source_filename": datasets.Value("string"),
# })
# dataset = Dataset.from_dict(data, features=features)
dataset = Dataset.from_dict(data) # Simpler approach, infers features
logger.info(f"Dataset object created with {len(chunks)} rows.")
except Exception as ds_err:
logger.error(f"Failed to create Dataset object: {ds_err}", exc_info=True)
return f"Error: Failed to create dataset structure. Check logs. ({ds_err})"
logger.info(f"Connecting to Hugging Face Hub API to push to '{repo_name}'...")
try:
api = HfApi(token=hf_token) # Pass token explicitly
# Create repo if it doesn't exist
try:
api.repo_info(repo_id=repo_name, repo_type="dataset")
logger.info(f"Repository '{repo_name}' already exists. Will overwrite content.")
except huggingface_hub.utils.RepositoryNotFoundError:
logger.info(f"Repository '{repo_name}' does not exist. Creating...")
api.create_repo(repo_id=repo_name, repo_type="dataset", private=False) # Default to public
logger.info(f"Successfully created repository '{repo_name}'.")
# Push the dataset
logger.info(f"Pushing dataset to '{repo_name}'...")
commit_message = f"Add/update OCR data from file: {source_filename}"
# push_to_hub overwrites the dataset by default
dataset.push_to_hub(repo_name, commit_message=commit_message)
repo_url = f"https://huggingface.co/datasets/{repo_name}"
logger.info(f"Dataset successfully pushed to {repo_url}")
return f"Success! Dataset with {len(chunks)} chunks saved to Hugging Face: {repo_url}"
except huggingface_hub.utils.HfHubHTTPError as hf_http_err:
logger.error(f"Hugging Face Hub HTTP Error: {hf_http_err}", exc_info=True)
return f"Error: Hugging Face Hub Error pushing to '{repo_name}'. Status: {hf_http_err.response.status_code}. Check token permissions, repo name, and network. Details: {hf_http_err}"
except Exception as push_err:
logger.error(f"Failed to push dataset to '{repo_name}': {push_err}", exc_info=True)
return f"Error: Failed to push dataset to Hugging Face repository '{repo_name}'. ({push_err})"
except Exception as e:
# Catch any unexpected errors during the overall process
logger.error(f"An unexpected error occurred processing '{source_filename}': {str(e)}", exc_info=True)
return f"An unexpected error occurred: {str(e)}"
finally:
logger.info(f"--- Finished processing for file: {source_filename} ---")
# --- Gradio Interface ---
with gr.Blocks(title="Mistral OCR & Dataset Creator", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as demo:
gr.Markdown("# Mistral OCR, Markdown Chunking, and Hugging Face Dataset Creator")
gr.Markdown(
"""
Upload a PDF or image file (PNG, JPG, WEBP, BMP). The application will:
1. Extract text and images using **Mistral OCR**.
2. Embed images as base64 data URIs directly within the extracted markdown text.
3. Chunk the resulting markdown based on **headers** and optionally **recursively by character count**.
4. Store any embedded base64 images found **within each chunk** in the chunk's metadata (`metadata['images_base64']`).
5. Create or update a **Hugging Face Dataset** with the processed chunks (`chunk_id`, `text`, `metadata`, `source_filename`).
"""
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload PDF or Image File",
file_types=['.pdf', '.png', '.jpg', '.jpeg', '.webp', '.bmp'],
type="filepath" # Ensures we get a path usable by `open()`
)
gr.Markdown("## Chunking Options")
chunk_size = gr.Slider(
minimum=0, maximum=8000, value=1000, step=100, # Increased max size
label="Max Chunk Size (Characters)",
info="Approximate target size. Set to 0 to disable recursive splitting (uses only header splits)."
)
chunk_overlap = gr.Slider(
minimum=0, maximum=1000, value=200, step=50,
label="Chunk Overlap (Characters)",
info="Number of characters to overlap between consecutive chunks (if recursive splitting is enabled)."
)
strip_headers = gr.Checkbox(
label="Strip Markdown Headers (#) from Chunk Content",
value=True,
info="If checked, removes '#', '##' etc. from the start of the text in each chunk."
)
gr.Markdown("## Hugging Face Output Options")
repo_name = gr.Textbox(
label="Target Hugging Face Dataset Repository",
placeholder="your-username/your-dataset-name",
info="The dataset will be pushed here (e.g., 'my-org/my-ocr-data'). Will be created if it doesn't exist."
)
hf_token = gr.Textbox(
label="Hugging Face Token (write permission)",
type="password",
placeholder="hf_...",
info="Required to create/push the dataset. If blank, will try using token from local `huggingface-cli login`.",
# value=os.environ.get("HF_TOKEN", "") # Optionally pre-fill from env var if desired
)
submit_btn = gr.Button("Process File and Save to Hugging Face", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(
label="Processing Log / Result Status",
lines=20,
interactive=False,
placeholder="Processing steps and final result will appear here..."
)
submit_btn.click(
fn=process_file_and_save,
inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name],
outputs=output
)
gr.Examples(
examples=[
[None, 1000, 200, True, "", "hf-username/my-first-ocr-dataset"],
[None, 2000, 400, True, "", "hf-username/large-chunk-ocr-data"],
[None, 0, 0, False, "", "hf-username/header-only-ocr-data"], # Example for header-only splitting
],
inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name],
outputs=output,
fn=process_file_and_save, # Make examples clickable
cache_examples=False # Avoid caching as it involves API calls and file processing
)
gr.Markdown("--- \n *Requires `MISTRAL_API_KEY` environment variable or being logged in via `huggingface-cli login`.*")
# --- Launch the Gradio App ---
if __name__ == "__main__":
# Check if client initialization failed earlier
if not client and api_key: # Check if key was present but init failed
print("\nCRITICAL: Mistral client failed to initialize. The application cannot perform OCR.")
print("Please check your MISTRAL_API_KEY and network connection.\n")
# Optionally exit, or let Gradio launch with limited functionality
# exit(1)
elif not client and not api_key:
print("\nWARNING: Mistral client not initialized because no API key was found.")
print("OCR functionality will fail. Please set MISTRAL_API_KEY or log in via `huggingface-cli login`.\n")
# share=True creates a public link (useful for Colab/Spaces)
# debug=True provides detailed errors in the console during development
demo.launch(share=os.getenv('GRADIO_SHARE', 'False').lower() == 'true', debug=True,) |