Spaces:
Running
Running
import gradio as gr | |
from langchain.document_loaders import PyMuPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from typing import List | |
import logging | |
from pathlib import Path | |
import requests | |
import base64 | |
import io | |
import fitz | |
from PIL import Image | |
from datasets import Dataset | |
from huggingface_hub import HfApi | |
import os | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Original chunk_pdf function (slightly modified for Gradio) | |
def chunk_pdf( | |
file_path: str, | |
chunk_size: int = 1000, | |
chunk_overlap: int = 200, | |
encoding: str = "utf-8", | |
preserve_numbering: bool = True | |
) -> List[Document]: | |
if chunk_size <= 0: | |
raise ValueError("chunk_size must be positive") | |
if chunk_overlap < 0: | |
raise ValueError("chunk_overlap cannot be negative") | |
if chunk_overlap >= chunk_size: | |
raise ValueError("chunk_overlap must be less than chunk_size") | |
try: | |
temp_file = None | |
if file_path.startswith(("http://", "https://")): | |
logger.info(f"Downloading PDF from {file_path}") | |
response = requests.get(file_path, stream=True, timeout=10) | |
response.raise_for_status() | |
temp_file = Path("temp.pdf") | |
with open(temp_file, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
file_path = str(temp_file) | |
elif not Path(file_path).exists(): | |
raise FileNotFoundError(f"PDF file not found at: {file_path}") | |
logger.info(f"Loading PDF from {file_path}") | |
loader = PyMuPDFLoader(file_path) | |
pages = loader.load() | |
if not pages: | |
logger.warning(f"No content extracted from {file_path}") | |
return [] | |
separators = ( | |
["\n\d+\.\s+", "\n\n", "\n", ".", " ", ""] | |
if preserve_numbering | |
else ["\n\n", "\n", ".", " ", ""] | |
) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
length_function=len, | |
separators=separators, | |
keep_separator=True, | |
add_start_index=True, | |
is_separator_regex=preserve_numbering | |
) | |
logger.info(f"Splitting {len(pages)} pages into chunks") | |
chunks = text_splitter.split_documents(pages) | |
if preserve_numbering: | |
merged_chunks = [] | |
current_chunk = None | |
for chunk in chunks: | |
content = chunk.page_content.strip() | |
if current_chunk is None: | |
current_chunk = chunk | |
elif content.startswith(tuple(f"{i}." for i in range(10))): | |
if current_chunk: | |
merged_chunks.append(current_chunk) | |
current_chunk = chunk | |
else: | |
current_chunk.page_content += "\n" + content | |
current_chunk.metadata["end_index"] = chunk.metadata["start_index"] + len(content) | |
if current_chunk: | |
merged_chunks.append(current_chunk) | |
chunks = merged_chunks | |
logger.info(f"Created {len(chunks)} chunks") | |
return chunks | |
except Exception as e: | |
logger.error(f"Error processing PDF {file_path}: {str(e)}") | |
raise | |
finally: | |
if temp_file and temp_file.exists(): | |
temp_file.unlink() | |
# Custom function to convert PDF page to base64 | |
def pdf_page_to_base64(pdf_path: str, page_number: int): | |
pdf_document = fitz.open(pdf_path) | |
page = pdf_document.load_page(page_number - 1) # input is one-indexed | |
pix = page.get_pixmap() | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
buffer = io.BytesIO() | |
img.save(buffer, format="PNG") | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
# Function to process PDF and create dataset | |
def process_pdf_and_save(pdf_file, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name): | |
try: | |
# Save uploaded file temporarily | |
pdf_path = pdf_file.name | |
chunks = chunk_pdf(pdf_path, chunk_size, chunk_overlap, "utf-8", preserve_numbering) | |
# Prepare dataset | |
data = { | |
"chunk_id": [], | |
"content": [], | |
"metadata": [], | |
"page_image": [] | |
} | |
for i, chunk in enumerate(chunks): | |
data["chunk_id"].append(i) | |
data["content"].append(chunk.page_content) | |
data["metadata"].append(chunk.metadata) | |
page_num = chunk.metadata.get("page", 1) | |
img_base64 = pdf_page_to_base64(pdf_path, page_num) | |
data["page_image"].append(img_base64) | |
# Create Hugging Face dataset | |
dataset = Dataset.from_dict(data) | |
# Push to Hugging Face | |
api = HfApi() | |
api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True) | |
dataset.push_to_hub(repo_name, token=hf_token) | |
return f"Dataset created with {len(chunks)} chunks and saved to Hugging Face at {repo_name}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Gradio Interface | |
with gr.Blocks(title="PDF Chunking and Dataset Creator") as demo: | |
gr.Markdown("# PDF Chunking and Dataset Creator") | |
gr.Markdown("Upload a PDF, configure chunking parameters, and save the dataset to Hugging Face.") | |
with gr.Row(): | |
with gr.Column(): | |
pdf_input = gr.File(label="Upload PDF") | |
chunk_size = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size") | |
chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap") | |
preserve_numbering = gr.Checkbox(label="Preserve Numbering", value=True) | |
hf_token = gr.Textbox(label="Hugging Face Token", type="password") | |
repo_name = gr.Textbox(label="Hugging Face Repository Name (e.g., username/dataset-name)") | |
submit_btn = gr.Button("Process and Save") | |
with gr.Column(): | |
output = gr.Textbox(label="Result") | |
submit_btn.click( | |
fn=process_pdf_and_save, | |
inputs=[pdf_input, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name], | |
outputs=output | |
) | |
demo.launch( | |
share=True, | |
) |