MinerU / mineru_single.py
SkyNait's picture
Update mineru_single.py
c0ed175 verified
raw
history blame
5.29 kB
#!/usr/bin/env python3
import os
import uuid
import json
import requests
import logging
import torch
import gc
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.io.s3 import S3Writer
from magic_pdf.data.data_reader_writer.base import DataWriter
from inference_svm_model import SVMModel
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
class Processor:
def __init__(self):
try:
self.s3_writer = S3Writer(
ak=os.getenv("S3_ACCESS_KEY"),
sk=os.getenv("S3_SECRET_KEY"),
bucket=os.getenv("S3_BUCKET_NAME"),
endpoint_url=os.getenv("S3_ENDPOINT"),
)
self.svm_model = SVMModel()
logger.info("Classification model initialized successfully")
with open("/home/user/magic-pdf.json", "r") as f:
config = json.load(f)
self.layout_mode = "doclayout_yolo"
# self.layout_mode = config["layout-config"]["model"]
self.formula_enable = config["formula-config"]["enable"]
self.table_enable = config["table-config"]["enable"]
self.language = "en"
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/")
bucket = os.getenv("S3_BUCKET_NAME", "")
self.prefix = f"{endpoint}/{bucket}/document-extracts/"
logger.info("Processor initialized successfully")
except Exception as e:
logger.error("Failed to initialize Processor: %s", str(e))
raise
def cleanup_gpu(self):
"""
Releases GPU memory, use garbage collection to clear PyTorch's CUDA cache.
This helps prevent VRAM accumulation.
"""
try:
gc.collect() #garbage collection
torch.cuda.empty_cache() # Clear memory cache on GPU
logger.info("GPU memory cleaned up.")
except Exception as e:
logger.error("Error during GPU cleanup: %s", e)
def process(self, file_url: str, key: str) -> str:
"""
Process a single PDF, returning final Markdown with irrelevant images removed.
"""
logger.info("Processing file: %s", file_url)
try:
response = requests.get(file_url)
if response.status_code != 200:
logger.error("Failed to download PDF from %s. Status code: %d", file_url, response.status_code)
raise Exception(f"Failed to download PDF: {file_url}")
pdf_bytes = response.content
logger.info("Downloaded %d bytes for file_url='%s'", len(pdf_bytes), file_url)
# Analyze PDF with OCR
dataset = PymuDocDataset(pdf_bytes)
inference = doc_analyze(
dataset,
ocr=True,
lang=self.language,
layout_model=self.layout_mode,
formula_enable=self.formula_enable,
table_enable=self.table_enable
)
logger.info("doc_analyze complete for key='%s'. Started extracting images...", key)
# Classify images and remove irrelevant ones
image_writer = ImageWriter(self.s3_writer, self.svm_model)
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
logger.info("OCR pipeline completed for key='%s'.", key)
md_content = pipe_result.get_markdown(self.prefix + key + "/")
final_markdown = image_writer.remove_redundant_images(md_content)
logger.info("Completed PDF process for key='%s'. Final MD length=%d", key, len(final_markdown))
return final_markdown
finally:
# GPU memory is cleaned up after each processing.
self.cleanup_gpu()
class ImageWriter(DataWriter):
"""
Receives each extracted image. Classifies it, uploads if relevant, or flags
it for removal if irrelevant.
"""
def __init__(self, s3_writer: S3Writer, svm_model: SVMModel):
self.s3_writer = s3_writer
self.svm_model = svm_model
self._redundant_images_paths = []
def write(self, path: str, data: bytes) -> None:
"""
Called for each extracted image. If relevant, upload to S3; otherwise mark for removal.
"""
label_str = self.svm_model.classify_image(data)
if label_str == 1:
self.s3_writer.write(path, data)
else:
self._redundant_images_paths.append(path)
def remove_redundant_images(self, md_content: str) -> str:
for path in self._redundant_images_paths:
md_content = md_content.replace(f"![]({path})", "")
return md_content
if __name__ == "__main__":
processor = Processor()
single_url = "https://example.com/somefile.pdf"
markdown_result = processor.process(single_url)
print("Single file Markdown:\n", markdown_result)
multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"]
batch_results = processor.process_batch(multiple_urls)
print("Batch results:", batch_results)