MinerU / topic_extr.py
SkyNait's picture
fix pattern
059f61a
#!/usr/bin/env python3
import os
import json
import logging
import gc
import fitz
import requests
import torch
import boto3
from io import BytesIO
from typing import Dict, List, Any
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler('topic_processor.log')
]
)
logger = logging.getLogger(__name__)
class s3Writer:
def __init__(self, ak: str, sk: str, bucket: str, endpoint_url: str):
self.bucket = bucket
self.client = boto3.client(
's3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url
)
def write(self, path: str, data: bytes) -> None:
try:
file_obj = BytesIO(data)
self.client.upload_fileobj(file_obj, self.bucket, path)
logger.info(f"Uploaded to S3: {path}")
except Exception as e:
logger.error(f"Failed to upload to S3: {str(e)}")
raise
class S3ImageWriter:
def __init__(self, s3_writer: s3Writer, base_path: str, gemini_api_key: str):
self.s3_writer = s3_writer
self.base_path = base_path if base_path.endswith("/") else base_path + "/"
self.gemini_api_key = gemini_api_key
self.descriptions = {}
def write(self, path: str, data: bytes) -> None:
full_path = f"{self.base_path}{os.path.basename(path)}"
self.s3_writer.write(full_path, data)
self.descriptions[path] = {
"data": data,
"s3_path": full_path
}
def post_process(self, key: str, md_content: str) -> str:
for path, info in self.descriptions.items():
s3_path = info.get("s3_path")
md_content = md_content.replace(f"![]({key}{path})", f"![]({s3_path})")
return md_content
def delete_non_heading_text(md_content: str) -> str:
filtered_lines = []
for line in md_content.splitlines():
stripped = line.lstrip()
if stripped.startswith('#') or stripped.startswith('![]('):
filtered_lines.append(line)
return "\n".join(filtered_lines)
class TopicExtractionProcessor:
def __init__(self, gemini_api_key: str = None):
try:
self.s3_writer = s3Writer(
ak=os.getenv("S3_ACCESS_KEY"),
sk=os.getenv("S3_SECRET_KEY"),
bucket="quextro-resources",
endpoint_url=os.getenv("S3_ENDPOINT")
)
config_path = "/home/user/magic-pdf.json"
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
self.layout_model = config.get("layout-config", {}).get("model", "doclayout_yolo")
self.formula_enable = config.get("formula-config", {}).get("enable", True)
else:
self.layout_model = "doclayout_yolo"
self.formula_enable = True
self.table_enable = False
self.language = "en"
self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY", "AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU")
logger.info("TopicExtractionProcessor initialized successfully")
except Exception as e:
logger.error("Failed to initialize TopicExtractionProcessor: %s", str(e))
raise
def cleanup_gpu(self):
try:
gc.collect()
torch.cuda.empty_cache()
logger.info("GPU memory cleaned up.")
except Exception as e:
logger.error("Error during GPU cleanup: %s", e)
def process(self, input_file: Dict[str, Any]) -> str:
try:
key = input_file.get("key", "")
url = input_file.get("url", "")
pages = input_file.get("page", [])
if not url or not pages:
raise ValueError("Missing required 'url' or 'page' in input file")
if url.startswith(("http://", "https://")):
response = requests.get(url)
response.raise_for_status()
pdf_bytes = response.content
else:
with open(url, "rb") as f:
pdf_bytes = f.read()
pages = self.parse_page_range(pages)
logger.info("Processing %s with pages %s", key, pages)
subset_pdf = self.create_subset_pdf(pdf_bytes, pages)
logger.info(f"Created subset PDF with size: {len(subset_pdf)} bytes")
dataset = PymuDocDataset(subset_pdf)
inference = doc_analyze(
dataset,
ocr=True,
lang=self.language,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable
)
base_path = f"/topic-extraction/{key}/"
writer = S3ImageWriter(self.s3_writer, "/topic-extraction/", self.gemini_api_key)
md_prefix = "/topic-extraction/"
pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
md_content = pipe_result.get_markdown(md_prefix)
post_processed = writer.post_process(md_prefix, md_content)
#remove non-heading text from the markdown output
final_markdown = delete_non_heading_text(post_processed)
return final_markdown
except Exception as e:
logger.error("Processing failed for %s: %s", key, str(e))
raise
finally:
self.cleanup_gpu()
def create_subset_pdf(self, pdf_bytes: bytes, page_indices: List[int]) -> bytes:
"""Create a PDF subset from specified pages"""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
new_doc = fitz.open()
try:
for p in sorted(set(page_indices)):
if 0 <= p < doc.page_count:
new_doc.insert_pdf(doc, from_page=p, to_page=p)
else:
raise ValueError(f"Page index {p} out of range (0-{doc.page_count-1})")
return new_doc.tobytes()
finally:
new_doc.close()
doc.close()
def parse_page_range(self, page_field) -> List[int]:
"""Parse page range from input (1-indexed to 0-indexed)"""
if isinstance(page_field, list):
return [int(p) - 1 for p in page_field]
if isinstance(page_field, str):
parts = [p.strip() for p in page_field.split(',')]
return [int(p) - 1 for p in parts]
raise ValueError("Invalid page field type")
def main():
"""Local test execution without RabbitMQ"""
test_input = {
"key": "local_test",
"url": "/home/user/app/input_output/a-level-pearson-mathematics-specification.pdf", # Local PDF path
"page":[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42]
}
processor = TopicExtractionProcessor()
try:
logger.info("Starting test processing.")
result = processor.process(test_input)
logger.info("Processing completed successfully")
print("Markdown:\n", result)
except Exception as e:
logger.error("Test failed: %s", str(e))
if __name__ == "__main__":
main()