File size: 4,999 Bytes
76d04a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import os
from collections import OrderedDict
from pathlib import Path
from typing import Dict

import torch
from huggingface_hub import snapshot_download
from optimum.exporters.onnx import export
from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig
from optimum.onnxruntime import ORTModelForCustomTasks, ORTOptimizer
from optimum.onnxruntime.configuration import AutoOptimizationConfig
from torch import Tensor
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig


class BGEM3InferenceModel(PreTrainedModel):
    config_class = XLMRobertaConfig
    base_model_prefix = "BGEM3InferenceModel"
    model_tags = ["BAAI/bge-m3"]

    def __init__(self, model_name: str = "BAAI/bge-m3"):
        super().__init__(PretrainedConfig())

        model_name = snapshot_download(repo_id=model_name)

        self.config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

        self.sparse_linear = torch.nn.Linear(
            in_features=self.model.config.hidden_size,
            out_features=1,
        )
        sparse_state_dict = torch.load(os.path.join(model_name, "sparse_linear.pt"), map_location="cpu")
        self.sparse_linear.load_state_dict(sparse_state_dict)

        self.colbert_linear = torch.nn.Linear(
            in_features=self.model.config.hidden_size,
            out_features=self.model.config.hidden_size,
        )
        colbert_state_dict = torch.load(os.path.join(model_name, "colbert_linear.pt"), map_location="cpu")
        self.colbert_linear.load_state_dict(colbert_state_dict)

    def dense_embedding(self, last_hidden_state: Tensor) -> Tensor:
        return last_hidden_state[:, 0]

    def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor:
        with torch.no_grad():
            return torch.relu(self.sparse_linear(last_hidden_state))

    def colbert_embedding(self, last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor:
        with torch.no_grad():
            colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
        return colbert_vecs * attention_mask[:, 1:][:, :, None].float()

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]:
        with torch.no_grad():
            last_hidden_state = self.model(
                input_ids=input_ids, attention_mask=attention_mask, return_dict=True
            ).last_hidden_state

        output = {}
        dense_vecs = self.dense_embedding(last_hidden_state)
        output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1)

        sparse_vecs = self.sparse_embedding(last_hidden_state)
        output["sparse_vecs"] = sparse_vecs

        colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask)
        output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1)

        return output


class BGEM3OnnxConfig(XLMRobertaOnnxConfig):
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        return OrderedDict(
            {
                "dense_vecs": {0: "batch_size", 1: "embedding"},
                "sparse_vecs": {0: "batch_size", 1: "token", 2: "weight"},
                "colbert_vecs": {0: "batch_size", 1: "token", 2: "embedding"},
            }
        )


def main(output: str, device: str = "cuda", optimize: str = "O4"):
    # 加载模型
    model = BGEM3InferenceModel()
    model.save_pretrained(output)

    # 配置
    bgem3_onnx_config = BGEM3OnnxConfig(model.config)

    # 导出
    export(
        model,
        output=Path(output) / "model.onnx",
        config=bgem3_onnx_config,
        opset=bgem3_onnx_config.DEFAULT_ONNX_OPSET,
        device=device,
    )

    optimizer = ORTOptimizer.from_pretrained(output, file_names=["model.onnx"])
    optimization_config = AutoOptimizationConfig.with_optimization_level(optimization_level=optimize)
    optimization_config.disable_shape_inference = True
    if optimize == "O4":
        optimization_config.optimize_for_gpu = True
        optimization_config.fp16 = True
        optimization_config.optimization_level = 99
    optimizer.optimize(save_dir=output, optimization_config=optimization_config, file_suffix="")

    ORTModelForCustomTasks.from_pretrained(
        output,
        provider="CUDAExecutionProvider" if device == "cuda" else "CPUExecutionProvider",
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--output", type=str)
    parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
    parser.add_argument("--optimize", type=str, choices=["O1", "O2", "O3", "O4"], default="O4")
    parser.add_argument("--push_to_hub", action="store_true", default=False)
    parser.add_argument("--push_to_hub_repo_id", type=str, default="JeremyHibiki/bge-m3-onnx")
    args = parser.parse_args()
    main(args.output, args.device, args.optimize)