#!/usr/bin/env python3 import os import uuid import json import requests from loguru import logger from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.data.io.s3 import S3Writer from magic_pdf.data.data_reader_writer.base import DataWriter from inference_svm_model import load_svm_model, classify_image class Processor: def __init__(self): self.s3_writer = S3Writer( ak=os.getenv("S3_ACCESS_KEY"), sk=os.getenv("S3_SECRET_KEY"), bucket=os.getenv("S3_BUCKET_NAME"), endpoint_url=os.getenv("S3_ENDPOINT"), ) model_path = os.getenv("SVM_MODEL_PATH", "./model_classification/svm_model.joblib") self.svm_model = load_svm_model(model_path) self.label_map = {0: "irrelevant", 1: "relevant"} with open("/home/user/magic-pdf.json", "r") as f: config = json.load(f) self.layout_mode = config["layout-config"]["model"] self.formula_enable = config["formula-config"]["enable"] self.table_enable = config["table-config"]["enable"] self.language = "en" endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/") bucket = os.getenv("S3_BUCKET_NAME", "") self.prefix = f"{endpoint}/{bucket}/document-extracts/" def process(self, file_url: str) -> str: logger.info("Processing file: {}", file_url) response = requests.get(file_url) if response.status_code != 200: raise Exception(f"Failed to download PDF: {file_url}") pdf_bytes = response.content dataset = PymuDocDataset(pdf_bytes) inference = doc_analyze( dataset, ocr=True, lang=self.language, layout_model=self.layout_mode, formula_enable=self.formula_enable, table_enable=self.table_enable ) image_writer = ImageWriter(self.s3_writer, self.svm_model, self.label_map) pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language) folder_name = str(uuid.uuid4()) md_content = pipe_result.get_markdown(self.prefix + folder_name + "/") # Remove references to images classified as "irrelevant" final_markdown = image_writer.remove_redundant_images(md_content) return final_markdown def process_batch(self, file_urls: list[str]) -> dict: results = {} for url in file_urls: try: md = self.process(url) results[url] = md except Exception as e: results[url] = f"Error: {str(e)}" return results class ImageWriter(DataWriter): """ Receives each extracted image. Classifies it, uploads if relevant, or flags it for removal if irrelevant. """ def __init__(self, s3_writer: S3Writer, svm_model, label_map): self.s3_writer = s3_writer self.svm_model = svm_model self.label_map = label_map self._redundant_images_paths = [] def write(self, path: str, data: bytes) -> None: import tempfile import os import uuid tmp_name = f"{uuid.uuid4()}.jpg" tmp_path = os.path.join(tempfile.gettempdir(), tmp_name) with open(tmp_path, "wb") as f: f.write(data) label_str = classify_image(tmp_path, self.svm_model, self.label_map) os.remove(tmp_path) if label_str == "relevant": # Upload to S3 self.s3_writer.write(path, data) else: self._redundant_images_paths.append(path) def remove_redundant_images(self, md_content: str) -> str: for path in self._redundant_images_paths: md_content = md_content.replace(f"![]({path})", "") return md_content if __name__ == "__main__": processor = Processor() single_url = "https://example.com/somefile.pdf" markdown_result = processor.process(single_url) print("Single file Markdown:\n", markdown_result) multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"] batch_results = processor.process_batch(multiple_urls) print("Batch results:", batch_results)