MinerU / mineru_single.py
princhman's picture
final update of the logic
0aedcf3
raw
history blame
3.46 kB
#!/usr/bin/env python3
import os
import uuid
import json
import requests
from loguru import logger
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.io.s3 import S3Writer
from magic_pdf.data.data_reader_writer.base import DataWriter
from inference_svm_model import SVMModel
class Processor:
def __init__(self):
self.s3_writer = S3Writer(
ak=os.getenv("S3_ACCESS_KEY"),
sk=os.getenv("S3_SECRET_KEY"),
bucket=os.getenv("S3_BUCKET_NAME"),
endpoint_url=os.getenv("S3_ENDPOINT"),
)
self.svm_model = SVMModel()
with open("/home/user/magic-pdf.json", "r") as f:
config = json.load(f)
self.layout_mode = config["layout-config"]["model"]
self.formula_enable = config["formula-config"]["enable"]
self.table_enable = config["table-config"]["enable"]
self.language = "en"
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/")
bucket = os.getenv("S3_BUCKET_NAME", "")
self.prefix = f"{endpoint}/{bucket}/document-extracts/"
def process(self, file_url: str, key: str) -> str:
logger.info("Processing file: {}", file_url)
response = requests.get(file_url)
if response.status_code != 200:
raise Exception(f"Failed to download PDF: {file_url}")
pdf_bytes = response.content
dataset = PymuDocDataset(pdf_bytes)
inference = doc_analyze(
dataset,
ocr=True,
lang=self.language,
layout_model=self.layout_mode,
formula_enable=self.formula_enable,
table_enable=self.table_enable
)
image_writer = ImageWriter(self.s3_writer, self.svm_model)
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
md_content = pipe_result.get_markdown(self.prefix + key + "/")
# Remove references to images classified as "irrelevant"
final_markdown = image_writer.remove_redundant_images(md_content)
return final_markdown
class ImageWriter(DataWriter):
"""
Receives each extracted image. Classifies it, uploads if relevant, or flags
it for removal if irrelevant.
"""
def __init__(self, s3_writer: S3Writer, svm_model: SVMModel):
self.s3_writer = s3_writer
self.svm_model = svm_model
self._redundant_images_paths = []
def write(self, path: str, data: bytes) -> None:
label_str = self.svm_model.classify_image(data)
if label_str == 1:
# Upload to S3
self.s3_writer.write(path, data)
else:
self._redundant_images_paths.append(path)
def remove_redundant_images(self, md_content: str) -> str:
for path in self._redundant_images_paths:
md_content = md_content.replace(f"![]({path})", "")
return md_content
if __name__ == "__main__":
processor = Processor()
single_url = "https://example.com/somefile.pdf"
markdown_result = processor.process(single_url)
print("Single file Markdown:\n", markdown_result)
multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"]
batch_results = processor.process_batch(multiple_urls)
print("Batch results:", batch_results)