Update README.md
Browse files
README.md
CHANGED
@@ -6,7 +6,7 @@ tags:
|
|
6 |
- transformers
|
7 |
---
|
8 |
|
9 |
-
## INF
|
10 |
|
11 |
**INF-WSE** is a series of word-level sparse embedding models developed by [INFLY TECH](https://www.infly.cn/en).
|
12 |
These models are optimized to generate sparse, high-dimensional text embeddings that excel in capturing the most
|
@@ -29,7 +29,7 @@ relevant information for search and retrieval, particularly in Chinese text.
|
|
29 |
|
30 |
### Transformers
|
31 |
|
32 |
-
#### Infer
|
33 |
```python
|
34 |
import torch
|
35 |
from transformers import AutoTokenizer, AutoModel
|
@@ -58,31 +58,10 @@ print(scores.tolist())
|
|
58 |
|
59 |
#### Convert embeddings to lexical weights
|
60 |
```python
|
61 |
-
import torch
|
62 |
-
from transformers import AutoTokenizer, AutoModel
|
63 |
from collections import OrderedDict
|
64 |
-
|
65 |
-
queries = ['电脑一体机由什么构成?', '什么是掌上电脑?']
|
66 |
-
documents = [
|
67 |
-
'电脑一体机,是由一台显示器、一个电脑键盘和一个鼠标组成的电脑。',
|
68 |
-
'掌上电脑是一种运行在嵌入式操作系统和内嵌式应用软件之上的、小巧、轻便、易带、实用、价廉的手持式计算设备。',
|
69 |
-
]
|
70 |
-
input_texts = queries + documents
|
71 |
-
|
72 |
-
tokenizer = AutoTokenizer.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True, use_fast=False)
|
73 |
-
model = AutoModel.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True)
|
74 |
-
model.eval()
|
75 |
-
|
76 |
-
max_length = 512
|
77 |
-
|
78 |
-
input_batch = tokenizer(input_texts, padding=True, max_length=max_length, truncation=True, return_tensors="pt")
|
79 |
-
|
80 |
-
with torch.no_grad():
|
81 |
-
embeddings = model(input_batch['input_ids'], input_batch['attention_mask'], return_sparse=False)
|
82 |
-
|
83 |
def convert_embeddings_to_weights(embeddings, tokenizer):
|
84 |
values, indices = torch.sort(embeddings, dim=-1, descending=True)
|
85 |
-
|
86 |
token2weight = []
|
87 |
for i in range(embeddings.size(0)):
|
88 |
token2weight.append(OrderedDict())
|
@@ -97,14 +76,14 @@ def convert_embeddings_to_weights(embeddings, tokenizer):
|
|
97 |
return token2weight
|
98 |
|
99 |
token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
|
100 |
-
print(token2weight[
|
101 |
-
|
102 |
-
# OrderedDict([('一体机', 3.3438382148742676), ('由', 2.493837356567383), ('电脑', 2.0291812419891357), ('构成', 1.986171841621399), ('什么', 1.0218793153762817)])
|
103 |
```
|
104 |
|
105 |
## Evaluation
|
106 |
|
107 |
### C-MTEB Retrieval task
|
|
|
108 |
([Chinese Massive Text Embedding Benchmark](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB))
|
109 |
|
110 |
Metric: nDCG@10
|
@@ -114,3 +93,5 @@ Metric: nDCG@10
|
|
114 |
| [BM25-zh](https://github.com/castorini/pyserini) | - | 25.39 | 13.70 | **86.66** | 13.68 | 11.49 | 15.48 | 6.56 | 29.53 | 25.98 |
|
115 |
| [bge-m3-sparse](https://huggingface.co/BAAI/bge-m3) | 512 | 29.94 | **24.50** | 76.16 | 22.12 | 17.62 | 27.52 | 9.78 | **37.69** | 24.12 |
|
116 |
| **inf-wse-v1-base-zh** | 512 | **32.83** | 20.51 | 76.40 | **36.77** | **19.97** | **28.61** | **13.32** | 36.81 | **30.25** |
|
|
|
|
|
|
6 |
- transformers
|
7 |
---
|
8 |
|
9 |
+
## <u>INF</u> <u>W</u>ord-level <u>S</u>parse <u>E</u>mbedding (INF-WSE)
|
10 |
|
11 |
**INF-WSE** is a series of word-level sparse embedding models developed by [INFLY TECH](https://www.infly.cn/en).
|
12 |
These models are optimized to generate sparse, high-dimensional text embeddings that excel in capturing the most
|
|
|
29 |
|
30 |
### Transformers
|
31 |
|
32 |
+
#### Infer embeddings
|
33 |
```python
|
34 |
import torch
|
35 |
from transformers import AutoTokenizer, AutoModel
|
|
|
58 |
|
59 |
#### Convert embeddings to lexical weights
|
60 |
```python
|
|
|
|
|
61 |
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def convert_embeddings_to_weights(embeddings, tokenizer):
|
63 |
values, indices = torch.sort(embeddings, dim=-1, descending=True)
|
64 |
+
|
65 |
token2weight = []
|
66 |
for i in range(embeddings.size(0)):
|
67 |
token2weight.append(OrderedDict())
|
|
|
76 |
return token2weight
|
77 |
|
78 |
token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
|
79 |
+
print(token2weight[1])
|
80 |
+
# OrderedDict([('掌上', 3.4572525024414062), ('电脑', 2.6253132820129395), ('是', 2.0787220001220703), ('什么', 1.2899624109268188)])
|
|
|
81 |
```
|
82 |
|
83 |
## Evaluation
|
84 |
|
85 |
### C-MTEB Retrieval task
|
86 |
+
|
87 |
([Chinese Massive Text Embedding Benchmark](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB))
|
88 |
|
89 |
Metric: nDCG@10
|
|
|
93 |
| [BM25-zh](https://github.com/castorini/pyserini) | - | 25.39 | 13.70 | **86.66** | 13.68 | 11.49 | 15.48 | 6.56 | 29.53 | 25.98 |
|
94 |
| [bge-m3-sparse](https://huggingface.co/BAAI/bge-m3) | 512 | 29.94 | **24.50** | 76.16 | 22.12 | 17.62 | 27.52 | 9.78 | **37.69** | 24.12 |
|
95 |
| **inf-wse-v1-base-zh** | 512 | **32.83** | 20.51 | 76.40 | **36.77** | **19.97** | **28.61** | **13.32** | 36.81 | **30.25** |
|
96 |
+
|
97 |
+
All results, except for BM25, are measured by building the sparse index via [Qdrant](https://github.com/qdrant/qdrant).
|