File size: 7,640 Bytes
f81cfef 9351a05 f81cfef 9351a05 6fc2b3e f81cfef 6fc2b3e f81cfef 9351a05 f81cfef 9351a05 f81cfef 9351a05 f81cfef 9351a05 f81cfef 9351a05 f81cfef 6fc2b3e 9351a05 6fc2b3e f81cfef 6fc2b3e da94345 6fc2b3e da94345 6fc2b3e da94345 059f61a 6fc2b3e 9351a05 6fc2b3e 9351a05 6fc2b3e 9351a05 f81cfef 6fc2b3e 9351a05 6fc2b3e f81cfef 6fc2b3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
#!/usr/bin/env python3
import os
import json
import logging
import gc
import fitz
import requests
import torch
import boto3
from io import BytesIO
from typing import Dict, List, Any
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler('topic_processor.log')
]
)
logger = logging.getLogger(__name__)
class s3Writer:
def __init__(self, ak: str, sk: str, bucket: str, endpoint_url: str):
self.bucket = bucket
self.client = boto3.client(
's3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url
)
def write(self, path: str, data: bytes) -> None:
try:
file_obj = BytesIO(data)
self.client.upload_fileobj(file_obj, self.bucket, path)
logger.info(f"Uploaded to S3: {path}")
except Exception as e:
logger.error(f"Failed to upload to S3: {str(e)}")
raise
class S3ImageWriter:
def __init__(self, s3_writer: s3Writer, base_path: str, gemini_api_key: str):
self.s3_writer = s3_writer
self.base_path = base_path if base_path.endswith("/") else base_path + "/"
self.gemini_api_key = gemini_api_key
self.descriptions = {}
def write(self, path: str, data: bytes) -> None:
full_path = f"{self.base_path}{os.path.basename(path)}"
self.s3_writer.write(full_path, data)
self.descriptions[path] = {
"data": data,
"s3_path": full_path
}
def post_process(self, key: str, md_content: str) -> str:
for path, info in self.descriptions.items():
s3_path = info.get("s3_path")
md_content = md_content.replace(f"", f"")
return md_content
def delete_non_heading_text(md_content: str) -> str:
filtered_lines = []
for line in md_content.splitlines():
stripped = line.lstrip()
if stripped.startswith('#') or stripped.startswith(':
filtered_lines.append(line)
return "\n".join(filtered_lines)
class TopicExtractionProcessor:
def __init__(self, gemini_api_key: str = None):
try:
self.s3_writer = s3Writer(
ak=os.getenv("S3_ACCESS_KEY"),
sk=os.getenv("S3_SECRET_KEY"),
bucket="quextro-resources",
endpoint_url=os.getenv("S3_ENDPOINT")
)
config_path = "/home/user/magic-pdf.json"
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
self.layout_model = config.get("layout-config", {}).get("model", "doclayout_yolo")
self.formula_enable = config.get("formula-config", {}).get("enable", True)
else:
self.layout_model = "doclayout_yolo"
self.formula_enable = True
self.table_enable = False
self.language = "en"
self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY", "AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU")
logger.info("TopicExtractionProcessor initialized successfully")
except Exception as e:
logger.error("Failed to initialize TopicExtractionProcessor: %s", str(e))
raise
def cleanup_gpu(self):
try:
gc.collect()
torch.cuda.empty_cache()
logger.info("GPU memory cleaned up.")
except Exception as e:
logger.error("Error during GPU cleanup: %s", e)
def process(self, input_file: Dict[str, Any]) -> str:
try:
key = input_file.get("key", "")
url = input_file.get("url", "")
pages = input_file.get("page", [])
if not url or not pages:
raise ValueError("Missing required 'url' or 'page' in input file")
if url.startswith(("http://", "https://")):
response = requests.get(url)
response.raise_for_status()
pdf_bytes = response.content
else:
with open(url, "rb") as f:
pdf_bytes = f.read()
pages = self.parse_page_range(pages)
logger.info("Processing %s with pages %s", key, pages)
subset_pdf = self.create_subset_pdf(pdf_bytes, pages)
logger.info(f"Created subset PDF with size: {len(subset_pdf)} bytes")
dataset = PymuDocDataset(subset_pdf)
inference = doc_analyze(
dataset,
ocr=True,
lang=self.language,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable
)
base_path = f"/topic-extraction/{key}/"
writer = S3ImageWriter(self.s3_writer, "/topic-extraction/", self.gemini_api_key)
md_prefix = "/topic-extraction/"
pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
md_content = pipe_result.get_markdown(md_prefix)
post_processed = writer.post_process(md_prefix, md_content)
#remove non-heading text from the markdown output
final_markdown = delete_non_heading_text(post_processed)
return final_markdown
except Exception as e:
logger.error("Processing failed for %s: %s", key, str(e))
raise
finally:
self.cleanup_gpu()
def create_subset_pdf(self, pdf_bytes: bytes, page_indices: List[int]) -> bytes:
"""Create a PDF subset from specified pages"""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
new_doc = fitz.open()
try:
for p in sorted(set(page_indices)):
if 0 <= p < doc.page_count:
new_doc.insert_pdf(doc, from_page=p, to_page=p)
else:
raise ValueError(f"Page index {p} out of range (0-{doc.page_count-1})")
return new_doc.tobytes()
finally:
new_doc.close()
doc.close()
def parse_page_range(self, page_field) -> List[int]:
"""Parse page range from input (1-indexed to 0-indexed)"""
if isinstance(page_field, list):
return [int(p) - 1 for p in page_field]
if isinstance(page_field, str):
parts = [p.strip() for p in page_field.split(',')]
return [int(p) - 1 for p in parts]
raise ValueError("Invalid page field type")
def main():
"""Local test execution without RabbitMQ"""
test_input = {
"key": "local_test",
"url": "/home/user/app/input_output/a-level-pearson-mathematics-specification.pdf", # Local PDF path
"page":[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42]
}
processor = TopicExtractionProcessor()
try:
logger.info("Starting test processing.")
result = processor.process(test_input)
logger.info("Processing completed successfully")
print("Markdown:\n", result)
except Exception as e:
logger.error("Test failed: %s", str(e))
if __name__ == "__main__":
main() |