|
|
|
import os
|
|
import uuid
|
|
import json
|
|
import requests
|
|
from loguru import logger
|
|
|
|
from magic_pdf.data.dataset import PymuDocDataset
|
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
|
from magic_pdf.data.io.s3 import S3Writer
|
|
from magic_pdf.data.data_reader_writer.base import DataWriter
|
|
|
|
from inference_svm_model import load_svm_model, classify_image
|
|
|
|
class Processor:
|
|
def __init__(self):
|
|
self.s3_writer = S3Writer(
|
|
ak=os.getenv("S3_ACCESS_KEY"),
|
|
sk=os.getenv("S3_SECRET_KEY"),
|
|
bucket=os.getenv("S3_BUCKET_NAME"),
|
|
endpoint_url=os.getenv("S3_ENDPOINT"),
|
|
)
|
|
|
|
model_path = os.getenv("SVM_MODEL_PATH", "./model_classification/svm_model.joblib")
|
|
self.svm_model = load_svm_model(model_path)
|
|
self.label_map = {0: "irrelevant", 1: "relevant"}
|
|
|
|
with open("/home/user/magic-pdf.json", "r") as f:
|
|
config = json.load(f)
|
|
|
|
self.layout_mode = config["layout-config"]["model"]
|
|
self.formula_enable = config["formula-config"]["enable"]
|
|
self.table_enable = config["table-config"]["enable"]
|
|
self.language = "en"
|
|
|
|
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/")
|
|
bucket = os.getenv("S3_BUCKET_NAME", "")
|
|
self.prefix = f"{endpoint}/{bucket}/document-extracts/"
|
|
|
|
def process(self, file_url: str) -> str:
|
|
logger.info("Processing file: {}", file_url)
|
|
response = requests.get(file_url)
|
|
if response.status_code != 200:
|
|
raise Exception(f"Failed to download PDF: {file_url}")
|
|
pdf_bytes = response.content
|
|
|
|
dataset = PymuDocDataset(pdf_bytes)
|
|
inference = doc_analyze(
|
|
dataset,
|
|
ocr=True,
|
|
lang=self.language,
|
|
layout_model=self.layout_mode,
|
|
formula_enable=self.formula_enable,
|
|
table_enable=self.table_enable
|
|
)
|
|
|
|
image_writer = ImageWriter(self.s3_writer, self.svm_model, self.label_map)
|
|
|
|
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
|
|
|
|
folder_name = str(uuid.uuid4())
|
|
md_content = pipe_result.get_markdown(self.prefix + folder_name + "/")
|
|
|
|
|
|
final_markdown = image_writer.remove_redundant_images(md_content)
|
|
return final_markdown
|
|
|
|
def process_batch(self, file_urls: list[str]) -> dict:
|
|
results = {}
|
|
for url in file_urls:
|
|
try:
|
|
md = self.process(url)
|
|
results[url] = md
|
|
except Exception as e:
|
|
results[url] = f"Error: {str(e)}"
|
|
return results
|
|
|
|
class ImageWriter(DataWriter):
|
|
"""
|
|
Receives each extracted image. Classifies it, uploads if relevant, or flags
|
|
it for removal if irrelevant.
|
|
"""
|
|
def __init__(self, s3_writer: S3Writer, svm_model, label_map):
|
|
self.s3_writer = s3_writer
|
|
self.svm_model = svm_model
|
|
self.label_map = label_map
|
|
self._redundant_images_paths = []
|
|
|
|
def write(self, path: str, data: bytes) -> None:
|
|
import tempfile
|
|
import os
|
|
import uuid
|
|
|
|
tmp_name = f"{uuid.uuid4()}.jpg"
|
|
tmp_path = os.path.join(tempfile.gettempdir(), tmp_name)
|
|
with open(tmp_path, "wb") as f:
|
|
f.write(data)
|
|
|
|
label_str = classify_image(tmp_path, self.svm_model, self.label_map)
|
|
|
|
os.remove(tmp_path)
|
|
|
|
if label_str == "relevant":
|
|
|
|
self.s3_writer.write(path, data)
|
|
else:
|
|
self._redundant_images_paths.append(path)
|
|
|
|
def remove_redundant_images(self, md_content: str) -> str:
|
|
for path in self._redundant_images_paths:
|
|
md_content = md_content.replace(f"", "")
|
|
return md_content
|
|
|
|
if __name__ == "__main__":
|
|
processor = Processor()
|
|
|
|
single_url = "https://example.com/somefile.pdf"
|
|
markdown_result = processor.process(single_url)
|
|
print("Single file Markdown:\n", markdown_result)
|
|
|
|
multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"]
|
|
batch_results = processor.process_batch(multiple_urls)
|
|
print("Batch results:", batch_results) |