MinerU / mineru_single.py
princhman's picture
Fix path errors
ae9e9b0
raw
history blame
4.37 kB
#!/usr/bin/env python3
import os
import uuid
import json
import requests
from loguru import logger
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.io.s3 import S3Writer
from magic_pdf.data.data_reader_writer.base import DataWriter
from inference_svm_model import load_svm_model, classify_image
class Processor:
def __init__(self):
self.s3_writer = S3Writer(
ak=os.getenv("S3_ACCESS_KEY"),
sk=os.getenv("S3_SECRET_KEY"),
bucket=os.getenv("S3_BUCKET_NAME"),
endpoint_url=os.getenv("S3_ENDPOINT"),
)
model_path = os.getenv("SVM_MODEL_PATH", "./model_classification/svm_model.joblib")
self.svm_model = load_svm_model(model_path)
self.label_map = {0: "irrelevant", 1: "relevant"}
with open("/home/user/magic-pdf.json", "r") as f:
config = json.load(f)
self.layout_mode = config["layout-config"]["model"]
self.formula_enable = config["formula-config"]["enable"]
self.table_enable = config["table-config"]["enable"]
self.language = "en"
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/")
bucket = os.getenv("S3_BUCKET_NAME", "")
self.prefix = f"{endpoint}/{bucket}/document-extracts/"
def process(self, file_url: str) -> str:
logger.info("Processing file: {}", file_url)
response = requests.get(file_url)
if response.status_code != 200:
raise Exception(f"Failed to download PDF: {file_url}")
pdf_bytes = response.content
dataset = PymuDocDataset(pdf_bytes)
inference = doc_analyze(
dataset,
ocr=True,
lang=self.language,
layout_model=self.layout_mode,
formula_enable=self.formula_enable,
table_enable=self.table_enable
)
image_writer = ImageWriter(self.s3_writer, self.svm_model, self.label_map)
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
folder_name = str(uuid.uuid4())
md_content = pipe_result.get_markdown(self.prefix + folder_name + "/")
# Remove references to images classified as "irrelevant"
final_markdown = image_writer.remove_redundant_images(md_content)
return final_markdown
def process_batch(self, file_urls: list[str]) -> dict:
results = {}
for url in file_urls:
try:
md = self.process(url)
results[url] = md
except Exception as e:
results[url] = f"Error: {str(e)}"
return results
class ImageWriter(DataWriter):
"""
Receives each extracted image. Classifies it, uploads if relevant, or flags
it for removal if irrelevant.
"""
def __init__(self, s3_writer: S3Writer, svm_model, label_map):
self.s3_writer = s3_writer
self.svm_model = svm_model
self.label_map = label_map
self._redundant_images_paths = []
def write(self, path: str, data: bytes) -> None:
import tempfile
import os
import uuid
tmp_name = f"{uuid.uuid4()}.jpg"
tmp_path = os.path.join(tempfile.gettempdir(), tmp_name)
with open(tmp_path, "wb") as f:
f.write(data)
label_str = classify_image(tmp_path, self.svm_model, self.label_map)
os.remove(tmp_path)
if label_str == "relevant":
# Upload to S3
self.s3_writer.write(path, data)
else:
self._redundant_images_paths.append(path)
def remove_redundant_images(self, md_content: str) -> str:
for path in self._redundant_images_paths:
md_content = md_content.replace(f"![]({path})", "")
return md_content
if __name__ == "__main__":
processor = Processor()
single_url = "https://example.com/somefile.pdf"
markdown_result = processor.process(single_url)
print("Single file Markdown:\n", markdown_result)
multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"]
batch_results = processor.process_batch(multiple_urls)
print("Batch results:", batch_results)