|
|
|
# 模型介绍 |
|
|
|
主要针对检索和语义匹配任务,本人实测要好于当前大多数向量模型。 |
|
|
|
支持多个向量维度:256,768,1024,1563,1792,2048,4096 |
|
|
|
支持中英互搜,但是英文表征能力要弱于中文 |
|
|
|
# 模型目录结构 |
|
|
|
结构很简单,就是标准的SentenceTransformer文件目录 + 一系列`2_Dense_{dims}`文件夹,dims代表最终的向量维度。 |
|
|
|
举个例子,`2_Dense_256`文件夹里存储了把向量维度转换为256维的Linear权重,具体如何使用请看下面的章节 |
|
|
|
# 模型使用方法 |
|
|
|
可直接用SentenceTransformer加载,也可以使用transformer加载使用: |
|
|
|
```python |
|
import os |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.preprocessing import normalize |
|
|
|
# 待编码文本 |
|
texts = ["通用向量编码", "hello world", "支持中英互搜,不建议纯英文场景使用"] |
|
# 模型目录 |
|
model_dir = "{MODEL_PATH}" |
|
|
|
#### 方法1:使用SentenceTransformer |
|
# !!!!!!!!!!!!!!默认是4096维度,如需其他维度,请自行复制2_Dense_{dims}中的文件到2_Dense文件夹中覆盖!!!!!!!!!!!!!! |
|
model = SentenceTransformer(model_dir) |
|
vectors = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True) |
|
print(vectors.shape) |
|
print(vectors[:, :4]) |
|
|
|
#### 方法2:使用transformers库 |
|
# !!!!!!!!!!!!!! 本代码会根据vector_dim值会读取对应的Linear层权重,请按需选择vector_dim !!!!!!!!!!!!!! |
|
vector_dim = 4096 |
|
model = AutoModel.from_pretrained(model_dir).eval() |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
vector_linear = torch.nn.Linear(in_features=model.config.hidden_size, out_features=vector_dim) |
|
vector_linear_dict = { |
|
k.replace("linear.", ""): v for k, v in |
|
torch.load(os.path.join(model_dir, f"2_Dense_{vector_dim}/pytorch_model.bin")).items() |
|
} |
|
vector_linear.load_state_dict(vector_linear_dict) |
|
with torch.no_grad(): |
|
input_data = tokenizer(texts, padding="longest", truncation=True, max_length=512, return_tensors="pt") |
|
attention_mask = input_data["attention_mask"] |
|
last_hidden_state = model(**input_data)[0] |
|
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
vectors = normalize(vector_linear(vectors).cpu().numpy()) |
|
print(vectors.shape) |
|
print(vectors[:, :4]) |
|
|
|
``` |