|
|
|
import os |
|
import json |
|
import logging |
|
import gc |
|
import fitz |
|
import requests |
|
import torch |
|
import boto3 |
|
from io import BytesIO |
|
from typing import Dict, List, Any |
|
|
|
from magic_pdf.data.dataset import PymuDocDataset |
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
|
handlers=[ |
|
logging.StreamHandler(), |
|
logging.FileHandler('topic_processor.log') |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class s3Writer: |
|
def __init__(self, ak: str, sk: str, bucket: str, endpoint_url: str): |
|
self.bucket = bucket |
|
self.client = boto3.client( |
|
's3', |
|
aws_access_key_id=ak, |
|
aws_secret_access_key=sk, |
|
endpoint_url=endpoint_url |
|
) |
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
try: |
|
file_obj = BytesIO(data) |
|
self.client.upload_fileobj(file_obj, self.bucket, path) |
|
logger.info(f"Uploaded to S3: {path}") |
|
except Exception as e: |
|
logger.error(f"Failed to upload to S3: {str(e)}") |
|
raise |
|
|
|
class S3ImageWriter: |
|
def __init__(self, s3_writer: s3Writer, base_path: str, gemini_api_key: str): |
|
self.s3_writer = s3_writer |
|
self.base_path = base_path if base_path.endswith("/") else base_path + "/" |
|
self.gemini_api_key = gemini_api_key |
|
self.descriptions = {} |
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
full_path = f"{self.base_path}{os.path.basename(path)}" |
|
self.s3_writer.write(full_path, data) |
|
self.descriptions[path] = { |
|
"data": data, |
|
"s3_path": full_path |
|
} |
|
|
|
def post_process(self, key: str, md_content: str) -> str: |
|
for path, info in self.descriptions.items(): |
|
s3_path = info.get("s3_path") |
|
md_content = md_content.replace(f"", f"") |
|
return md_content |
|
|
|
def delete_non_heading_text(md_content: str) -> str: |
|
filtered_lines = [] |
|
for line in md_content.splitlines(): |
|
stripped = line.lstrip() |
|
if stripped.startswith('#') or stripped.startswith(': |
|
filtered_lines.append(line) |
|
return "\n".join(filtered_lines) |
|
|
|
class TopicExtractionProcessor: |
|
def __init__(self, gemini_api_key: str = None): |
|
try: |
|
self.s3_writer = s3Writer( |
|
ak=os.getenv("S3_ACCESS_KEY"), |
|
sk=os.getenv("S3_SECRET_KEY"), |
|
bucket="quextro-resources", |
|
endpoint_url=os.getenv("S3_ENDPOINT") |
|
) |
|
|
|
config_path = "/home/user/magic-pdf.json" |
|
if os.path.exists(config_path): |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
self.layout_model = config.get("layout-config", {}).get("model", "doclayout_yolo") |
|
self.formula_enable = config.get("formula-config", {}).get("enable", True) |
|
else: |
|
self.layout_model = "doclayout_yolo" |
|
self.formula_enable = True |
|
|
|
self.table_enable = False |
|
self.language = "en" |
|
self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY", "AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU") |
|
|
|
logger.info("TopicExtractionProcessor initialized successfully") |
|
except Exception as e: |
|
logger.error("Failed to initialize TopicExtractionProcessor: %s", str(e)) |
|
raise |
|
|
|
def cleanup_gpu(self): |
|
try: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
logger.info("GPU memory cleaned up.") |
|
except Exception as e: |
|
logger.error("Error during GPU cleanup: %s", e) |
|
|
|
def process(self, input_file: Dict[str, Any]) -> str: |
|
try: |
|
key = input_file.get("key", "") |
|
url = input_file.get("url", "") |
|
pages = input_file.get("page", []) |
|
|
|
if not url or not pages: |
|
raise ValueError("Missing required 'url' or 'page' in input file") |
|
|
|
if url.startswith(("http://", "https://")): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
pdf_bytes = response.content |
|
else: |
|
with open(url, "rb") as f: |
|
pdf_bytes = f.read() |
|
|
|
pages = self.parse_page_range(pages) |
|
logger.info("Processing %s with pages %s", key, pages) |
|
|
|
subset_pdf = self.create_subset_pdf(pdf_bytes, pages) |
|
logger.info(f"Created subset PDF with size: {len(subset_pdf)} bytes") |
|
|
|
|
|
dataset = PymuDocDataset(subset_pdf) |
|
inference = doc_analyze( |
|
dataset, |
|
ocr=True, |
|
lang=self.language, |
|
layout_model=self.layout_model, |
|
formula_enable=self.formula_enable, |
|
table_enable=self.table_enable |
|
) |
|
|
|
base_path = f"/topic-extraction/{key}/" |
|
writer = S3ImageWriter(self.s3_writer, "/topic-extraction/", self.gemini_api_key) |
|
md_prefix = "/topic-extraction/" |
|
pipe_result = inference.pipe_ocr_mode(writer, lang=self.language) |
|
md_content = pipe_result.get_markdown(md_prefix) |
|
post_processed = writer.post_process(md_prefix, md_content) |
|
|
|
|
|
final_markdown = delete_non_heading_text(post_processed) |
|
|
|
return final_markdown |
|
|
|
except Exception as e: |
|
logger.error("Processing failed for %s: %s", key, str(e)) |
|
raise |
|
finally: |
|
self.cleanup_gpu() |
|
|
|
def create_subset_pdf(self, pdf_bytes: bytes, page_indices: List[int]) -> bytes: |
|
"""Create a PDF subset from specified pages""" |
|
doc = fitz.open(stream=pdf_bytes, filetype="pdf") |
|
new_doc = fitz.open() |
|
|
|
try: |
|
for p in sorted(set(page_indices)): |
|
if 0 <= p < doc.page_count: |
|
new_doc.insert_pdf(doc, from_page=p, to_page=p) |
|
else: |
|
raise ValueError(f"Page index {p} out of range (0-{doc.page_count-1})") |
|
return new_doc.tobytes() |
|
finally: |
|
new_doc.close() |
|
doc.close() |
|
|
|
def parse_page_range(self, page_field) -> List[int]: |
|
"""Parse page range from input (1-indexed to 0-indexed)""" |
|
if isinstance(page_field, list): |
|
return [int(p) - 1 for p in page_field] |
|
if isinstance(page_field, str): |
|
parts = [p.strip() for p in page_field.split(',')] |
|
return [int(p) - 1 for p in parts] |
|
raise ValueError("Invalid page field type") |
|
|
|
def main(): |
|
"""Local test execution without RabbitMQ""" |
|
test_input = { |
|
"key": "local_test", |
|
"url": "/home/user/app/input_output/a-level-pearson-mathematics-specification.pdf", |
|
"page":[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42] |
|
} |
|
|
|
processor = TopicExtractionProcessor() |
|
|
|
try: |
|
logger.info("Starting test processing.") |
|
result = processor.process(test_input) |
|
logger.info("Processing completed successfully") |
|
print("Markdown:\n", result) |
|
except Exception as e: |
|
logger.error("Test failed: %s", str(e)) |
|
|
|
if __name__ == "__main__": |
|
main() |