|
from __future__ import annotations |
|
|
|
import os |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
from typing import Dict |
|
|
|
import torch |
|
from huggingface_hub import snapshot_download |
|
from optimum.exporters.onnx import export |
|
from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig |
|
from optimum.onnxruntime import ORTModelForCustomTasks, ORTOptimizer |
|
from optimum.onnxruntime.configuration import AutoOptimizationConfig |
|
from torch import Tensor |
|
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig |
|
|
|
|
|
class BGEM3InferenceModel(PreTrainedModel): |
|
config_class = XLMRobertaConfig |
|
base_model_prefix = "BGEM3InferenceModel" |
|
model_tags = ["BAAI/bge-m3"] |
|
|
|
def __init__(self, model_name: str = "BAAI/bge-m3"): |
|
super().__init__(PretrainedConfig()) |
|
|
|
model_name = snapshot_download(repo_id=model_name) |
|
|
|
self.config = AutoConfig.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
self.sparse_linear = torch.nn.Linear( |
|
in_features=self.model.config.hidden_size, |
|
out_features=1, |
|
) |
|
sparse_state_dict = torch.load(os.path.join(model_name, "sparse_linear.pt"), map_location="cpu") |
|
self.sparse_linear.load_state_dict(sparse_state_dict) |
|
|
|
self.colbert_linear = torch.nn.Linear( |
|
in_features=self.model.config.hidden_size, |
|
out_features=self.model.config.hidden_size, |
|
) |
|
colbert_state_dict = torch.load(os.path.join(model_name, "colbert_linear.pt"), map_location="cpu") |
|
self.colbert_linear.load_state_dict(colbert_state_dict) |
|
|
|
def dense_embedding(self, last_hidden_state: Tensor) -> Tensor: |
|
return last_hidden_state[:, 0] |
|
|
|
def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor: |
|
with torch.no_grad(): |
|
return torch.relu(self.sparse_linear(last_hidden_state)) |
|
|
|
def colbert_embedding(self, last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor: |
|
with torch.no_grad(): |
|
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:]) |
|
return colbert_vecs * attention_mask[:, 1:][:, :, None].float() |
|
|
|
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]: |
|
with torch.no_grad(): |
|
last_hidden_state = self.model( |
|
input_ids=input_ids, attention_mask=attention_mask, return_dict=True |
|
).last_hidden_state |
|
|
|
output = {} |
|
dense_vecs = self.dense_embedding(last_hidden_state) |
|
output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1) |
|
|
|
sparse_vecs = self.sparse_embedding(last_hidden_state) |
|
output["sparse_vecs"] = sparse_vecs |
|
|
|
colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask) |
|
output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1) |
|
|
|
return output |
|
|
|
|
|
class BGEM3OnnxConfig(XLMRobertaOnnxConfig): |
|
@property |
|
def outputs(self) -> Dict[str, Dict[int, str]]: |
|
return OrderedDict( |
|
{ |
|
"dense_vecs": {0: "batch_size", 1: "embedding"}, |
|
"sparse_vecs": {0: "batch_size", 1: "token", 2: "weight"}, |
|
"colbert_vecs": {0: "batch_size", 1: "token", 2: "embedding"}, |
|
} |
|
) |
|
|
|
|
|
def main(output: str, device: str = "cuda", optimize: str = "O4"): |
|
|
|
model = BGEM3InferenceModel() |
|
model.save_pretrained(output) |
|
|
|
|
|
bgem3_onnx_config = BGEM3OnnxConfig(model.config) |
|
|
|
|
|
export( |
|
model, |
|
output=Path(output) / "model.onnx", |
|
config=bgem3_onnx_config, |
|
opset=bgem3_onnx_config.DEFAULT_ONNX_OPSET, |
|
device=device, |
|
) |
|
|
|
optimizer = ORTOptimizer.from_pretrained(output, file_names=["model.onnx"]) |
|
optimization_config = AutoOptimizationConfig.with_optimization_level(optimization_level=optimize) |
|
optimization_config.disable_shape_inference = True |
|
if optimize == "O4": |
|
optimization_config.optimize_for_gpu = True |
|
optimization_config.fp16 = True |
|
optimization_config.optimization_level = 99 |
|
optimizer.optimize(save_dir=output, optimization_config=optimization_config, file_suffix="") |
|
|
|
ORTModelForCustomTasks.from_pretrained( |
|
output, |
|
provider="CUDAExecutionProvider" if device == "cuda" else "CPUExecutionProvider", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--output", type=str) |
|
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda") |
|
parser.add_argument("--optimize", type=str, choices=["O1", "O2", "O3", "O4"], default="O4") |
|
parser.add_argument("--push_to_hub", action="store_true", default=False) |
|
parser.add_argument("--push_to_hub_repo_id", type=str, default="JeremyHibiki/bge-m3-onnx") |
|
args = parser.parse_args() |
|
main(args.output, args.device, args.optimize) |
|
|