|
|
|
import os |
|
import uuid |
|
import json |
|
import requests |
|
import logging |
|
import torch |
|
import gc |
|
from magic_pdf.data.dataset import PymuDocDataset |
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze |
|
from magic_pdf.data.io.s3 import S3Writer |
|
from magic_pdf.data.data_reader_writer.base import DataWriter |
|
from inference_svm_model import SVMModel |
|
import concurrent.futures |
|
import boto3 |
|
from io import BytesIO |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
|
handlers=[ |
|
logging.StreamHandler(), |
|
logging.FileHandler('mineru.log') |
|
] |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
class Processor: |
|
def __init__(self): |
|
try: |
|
self.s3_writer = s3Writer( |
|
ak=os.getenv("S3_ACCESS_KEY"), |
|
sk=os.getenv("S3_SECRET_KEY"), |
|
bucket=os.getenv("S3_BUCKET_NAME"), |
|
endpoint_url=os.getenv("S3_ENDPOINT"), |
|
) |
|
self.svm_model = SVMModel() |
|
logger.info("Classification model initialized successfully") |
|
with open("/home/user/magic-pdf.json", "r") as f: |
|
config = json.load(f) |
|
|
|
self.layout_mode = config["layout-config"]["model"] |
|
self.formula_enable = config["formula-config"]["enable"] |
|
self.table_enable = False |
|
self.language = "en" |
|
endpoint = os.getenv("S3_ENDPOINT", "").rstrip("/") |
|
bucket = os.getenv("S3_BUCKET_NAME", "") |
|
self.prefix = "document-extracts/" |
|
logger.info("Processor initialized successfully") |
|
except Exception as e: |
|
logger.error("Failed to initialize Processor: %s", str(e)) |
|
raise |
|
|
|
def cleanup_gpu(self): |
|
""" |
|
Releases GPU memory, use garbage collection to clear PyTorch's CUDA cache. |
|
This helps prevent VRAM accumulation. |
|
""" |
|
try: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
logger.info("GPU memory cleaned up.") |
|
except Exception as e: |
|
logger.error("Error during GPU cleanup: %s", e) |
|
|
|
def process(self, file_url: str, key: str) -> str: |
|
""" |
|
Process a single PDF, returning final Markdown with irrelevant images removed. |
|
""" |
|
logger.info("Processing file: %s", file_url) |
|
try: |
|
response = requests.get(file_url) |
|
if response.status_code != 200: |
|
logger.error("Failed to download PDF from %s. Status code: %d", file_url, response.status_code) |
|
raise Exception(f"Failed to download PDF: {file_url}") |
|
pdf_bytes = response.content |
|
logger.info("Downloaded %d bytes for file_url='%s'", len(pdf_bytes), file_url) |
|
|
|
dataset = PymuDocDataset(pdf_bytes) |
|
inference = doc_analyze( |
|
dataset, |
|
ocr=True, |
|
lang=self.language, |
|
layout_model=self.layout_mode, |
|
formula_enable=self.formula_enable, |
|
table_enable=self.table_enable |
|
) |
|
logger.info("doc_analyze complete for key='%s'. Started extracting images...", key) |
|
|
|
|
|
image_writer = ImageWriter(self.s3_writer, f"{self.prefix}{key}/", self.svm_model) |
|
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language) |
|
logger.info("OCR pipeline completed for key='%s'.", key) |
|
md_content = pipe_result.get_markdown(f"{self.prefix}{key}/") |
|
final_markdown = image_writer.post_process(f"{self.prefix}{key}/",md_content) |
|
logger.info("Completed PDF process for key='%s'. Final MD length=%d", key, len(final_markdown)) |
|
return final_markdown |
|
finally: |
|
|
|
self.cleanup_gpu() |
|
|
|
|
|
class s3Writer: |
|
def __init__(self, ak: str, sk: str, bucket: str, endpoint_url: str): |
|
self.bucket = bucket |
|
self.client = boto3.client('s3', |
|
aws_access_key_id=ak, |
|
aws_secret_access_key=sk, |
|
endpoint_url=endpoint_url |
|
) |
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
"""Upload data to S3 using proper keyword arguments""" |
|
try: |
|
|
|
file_obj = BytesIO(data) |
|
|
|
|
|
self.client.upload_fileobj( |
|
file_obj, |
|
self.bucket, |
|
path |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed to upload to S3: {str(e)}") |
|
raise |
|
|
|
|
|
class ImageWriter(DataWriter): |
|
""" |
|
Receives each extracted image. Classifies it, uploads if relevant, or flags |
|
it for removal if irrelevant. |
|
""" |
|
def __init__(self, s3_writer: s3Writer, base_path: str, svm_model: SVMModel): |
|
self.s3_writer = s3_writer |
|
self.base_path = base_path |
|
self.svm_model = svm_model |
|
self._redundant_images_paths = [] |
|
self.descriptions = {} |
|
""" |
|
{ |
|
"{path}": { |
|
"description": "{description}", |
|
"full_path": "{full_path}" |
|
} |
|
} |
|
""" |
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
""" |
|
Called for each extracted image. If relevant, upload to S3; otherwise mark for removal. |
|
""" |
|
full_path = f"{self.base_path}" + path.split("/")[-1] |
|
self.s3_writer.write(full_path, data) |
|
self.descriptions[path] = { |
|
"data": data, |
|
"full_path": full_path |
|
} |
|
|
|
def post_process(self, key: str, md_content: str) -> str: |
|
max_workers = len(self.descriptions) |
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max(max_workers, 1)) as executor: |
|
future_to_file = { |
|
executor.submit( |
|
call_gemini_for_image_description, |
|
self.descriptions[path]['data'] |
|
): path for path in self.descriptions.keys() |
|
} |
|
for future in concurrent.futures.as_completed(future_to_file): |
|
path = future_to_file[future] |
|
try: |
|
description = future.result() |
|
if description: |
|
self.descriptions[path]['description'] = description |
|
except Exception as e: |
|
logger.error(f"[ERROR] Processing {path}: {str(e)}") |
|
|
|
for path, info in self.descriptions.items(): |
|
description = info['description'] |
|
full_path = info['full_path'] |
|
md_content = md_content.replace(f"", f"") |
|
return md_content |
|
|
|
|
|
def call_gemini_for_image_description(image_data: bytes) -> str: |
|
"""Convert image bytes to Gemini-compatible format and get description""" |
|
from google import genai |
|
from google.genai import types |
|
import base64 |
|
|
|
try: |
|
|
|
client = genai.Client(api_key="AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU") |
|
|
|
|
|
response = client.models.generate_content( |
|
model="gemini-2.0-flash", |
|
config=types.GenerateContentConfig(temperature=0.), |
|
contents=[ |
|
{ |
|
"parts": [ |
|
{"text": """The provided image is a part of a question paper or markscheme. |
|
Extract all the necessary information from the image to be able to identify the question. |
|
To identify the question, we only need the following: question number and question part. |
|
Don't include redundant information. |
|
For example, if image contains text like: "Q1 Part A Answer: Life on earth was created by diety..." |
|
you should return just "Q1 Part A Mark Scheme" |
|
If there is no text on this image, return the description of the image. 20 words max. |
|
|
|
If there are not enough data, consider information from the surrounding context. |
|
Additionally, if the image contains a truncated part, you must describe it and mark as a |
|
part of some another image that goes before or after current image. |
|
|
|
If the image is of a multiple-choice question’s options, then modify your answer by appending |
|
'MCQ: A [option] B [option] C [option] D [option]' (replacing [option] with the actual options). |
|
Otherwise, follow the above instructions strictly. |
|
"""}, |
|
{ |
|
"inline_data": { |
|
"mime_type": "image/jpeg", |
|
"data": base64.b64encode(image_data).decode('utf-8') |
|
} |
|
} |
|
] |
|
} |
|
] |
|
) |
|
|
|
|
|
description = response.text.strip() if response and response.text else "Image description unavailable" |
|
return description |
|
|
|
except Exception as e: |
|
logger.error(f"Error getting image description: {str(e)}") |
|
return ("error", "Error describing image", None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
class Processor: |
|
def __init__(self): |
|
try: |
|
self.s3_writer = s3Writer( |
|
ak=os.getenv("S3_ACCESS_KEY"), |
|
sk=os.getenv("S3_SECRET_KEY"), |
|
bucket=os.getenv("S3_BUCKET_NAME"), |
|
endpoint_url=os.getenv("S3_ENDPOINT"), |
|
) |
|
self.svm_model = SVMModel() |
|
logger.info("Classification model initialized successfully") |
|
|
|
with open("/home/user/magic-pdf.json", "r") as f: |
|
config = json.load(f) |
|
|
|
self.layout_mode = config["layout-config"]["model"] |
|
self.formula_enable = config["formula-config"]["enable"] |
|
self.table_enable = False |
|
self.language = "en" |
|
|
|
self.prefix = "document-extracts/" |
|
logger.info("Processor initialized successfully") |
|
except Exception as e: |
|
logger.error("Failed to initialize Processor: %s", str(e)) |
|
raise |
|
|
|
def cleanup_gpu(self): |
|
""" |
|
Releases GPU memory, uses garbage collection to clear PyTorch's CUDA cache. |
|
This helps prevent VRAM accumulation. |
|
""" |
|
try: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
logger.info("GPU memory cleaned up.") |
|
except Exception as e: |
|
logger.error("Error during GPU cleanup: %s", e) |
|
|
|
def process(self, file_path: str, key: str) -> str: |
|
""" |
|
Process a single PDF file from a local path, returning final Markdown with irrelevant images removed. |
|
""" |
|
logger.info("Processing file: %s", file_path) |
|
try: |
|
|
|
with open(file_path, "rb") as f: |
|
pdf_bytes = f.read() |
|
|
|
logger.info("Loaded %d bytes from file_path='%s'", len(pdf_bytes), file_path) |
|
|
|
|
|
dataset = PymuDocDataset(pdf_bytes) |
|
inference = doc_analyze( |
|
dataset, |
|
ocr=True, |
|
lang=self.language, |
|
layout_model=self.layout_mode, |
|
formula_enable=self.formula_enable, |
|
table_enable=self.table_enable |
|
) |
|
|
|
logger.info("doc_analyze complete for key='%s'. Started extracting images...", key) |
|
|
|
|
|
image_writer = ImageWriter(self.s3_writer, f"{self.prefix}{key}/", self.svm_model) |
|
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language) |
|
|
|
logger.info("OCR pipeline completed for key='%s'.", key) |
|
|
|
md_content = pipe_result.get_markdown(f"{self.prefix}{key}/") |
|
final_markdown = image_writer.post_process(f"{self.prefix}{key}/", md_content) |
|
|
|
logger.info("Completed PDF process for key='%s'. Final MD length=%d", key, len(final_markdown)) |
|
return final_markdown |
|
finally: |
|
|
|
self.cleanup_gpu() |
|
|
|
|
|
processor = Processor() |
|
file_path = "./output1.pdf" |
|
markdown_result = processor.process(file_path, key="1234323") |
|
print("Single file Markdown:\n", markdown_result) |
|
|