File size: 3,462 Bytes
11551ca b273357 a8b6881 b273357 a8b6881 b273357 a8b6881 0aedcf3 a8b6881 11551ca b273357 0aedcf3 a8b6881 b273357 a8b6881 b273357 0aedcf3 b273357 a8b6881 b273357 a8b6881 b273357 0aedcf3 b273357 a8b6881 0aedcf3 b273357 a8b6881 b273357 0aedcf3 b273357 a8b6881 0aedcf3 b273357 0aedcf3 b273357 11551ca a8b6881 11551ca b273357 a8b6881 11551ca a8b6881 b273357 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
#!/usr/bin/env python3
import os
import uuid
import json
import requests
from loguru import logger
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.io.s3 import S3Writer
from magic_pdf.data.data_reader_writer.base import DataWriter
from inference_svm_model import SVMModel
class Processor:
def __init__(self):
self.s3_writer = S3Writer(
ak=os.getenv("S3_ACCESS_KEY"),
sk=os.getenv("S3_SECRET_KEY"),
bucket=os.getenv("S3_BUCKET_NAME"),
endpoint_url=os.getenv("S3_ENDPOINT"),
)
self.svm_model = SVMModel()
with open("/home/user/magic-pdf.json", "r") as f:
config = json.load(f)
self.layout_mode = config["layout-config"]["model"]
self.formula_enable = config["formula-config"]["enable"]
self.table_enable = config["table-config"]["enable"]
self.language = "en"
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/")
bucket = os.getenv("S3_BUCKET_NAME", "")
self.prefix = f"{endpoint}/{bucket}/document-extracts/"
def process(self, file_url: str, key: str) -> str:
logger.info("Processing file: {}", file_url)
response = requests.get(file_url)
if response.status_code != 200:
raise Exception(f"Failed to download PDF: {file_url}")
pdf_bytes = response.content
dataset = PymuDocDataset(pdf_bytes)
inference = doc_analyze(
dataset,
ocr=True,
lang=self.language,
layout_model=self.layout_mode,
formula_enable=self.formula_enable,
table_enable=self.table_enable
)
image_writer = ImageWriter(self.s3_writer, self.svm_model)
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
md_content = pipe_result.get_markdown(self.prefix + key + "/")
# Remove references to images classified as "irrelevant"
final_markdown = image_writer.remove_redundant_images(md_content)
return final_markdown
class ImageWriter(DataWriter):
"""
Receives each extracted image. Classifies it, uploads if relevant, or flags
it for removal if irrelevant.
"""
def __init__(self, s3_writer: S3Writer, svm_model: SVMModel):
self.s3_writer = s3_writer
self.svm_model = svm_model
self._redundant_images_paths = []
def write(self, path: str, data: bytes) -> None:
label_str = self.svm_model.classify_image(data)
if label_str == 1:
# Upload to S3
self.s3_writer.write(path, data)
else:
self._redundant_images_paths.append(path)
def remove_redundant_images(self, md_content: str) -> str:
for path in self._redundant_images_paths:
md_content = md_content.replace(f"", "")
return md_content
if __name__ == "__main__":
processor = Processor()
single_url = "https://example.com/somefile.pdf"
markdown_result = processor.process(single_url)
print("Single file Markdown:\n", markdown_result)
multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"]
batch_results = processor.process_batch(multiple_urls)
print("Batch results:", batch_results) |