|
|
|
import os |
|
import uuid |
|
import json |
|
import requests |
|
import logging |
|
import torch |
|
import gc |
|
|
|
from magic_pdf.data.dataset import PymuDocDataset |
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze |
|
from magic_pdf.data.io.s3 import S3Writer |
|
from magic_pdf.data.data_reader_writer.base import DataWriter |
|
|
|
from inference_svm_model import SVMModel |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class Processor: |
|
def __init__(self): |
|
try: |
|
self.s3_writer = S3Writer( |
|
ak=os.getenv("S3_ACCESS_KEY"), |
|
sk=os.getenv("S3_SECRET_KEY"), |
|
bucket=os.getenv("S3_BUCKET_NAME"), |
|
endpoint_url=os.getenv("S3_ENDPOINT"), |
|
) |
|
|
|
self.svm_model = SVMModel() |
|
logger.info("Classification model initialized successfully") |
|
|
|
with open("/home/user/magic-pdf.json", "r") as f: |
|
config = json.load(f) |
|
|
|
self.layout_mode = "doclayout_yolo" |
|
|
|
|
|
self.formula_enable = config["formula-config"]["enable"] |
|
self.table_enable = config["table-config"]["enable"] |
|
self.language = "en" |
|
|
|
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/") |
|
bucket = os.getenv("S3_BUCKET_NAME", "") |
|
self.prefix = f"{endpoint}/{bucket}/document-extracts/" |
|
|
|
logger.info("Processor initialized successfully") |
|
except Exception as e: |
|
logger.error("Failed to initialize Processor: %s", str(e)) |
|
raise |
|
|
|
def cleanup_gpu(self): |
|
""" |
|
Releases GPU memory, use garbage collection to clear PyTorch's CUDA cache. |
|
This helps prevent VRAM accumulation. |
|
""" |
|
try: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
logger.info("GPU memory cleaned up.") |
|
except Exception as e: |
|
logger.error("Error during GPU cleanup: %s", e) |
|
|
|
def process(self, file_url: str, key: str) -> str: |
|
""" |
|
Process a single PDF, returning final Markdown with irrelevant images removed. |
|
""" |
|
logger.info("Processing file: %s", file_url) |
|
|
|
try: |
|
response = requests.get(file_url) |
|
if response.status_code != 200: |
|
logger.error("Failed to download PDF from %s. Status code: %d", file_url, response.status_code) |
|
raise Exception(f"Failed to download PDF: {file_url}") |
|
|
|
pdf_bytes = response.content |
|
logger.info("Downloaded %d bytes for file_url='%s'", len(pdf_bytes), file_url) |
|
|
|
|
|
dataset = PymuDocDataset(pdf_bytes) |
|
inference = doc_analyze( |
|
dataset, |
|
ocr=True, |
|
lang=self.language, |
|
layout_model=self.layout_mode, |
|
formula_enable=self.formula_enable, |
|
table_enable=self.table_enable |
|
) |
|
logger.info("doc_analyze complete for key='%s'. Started extracting images...", key) |
|
|
|
|
|
image_writer = ImageWriter(self.s3_writer, self.svm_model) |
|
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language) |
|
logger.info("OCR pipeline completed for key='%s'.", key) |
|
|
|
md_content = pipe_result.get_markdown(self.prefix + key + "/") |
|
final_markdown = image_writer.remove_redundant_images(md_content) |
|
logger.info("Completed PDF process for key='%s'. Final MD length=%d", key, len(final_markdown)) |
|
return final_markdown |
|
finally: |
|
|
|
self.cleanup_gpu() |
|
|
|
class ImageWriter(DataWriter): |
|
""" |
|
Receives each extracted image. Classifies it, uploads if relevant, or flags |
|
it for removal if irrelevant. |
|
""" |
|
def __init__(self, s3_writer: S3Writer, svm_model: SVMModel): |
|
self.s3_writer = s3_writer |
|
self.svm_model = svm_model |
|
self._redundant_images_paths = [] |
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
""" |
|
Called for each extracted image. If relevant, upload to S3; otherwise mark for removal. |
|
""" |
|
label_str = self.svm_model.classify_image(data) |
|
|
|
if label_str == 1: |
|
self.s3_writer.write(path, data) |
|
else: |
|
self._redundant_images_paths.append(path) |
|
|
|
def remove_redundant_images(self, md_content: str) -> str: |
|
for path in self._redundant_images_paths: |
|
md_content = md_content.replace(f"", "") |
|
return md_content |
|
|
|
if __name__ == "__main__": |
|
processor = Processor() |
|
|
|
single_url = "https://example.com/somefile.pdf" |
|
markdown_result = processor.process(single_url) |
|
print("Single file Markdown:\n", markdown_result) |
|
|
|
multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"] |
|
batch_results = processor.process_batch(multiple_urls) |
|
print("Batch results:", batch_results) |