MinerU / model_loader.py
kitjesen's picture
Upload 6 files
8afa9a1 verified
raw
history blame
1.08 kB
import torch
from transformers import AutoModel, AutoTokenizer
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
import os
class MinerUModelLoader:
@staticmethod
def load_models(base_path):
models = {}
# Layout模型加载
cfg = get_cfg()
cfg.merge_from_file(os.path.join(base_path, "models/Layout/config.json"))
cfg.MODEL.WEIGHTS = os.path.join(base_path, "models/Layout/model_final.pth")
models["layout"] = DefaultPredictor(cfg)
# 公式检测模型
models["formula_detector"] = torch.load(os.path.join(base_path, "models/MFD/weights.pt"))
# 公式识别模型
models["formula_recognizer"] = AutoModel.from_pretrained(
os.path.join(base_path, "models/MFR/UniMERNet")
)
# 表格识别模型
models["table_recognizer"] = AutoModel.from_pretrained(
os.path.join(base_path, "models/TabRec/StructEqTable")
)
return models