|
|
|
import os
|
|
import uuid
|
|
import json
|
|
import requests
|
|
from loguru import logger
|
|
|
|
from magic_pdf.data.dataset import PymuDocDataset
|
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
|
from magic_pdf.data.io.s3 import S3Writer
|
|
from magic_pdf.data.data_reader_writer.base import DataWriter
|
|
|
|
from inference_svm_model import SVMModel
|
|
|
|
class Processor:
|
|
def __init__(self):
|
|
self.s3_writer = S3Writer(
|
|
ak=os.getenv("S3_ACCESS_KEY"),
|
|
sk=os.getenv("S3_SECRET_KEY"),
|
|
bucket=os.getenv("S3_BUCKET_NAME"),
|
|
endpoint_url=os.getenv("S3_ENDPOINT"),
|
|
)
|
|
|
|
self.svm_model = SVMModel()
|
|
|
|
with open("/home/user/magic-pdf.json", "r") as f:
|
|
config = json.load(f)
|
|
|
|
self.layout_mode = config["layout-config"]["model"]
|
|
self.formula_enable = config["formula-config"]["enable"]
|
|
self.table_enable = config["table-config"]["enable"]
|
|
self.language = "en"
|
|
|
|
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/")
|
|
bucket = os.getenv("S3_BUCKET_NAME", "")
|
|
self.prefix = f"{endpoint}/{bucket}/document-extracts/"
|
|
|
|
def process(self, file_url: str, key: str) -> str:
|
|
logger.info("Processing file: {}", file_url)
|
|
response = requests.get(file_url)
|
|
if response.status_code != 200:
|
|
raise Exception(f"Failed to download PDF: {file_url}")
|
|
pdf_bytes = response.content
|
|
|
|
dataset = PymuDocDataset(pdf_bytes)
|
|
inference = doc_analyze(
|
|
dataset,
|
|
ocr=True,
|
|
lang=self.language,
|
|
layout_model=self.layout_mode,
|
|
formula_enable=self.formula_enable,
|
|
table_enable=self.table_enable
|
|
)
|
|
|
|
image_writer = ImageWriter(self.s3_writer, self.svm_model)
|
|
|
|
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
|
|
|
|
md_content = pipe_result.get_markdown(self.prefix + key + "/")
|
|
|
|
|
|
final_markdown = image_writer.remove_redundant_images(md_content)
|
|
return final_markdown
|
|
|
|
class ImageWriter(DataWriter):
|
|
"""
|
|
Receives each extracted image. Classifies it, uploads if relevant, or flags
|
|
it for removal if irrelevant.
|
|
"""
|
|
def __init__(self, s3_writer: S3Writer, svm_model: SVMModel):
|
|
self.s3_writer = s3_writer
|
|
self.svm_model = svm_model
|
|
self._redundant_images_paths = []
|
|
|
|
def write(self, path: str, data: bytes) -> None:
|
|
label_str = self.svm_model.classify_image(data)
|
|
|
|
if label_str == 1:
|
|
|
|
self.s3_writer.write(path, data)
|
|
else:
|
|
self._redundant_images_paths.append(path)
|
|
|
|
def remove_redundant_images(self, md_content: str) -> str:
|
|
for path in self._redundant_images_paths:
|
|
md_content = md_content.replace(f"", "")
|
|
return md_content
|
|
|
|
if __name__ == "__main__":
|
|
processor = Processor()
|
|
|
|
single_url = "https://example.com/somefile.pdf"
|
|
markdown_result = processor.process(single_url)
|
|
print("Single file Markdown:\n", markdown_result)
|
|
|
|
multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"]
|
|
batch_results = processor.process_batch(multiple_urls)
|
|
print("Batch results:", batch_results) |