EasonYao commited on
Commit
1eaba88
·
verified ·
1 Parent(s): e9c77ac

Update README.md

Browse files

Refactored the convert_embeddings_to_weights script to be self-contained and executable by integrating the embeddings generated from the first code snippet.

Files changed (1) hide show
  1. README.md +23 -1
README.md CHANGED
@@ -58,10 +58,31 @@ print(scores.tolist())
58
 
59
  #### Convert embeddings to lexical weights
60
  ```python
 
 
61
  from collections import OrderedDict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def convert_embeddings_to_weights(embeddings, tokenizer):
63
  values, indices = torch.sort(embeddings, dim=-1, descending=True)
64
-
65
  token2weight = []
66
  for i in range(embeddings.size(0)):
67
  token2weight.append(OrderedDict())
@@ -77,6 +98,7 @@ def convert_embeddings_to_weights(embeddings, tokenizer):
77
 
78
  token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
79
  print(token2weight[0])
 
80
  # OrderedDict([('一体机', 3.3438382148742676), ('由', 2.493837356567383), ('电脑', 2.0291812419891357), ('构成', 1.986171841621399), ('什么', 1.0218793153762817)])
81
  ```
82
 
 
58
 
59
  #### Convert embeddings to lexical weights
60
  ```python
61
+ import torch
62
+ from transformers import AutoTokenizer, AutoModel
63
  from collections import OrderedDict
64
+
65
+ queries = ['电脑一体机由什么构成?', '什么是掌上电脑?']
66
+ documents = [
67
+ '电脑一体机,是由一台显示器、一个电脑键盘和一个鼠标组成的电脑。',
68
+ '掌上电脑是一种运行在嵌入式操作系统和内嵌式应用软件之上的、小巧、轻便、易带、实用、价廉的手持式计算设备。',
69
+ ]
70
+ input_texts = queries + documents
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True, use_fast=False)
73
+ model = AutoModel.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True)
74
+ model.eval()
75
+
76
+ max_length = 512
77
+
78
+ input_batch = tokenizer(input_texts, padding=True, max_length=max_length, truncation=True, return_tensors="pt")
79
+
80
+ with torch.no_grad():
81
+ embeddings = model(input_batch['input_ids'], input_batch['attention_mask'], return_sparse=False)
82
+
83
  def convert_embeddings_to_weights(embeddings, tokenizer):
84
  values, indices = torch.sort(embeddings, dim=-1, descending=True)
85
+
86
  token2weight = []
87
  for i in range(embeddings.size(0)):
88
  token2weight.append(OrderedDict())
 
98
 
99
  token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
100
  print(token2weight[0])
101
+
102
  # OrderedDict([('一体机', 3.3438382148742676), ('由', 2.493837356567383), ('电脑', 2.0291812419891357), ('构成', 1.986171841621399), ('什么', 1.0218793153762817)])
103
  ```
104