Update README.md
Browse filesRefactored the convert_embeddings_to_weights script to be self-contained and executable by integrating the embeddings generated from the first code snippet.
README.md
CHANGED
@@ -58,10 +58,31 @@ print(scores.tolist())
|
|
58 |
|
59 |
#### Convert embeddings to lexical weights
|
60 |
```python
|
|
|
|
|
61 |
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def convert_embeddings_to_weights(embeddings, tokenizer):
|
63 |
values, indices = torch.sort(embeddings, dim=-1, descending=True)
|
64 |
-
|
65 |
token2weight = []
|
66 |
for i in range(embeddings.size(0)):
|
67 |
token2weight.append(OrderedDict())
|
@@ -77,6 +98,7 @@ def convert_embeddings_to_weights(embeddings, tokenizer):
|
|
77 |
|
78 |
token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
|
79 |
print(token2weight[0])
|
|
|
80 |
# OrderedDict([('一体机', 3.3438382148742676), ('由', 2.493837356567383), ('电脑', 2.0291812419891357), ('构成', 1.986171841621399), ('什么', 1.0218793153762817)])
|
81 |
```
|
82 |
|
|
|
58 |
|
59 |
#### Convert embeddings to lexical weights
|
60 |
```python
|
61 |
+
import torch
|
62 |
+
from transformers import AutoTokenizer, AutoModel
|
63 |
from collections import OrderedDict
|
64 |
+
|
65 |
+
queries = ['电脑一体机由什么构成?', '什么是掌上电脑?']
|
66 |
+
documents = [
|
67 |
+
'电脑一体机,是由一台显示器、一个电脑键盘和一个鼠标组成的电脑。',
|
68 |
+
'掌上电脑是一种运行在嵌入式操作系统和内嵌式应用软件之上的、小巧、轻便、易带、实用、价廉的手持式计算设备。',
|
69 |
+
]
|
70 |
+
input_texts = queries + documents
|
71 |
+
|
72 |
+
tokenizer = AutoTokenizer.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True, use_fast=False)
|
73 |
+
model = AutoModel.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True)
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
max_length = 512
|
77 |
+
|
78 |
+
input_batch = tokenizer(input_texts, padding=True, max_length=max_length, truncation=True, return_tensors="pt")
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
embeddings = model(input_batch['input_ids'], input_batch['attention_mask'], return_sparse=False)
|
82 |
+
|
83 |
def convert_embeddings_to_weights(embeddings, tokenizer):
|
84 |
values, indices = torch.sort(embeddings, dim=-1, descending=True)
|
85 |
+
|
86 |
token2weight = []
|
87 |
for i in range(embeddings.size(0)):
|
88 |
token2weight.append(OrderedDict())
|
|
|
98 |
|
99 |
token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
|
100 |
print(token2weight[0])
|
101 |
+
|
102 |
# OrderedDict([('一体机', 3.3438382148742676), ('由', 2.493837356567383), ('电脑', 2.0291812419891357), ('构成', 1.986171841621399), ('什么', 1.0218793153762817)])
|
103 |
```
|
104 |
|