Spaces:
Running
Running
import gradio as gr | |
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from typing import List, Dict, Any, Tuple | |
import logging | |
import re | |
import base64 | |
import mimetypes | |
from datasets import Dataset | |
from huggingface_hub import HfApi, get_token | |
import huggingface_hub | |
import os | |
from mistralai import Mistral | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Mistral OCR Setup --- | |
api_key = os.environ.get("MISTRAL_API_KEY") | |
hf_token_global = None | |
client = None | |
if not api_key: | |
logger.warning("MISTRAL_API_KEY not set. Attempting to use Hugging Face token.") | |
api_key = get_token() | |
if api_key: | |
logger.info("Using Hugging Face token as MISTRAL_API_KEY.") | |
else: | |
logger.warning("No API key found.") | |
if api_key: | |
try: | |
client = Mistral(api_key=api_key) | |
logger.info("Mistral client initialized successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize Mistral client: {e}", exc_info=True) | |
raise RuntimeError(f"Failed to initialize Mistral client: {e}") | |
else: | |
logger.error("Mistral API key not available. OCR will fail.") | |
# --- Helper Functions --- | |
def encode_image_bytes(image_bytes: bytes) -> str: | |
"""Encodes image bytes to a base64 string.""" | |
return base64.b64encode(image_bytes).decode('utf-8') | |
def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]: | |
"""Combines markdown from OCR pages, replacing image IDs with base64 data URIs.""" | |
processed_markdowns = [] | |
raw_markdowns = [] | |
image_data_map = {} | |
if not hasattr(ocr_response, 'pages') or not ocr_response.pages: | |
logger.warning("OCR response has no 'pages' attribute or pages list is empty.") | |
return "", "", {} | |
try: | |
for page_idx, page in enumerate(ocr_response.pages): | |
if hasattr(page, 'images') and page.images: | |
for img in page.images: | |
if hasattr(img, 'id') and hasattr(img, 'image_base64') and img.image_base64: | |
image_data_map[img.id] = img.image_base64 | |
else: | |
logger.warning(f"Page {page_idx}: Image object lacks 'id' or valid 'image_base64'.") | |
if not hasattr(page, 'markdown'): | |
logger.warning(f"Page {page_idx} lacks 'markdown' attribute. Skipping.") | |
continue | |
current_raw_markdown = page.markdown if page.markdown else "" | |
raw_markdowns.append(current_raw_markdown) | |
current_processed_markdown = current_raw_markdown | |
img_refs = re.findall(r"!\[.*?\]\((.*?)\)", current_processed_markdown) | |
for img_id in img_refs: | |
if img_id in image_data_map: | |
base64_data_uri = image_data_map[img_id] | |
escaped_img_id = re.escape(img_id) | |
pattern = r"(!\[.*?\]\()" + escaped_img_id + r"(\))" | |
if re.search(pattern, current_processed_markdown): | |
current_processed_markdown = re.sub( | |
pattern, | |
r"\1" + base64_data_uri + r"\2", | |
current_processed_markdown | |
) | |
elif not img_id.startswith(('http:', 'https:', 'data:')): | |
logger.warning(f"Page {page_idx}: Image ID '{img_id}' not in image data.") | |
processed_markdowns.append(current_processed_markdown) | |
return "\n\n".join(processed_markdowns), "\n\n".join(raw_markdowns), image_data_map | |
except Exception as e: | |
logger.error(f"Error processing OCR response markdown: {e}", exc_info=True) | |
raise | |
def perform_ocr_file(file_obj: Any) -> Tuple[str, str, Dict[str, str]]: | |
"""Performs OCR on an uploaded file using Mistral API.""" | |
if not client: | |
return "Error: Mistral client not initialized.", "", {} | |
if not file_obj: | |
return "Error: No file provided.", "", {} | |
try: | |
file_path = file_obj.name | |
file_name = getattr(file_obj, 'orig_name', os.path.basename(file_path)) | |
logger.info(f"Performing OCR on file: {file_name}") | |
file_ext = os.path.splitext(file_name)[1].lower() | |
ocr_response = None | |
uploaded_file_id = None | |
if file_ext == '.pdf': | |
try: | |
with open(file_path, "rb") as f: | |
file_content = f.read() | |
logger.info(f"Uploading PDF {file_name} to Mistral...") | |
files = { | |
"file": (file_name, file_content, "application/pdf") | |
} | |
uploaded_pdf = client.files.upload( | |
file=files["file"], | |
purpose="ocr" | |
) | |
uploaded_file_id = uploaded_pdf.id | |
logger.info(f"PDF uploaded successfully. File ID: {uploaded_file_id}") | |
signed_url_response = client.files.get_signed_url(file_id=uploaded_file_id) | |
ocr_response = client.ocr.process( | |
model="mistral-ocr-latest", | |
document={"type": "document_url", "document_url": signed_url_response.url}, | |
include_image_base64=True | |
) | |
finally: | |
if uploaded_file_id: | |
try: | |
client.files.delete(file_id=uploaded_file_id) | |
except Exception as delete_err: | |
logger.warning(f"Failed to delete temporary file {uploaded_file_id}: {delete_err}") | |
elif file_ext in ['.png', '.jpg', '.jpeg', '.webp', '.bmp']: | |
with open(file_path, "rb") as f: | |
image_bytes = f.read() | |
if not image_bytes: | |
return f"Error: Uploaded image file '{file_name}' is empty.", "", {} | |
base64_encoded_image = encode_image_bytes(image_bytes) | |
mime_type, _ = mimetypes.guess_type(file_path) | |
mime_type = mime_type or 'image/jpeg' | |
data_uri = f"data:{mime_type};base64,{base64_encoded_image}" | |
ocr_response = client.ocr.process( | |
model="mistral-ocr-latest", | |
document={"type": "image_url", "image_url": data_uri}, | |
include_image_base64=True | |
) | |
else: | |
return f"Unsupported file type: '{file_name}'.", "", {} | |
if ocr_response: | |
return get_combined_markdown(ocr_response) | |
return f"Error: OCR failed for '{file_name}'.", "", {} | |
except Exception as e: | |
logger.error(f"Error during OCR: {e}", exc_info=True) | |
return f"Error during OCR: {str(e)}", "", {} | |
def chunk_markdown( | |
markdown_text_with_images: str, | |
chunk_size: int = 1000, | |
chunk_overlap: int = 200, | |
strip_headers: bool = True | |
) -> List[Document]: | |
"""Chunks markdown text, preserving headers in metadata.""" | |
if not markdown_text_with_images or not markdown_text_with_images.strip(): | |
logger.warning("chunk_markdown received empty input.") | |
return [] | |
headers_to_split_on = [ | |
("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3"), | |
("####", "Header 4"), ("#####", "Header 5"), ("######", "Header 6"), | |
] | |
markdown_splitter = MarkdownHeaderTextSplitter( | |
headers_to_split_on=headers_to_split_on, strip_headers=strip_headers | |
) | |
header_chunks = markdown_splitter.split_text(markdown_text_with_images) | |
if not header_chunks: | |
return [] | |
final_chunks = [] | |
if chunk_size > 0: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, | |
separators=["\n\n", "\n", "(?<=\. )", "(?<=\? )", "(?<=! )", ", ", "; ", " ", ""], | |
add_start_index=True | |
) | |
for i, header_chunk in enumerate(header_chunks): | |
if header_chunk.page_content and len(header_chunk.page_content) > chunk_size: | |
sub_chunks = text_splitter.split_documents([header_chunk]) | |
final_chunks.extend(sub_chunks) | |
elif header_chunk.page_content: | |
final_chunks.append(header_chunk) | |
else: | |
final_chunks = [chunk for chunk in header_chunks if chunk.page_content] | |
for chunk in final_chunks: | |
images_in_chunk = re.findall( | |
r"!\[.*?\]\((data:image/[a-zA-Z+]+;base64,[A-Za-z0-9+/=]+)\)", | |
chunk.page_content | |
) | |
if not hasattr(chunk, 'metadata'): | |
chunk.metadata = {} | |
chunk.metadata["images_base64"] = images_in_chunk | |
return final_chunks | |
def get_hf_token(explicit_token: str = None) -> str: | |
"""Retrieve Hugging Face token with fallback mechanisms.""" | |
global hf_token_global | |
if explicit_token and explicit_token.strip() and explicit_token.startswith('hf_'): | |
return explicit_token.strip() | |
if hf_token_global: | |
return hf_token_global | |
env_token = os.environ.get("HF_TOKEN") | |
if env_token and env_token.startswith('hf_'): | |
hf_token_global = env_token | |
return env_token | |
try: | |
stored_token = huggingface_hub.get_token() | |
if stored_token: | |
hf_token_global = stored_token | |
return stored_token | |
except Exception as e: | |
logger.warning(f"Could not retrieve token from Hugging Face config: {e}") | |
return None | |
def process_file_and_save( | |
file_obj: Any, chunk_size: int, chunk_overlap: int, | |
strip_headers: bool, hf_token: str, repo_name: str | |
) -> str: | |
"""Orchestrates OCR, chunking, and saving to Hugging Face.""" | |
if not file_obj: | |
return "Error: No file uploaded." | |
if not repo_name or '/' not in repo_name: | |
return "Error: Invalid repository name (use 'username/dataset-name')." | |
if chunk_size < 0: | |
chunk_size = 0 | |
if chunk_overlap < 0: | |
chunk_overlap = 0 | |
if chunk_size > 0 and chunk_overlap >= chunk_size: | |
chunk_overlap = min(200, chunk_size // 2) | |
effective_hf_token = get_hf_token(hf_token) | |
if not effective_hf_token: | |
return """Error: No valid Hugging Face token found. | |
Please either: | |
1. Provide a token in the input field (starts with 'hf_') | |
2. Set HF_TOKEN environment variable | |
3. Run `huggingface-cli login` in your terminal""" | |
try: | |
source_filename = getattr(file_obj, 'orig_name', os.path.basename(file_obj.name)) | |
logger.info(f"--- Starting processing for file: {source_filename} ---") | |
processed_markdown, _, _ = perform_ocr_file(file_obj) | |
if not processed_markdown or processed_markdown.startswith("Error:"): | |
return processed_markdown | |
chunks = chunk_markdown(processed_markdown, chunk_size, chunk_overlap, strip_headers) | |
if not chunks: | |
return "Error: Failed to chunk the document." | |
data = { | |
"chunk_id": [f"{source_filename}_chunk_{i}" for i in range(len(chunks))], | |
"text": [chunk.page_content or "" for chunk in chunks], | |
"metadata": [chunk.metadata for chunk in chunks], | |
"source_filename": [source_filename] * len(chunks), | |
} | |
dataset = Dataset.from_dict(data) | |
api = HfApi(token=effective_hf_token) | |
try: | |
user_info = api.whoami() | |
logger.info(f"Authenticated as: {user_info['name']}") | |
except Exception as auth_err: | |
return f"Error: Invalid HF token - authentication failed: {auth_err}" | |
try: | |
api.repo_info(repo_id=repo_name, repo_type="dataset") | |
logger.info(f"Repository '{repo_name}' exists.") | |
except huggingface_hub.utils.RepositoryNotFoundError: | |
api.create_repo(repo_id=repo_name, repo_type="dataset", private=False) | |
logger.info(f"Created repository '{repo_name}'.") | |
dataset.push_to_hub(repo_name, token=effective_hf_token, | |
commit_message=f"Add OCR data from {source_filename}") | |
repo_url = f"https://huggingface.co/datasets/{repo_name}" | |
return f"Success! Dataset with {len(chunks)} chunks saved to: {repo_url}" | |
except huggingface_hub.utils.HfHubHTTPError as hf_http_err: | |
status = getattr(hf_http_err.response, 'status_code', 'Unknown') | |
if status == 401: | |
return "Error: Invalid or unauthorized Hugging Face token." | |
elif status == 403: | |
return "Error: Token lacks write permission." | |
return f"Error: Hugging Face Hub Error (Status {status}): {hf_http_err}" | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}", exc_info=True) | |
return f"Unexpected error: {str(e)}" | |
# --- Gradio Interface --- | |
with gr.Blocks(title="Mistral OCR & Dataset Creator", | |
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as demo: | |
gr.Markdown("# Mistral OCR, Markdown Chunking, and Hugging Face Dataset Creator") | |
gr.Markdown( | |
""" | |
Upload a PDF or image file. The application will: | |
1. Extract text and images using Mistral OCR | |
2. Embed images as base64 data URIs in markdown | |
3. Chunk markdown by headers and optionally character count | |
4. Store embedded images in chunk metadata | |
5. Create/update a Hugging Face Dataset | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input = gr.File( | |
label="Upload PDF or Image File", | |
file_types=['.pdf', '.png', '.jpg', '.jpeg', '.webp', '.bmp'], | |
type="filepath" | |
) | |
gr.Markdown("## Chunking Options") | |
chunk_size = gr.Slider(minimum=0, maximum=8000, value=1000, step=100, | |
label="Max Chunk Size (Characters)") | |
chunk_overlap = gr.Slider(minimum=0, maximum=1000, value=200, step=50, | |
label="Chunk Overlap (Characters)") | |
strip_headers = gr.Checkbox(label="Strip Headers from Content", value=True) | |
gr.Markdown("## Hugging Face Output Options") | |
repo_name = gr.Textbox(label="HF Dataset Repository", | |
placeholder="your-username/your-dataset-name") | |
hf_token = gr.Textbox(label="Hugging Face Token", type="password", | |
placeholder="hf_...") | |
submit_btn = gr.Button("Process and Save", variant="primary") | |
with gr.Column(scale=1): | |
output = gr.Textbox(label="Result Status", lines=20, interactive=False) | |
submit_btn.click( | |
fn=process_file_and_save, | |
inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name], | |
outputs=output | |
) | |
gr.Examples( | |
examples=[ | |
[None, 1000, 200, True, "", "hf-username/my-first-ocr-dataset"], | |
[None, 2000, 400, True, "", "hf-username/large-chunk-ocr-data"], | |
[None, 0, 0, False, "", "hf-username/header-only-ocr-data"], | |
], | |
inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name], | |
outputs=output, | |
fn=process_file_and_save, | |
cache_examples=False | |
) | |
gr.Markdown("*Requires MISTRAL_API_KEY or HF token*") | |
if __name__ == "__main__": | |
initial_token = get_hf_token() | |
if not initial_token and not client: | |
print("\nWARNING: Neither Mistral API key nor HF token found.") | |
print("Set MISTRAL_API_KEY and/or HF_TOKEN, or use `huggingface-cli login`") | |
demo.launch( | |
share=os.getenv('GRADIO_SHARE', 'False').lower() == 'true', | |
debug=True, | |
auth_message="Provide a valid Hugging Face token if prompted" | |
) |