File size: 3,462 Bytes
11551ca
 
 
b273357
a8b6881
b273357
 
a8b6881
 
 
b273357
a8b6881
0aedcf3
a8b6881
 
 
 
 
 
 
 
11551ca
b273357
0aedcf3
a8b6881
 
 
b273357
a8b6881
 
 
 
b273357
 
 
 
 
0aedcf3
b273357
 
a8b6881
b273357
a8b6881
 
 
b273357
 
 
 
 
 
 
 
 
0aedcf3
b273357
 
a8b6881
0aedcf3
b273357
 
 
 
 
a8b6881
b273357
 
 
 
0aedcf3
b273357
 
a8b6881
 
 
0aedcf3
b273357
0aedcf3
b273357
 
11551ca
a8b6881
11551ca
b273357
a8b6881
 
 
11551ca
a8b6881
 
b273357
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
import os
import uuid
import json
import requests
from loguru import logger

from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.io.s3 import S3Writer
from magic_pdf.data.data_reader_writer.base import DataWriter

from inference_svm_model import SVMModel

class Processor:
    def __init__(self):
        self.s3_writer = S3Writer(
            ak=os.getenv("S3_ACCESS_KEY"),
            sk=os.getenv("S3_SECRET_KEY"),
            bucket=os.getenv("S3_BUCKET_NAME"),
            endpoint_url=os.getenv("S3_ENDPOINT"),
        )

        self.svm_model = SVMModel()

        with open("/home/user/magic-pdf.json", "r") as f:
            config = json.load(f)

        self.layout_mode = config["layout-config"]["model"]
        self.formula_enable = config["formula-config"]["enable"]
        self.table_enable = config["table-config"]["enable"]
        self.language = "en"

        endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/")
        bucket = os.getenv("S3_BUCKET_NAME", "")
        self.prefix = f"{endpoint}/{bucket}/document-extracts/"

    def process(self, file_url: str, key: str) -> str:
        logger.info("Processing file: {}", file_url)
        response = requests.get(file_url)
        if response.status_code != 200:
            raise Exception(f"Failed to download PDF: {file_url}")
        pdf_bytes = response.content

        dataset = PymuDocDataset(pdf_bytes)
        inference = doc_analyze(
            dataset,
            ocr=True,
            lang=self.language,
            layout_model=self.layout_mode,
            formula_enable=self.formula_enable,
            table_enable=self.table_enable
        )

        image_writer = ImageWriter(self.s3_writer, self.svm_model)

        pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)

        md_content = pipe_result.get_markdown(self.prefix + key + "/")

        # Remove references to images classified as "irrelevant"
        final_markdown = image_writer.remove_redundant_images(md_content)
        return final_markdown

class ImageWriter(DataWriter):
    """

    Receives each extracted image. Classifies it, uploads if relevant, or flags

    it for removal if irrelevant.

    """
    def __init__(self, s3_writer: S3Writer, svm_model: SVMModel):
        self.s3_writer = s3_writer
        self.svm_model = svm_model
        self._redundant_images_paths = []

    def write(self, path: str, data: bytes) -> None:
        label_str = self.svm_model.classify_image(data)

        if label_str == 1:
            # Upload to S3
            self.s3_writer.write(path, data)
        else:
            self._redundant_images_paths.append(path)

    def remove_redundant_images(self, md_content: str) -> str:
        for path in self._redundant_images_paths:
            md_content = md_content.replace(f"![]({path})", "")
        return md_content

if __name__ == "__main__":
    processor = Processor()

    single_url = "https://example.com/somefile.pdf"
    markdown_result = processor.process(single_url)
    print("Single file Markdown:\n", markdown_result)

    multiple_urls = ["https://example.com/file1.pdf", "https://example.com/file2.pdf"]
    batch_results = processor.process_batch(multiple_urls)
    print("Batch results:", batch_results)