from transformers import Pipeline import torch from typing import Union, List import fitz import os from detectron2.config import get_cfg from detectron2.engine import DefaultPredictor class MinerUPipeline(Pipeline): def __init__(self, model_path, **kwargs): super().__init__(**kwargs) # 加载Layout模型 cfg = get_cfg() cfg.merge_from_file(os.path.join(model_path, "models/Layout/config.json")) cfg.MODEL.WEIGHTS = os.path.join(model_path, "models/Layout/model_final.pth") self.layout_model = DefaultPredictor(cfg) # 加载其他模型 self.formula_detector = torch.load(os.path.join(model_path, "models/MFD/weights.pt")) self.formula_recognizer = AutoModel.from_pretrained(os.path.join(model_path, "models/MFR/UniMERNet")) self.table_recognizer = AutoModel.from_pretrained(os.path.join(model_path, "TabRec/StructEqTable")) def preprocess(self, pdf_path): """处理PDF输入""" doc = fitz.open(pdf_path) pages = [] for page in doc: # 获取页面图像 pix = page.get_pixmap() # 转换为模型所需格式 img = torch.tensor(pix.samples).permute(2, 0, 1).float() pages.append(img) return pages def _forward(self, pages): results = [] for page in pages: # 1. 布局分析 layout = self.layout_model(page) # 2. 根据布局结果处理不同区域 text_regions = [] formula_regions = [] table_regions = [] for region in layout: if region.type == "text": text_regions.append(self._process_text(region)) elif region.type == "formula": formula_regions.append(self._process_formula(region)) elif region.type == "table": table_regions.append(self._process_table(region)) results.append({ "text": text_regions, "formulas": formula_regions, "tables": table_regions }) return results def _process_formula(self, region): # 公式检测和识别 detected = self.formula_detector(region.image) return self.formula_recognizer(detected) def _process_table(self, region): # 表格识别 return self.table_recognizer(region.image) def postprocess(self, model_outputs): """转换为Markdown""" markdown = [] for page in model_outputs: # 组合文本、公式和表格 markdown.extend(page["text"]) markdown.extend([f"$${formula}$$" for formula in page["formulas"]]) markdown.extend([table.to_markdown() for table in page["tables"]]) return "\n\n".join(markdown)