Update modeling_bge_m3.py
Browse files- modeling_bge_m3.py +11 -25
modeling_bge_m3.py
CHANGED
@@ -42,6 +42,7 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
|
|
42 |
|
43 |
self.init_weights()
|
44 |
|
|
|
45 |
def dense_embedding(self, hidden_state, mask):
|
46 |
if self.sentence_pooling_method == "cls":
|
47 |
return hidden_state[:, 0]
|
@@ -50,6 +51,7 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
|
|
50 |
d = mask.sum(axis=1, keepdim=True).float()
|
51 |
return s / d
|
52 |
|
|
|
53 |
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = False):
|
54 |
token_weights = torch.relu(self.sparse_linear(hidden_state))
|
55 |
if not return_embedding:
|
@@ -69,11 +71,13 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
|
|
69 |
sparse_embedding[:, unused_tokens] *= 0.0
|
70 |
return sparse_embedding
|
71 |
|
|
|
72 |
def colbert_embedding(self, last_hidden_state, mask):
|
73 |
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
|
74 |
colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
|
75 |
return colbert_vecs
|
76 |
|
|
|
77 |
def _process_token_weights(self, token_weights, input_ids, mask):
|
78 |
token_weights = token_weights.squeeze(-1)
|
79 |
# conver to dict
|
@@ -81,50 +85,32 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
|
|
81 |
unused_tokens = self.config.unused_tokens
|
82 |
unused_tokens = torch.tensor(unused_tokens, device=input_ids.device)
|
83 |
|
84 |
-
#
|
85 |
valid_indices = ~torch.isin(input_ids, unused_tokens)
|
86 |
-
#
|
87 |
valid_indices = (valid_indices & (token_weights > 0)).bool()
|
88 |
-
# 结合 attention mask,获取有效的 token 的索引
|
89 |
valid_indices = (valid_indices & mask).bool()
|
90 |
|
91 |
for i, valid in enumerate(valid_indices):
|
92 |
result = defaultdict(int)
|
93 |
|
94 |
-
#
|
95 |
valid_weights = token_weights[i][valid]
|
96 |
valid_ids = input_ids[i][valid]
|
97 |
|
98 |
-
#
|
99 |
unique_ids, inverse_indices = torch.unique(valid_ids, return_inverse=True)
|
100 |
|
101 |
-
#
|
102 |
for i in range(unique_ids.shape[0]):
|
103 |
id_mask = inverse_indices == i
|
104 |
result[str(unique_ids[i].item())] = valid_weights[id_mask].max().item()
|
105 |
|
106 |
all_result.append(result)
|
107 |
-
|
108 |
-
# for w, idx, num in zip(token_weights, input_ids, tokens_num):
|
109 |
-
# r = defaultdict(int)
|
110 |
-
# token_weight = w[:num]
|
111 |
-
# idx = idx[:num]
|
112 |
-
|
113 |
-
# for t_w, t_idx in zip(token_weight, idx):
|
114 |
-
# if t_idx.item() not in unused_tokens:
|
115 |
-
# t_idx = str(t_idx.item())
|
116 |
-
# if t_w > r[t_idx]:
|
117 |
-
# r[t_idx] = t_w.item()
|
118 |
-
|
119 |
-
# result.append(r)
|
120 |
-
|
121 |
-
# if idx not in unused_tokens and w > 0:
|
122 |
-
# idx = str(idx)
|
123 |
-
# # w = int(w)
|
124 |
-
# if w > result[idx]:
|
125 |
-
# result[idx] = w
|
126 |
return all_result
|
127 |
|
|
|
128 |
def _process_colbert_vecs(self, colbert_vecs, tokens_num) -> List[torch.Tensor]:
|
129 |
# delte the vectors of padding tokens
|
130 |
vecs = []
|
|
|
42 |
|
43 |
self.init_weights()
|
44 |
|
45 |
+
# Copied from FlagEmbedding
|
46 |
def dense_embedding(self, hidden_state, mask):
|
47 |
if self.sentence_pooling_method == "cls":
|
48 |
return hidden_state[:, 0]
|
|
|
51 |
d = mask.sum(axis=1, keepdim=True).float()
|
52 |
return s / d
|
53 |
|
54 |
+
# Copied from FlagEmbedding
|
55 |
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = False):
|
56 |
token_weights = torch.relu(self.sparse_linear(hidden_state))
|
57 |
if not return_embedding:
|
|
|
71 |
sparse_embedding[:, unused_tokens] *= 0.0
|
72 |
return sparse_embedding
|
73 |
|
74 |
+
# Copied from FlagEmbedding
|
75 |
def colbert_embedding(self, last_hidden_state, mask):
|
76 |
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
|
77 |
colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
|
78 |
return colbert_vecs
|
79 |
|
80 |
+
# Modified from FlagEmbedding
|
81 |
def _process_token_weights(self, token_weights, input_ids, mask):
|
82 |
token_weights = token_weights.squeeze(-1)
|
83 |
# conver to dict
|
|
|
85 |
unused_tokens = self.config.unused_tokens
|
86 |
unused_tokens = torch.tensor(unused_tokens, device=input_ids.device)
|
87 |
|
88 |
+
# Get valid matrix
|
89 |
valid_indices = ~torch.isin(input_ids, unused_tokens)
|
90 |
+
# w>0
|
91 |
valid_indices = (valid_indices & (token_weights > 0)).bool()
|
|
|
92 |
valid_indices = (valid_indices & mask).bool()
|
93 |
|
94 |
for i, valid in enumerate(valid_indices):
|
95 |
result = defaultdict(int)
|
96 |
|
97 |
+
# Get valid weight and ids
|
98 |
valid_weights = token_weights[i][valid]
|
99 |
valid_ids = input_ids[i][valid]
|
100 |
|
101 |
+
# Get unique token
|
102 |
unique_ids, inverse_indices = torch.unique(valid_ids, return_inverse=True)
|
103 |
|
104 |
+
# Get max weight for each token
|
105 |
for i in range(unique_ids.shape[0]):
|
106 |
id_mask = inverse_indices == i
|
107 |
result[str(unique_ids[i].item())] = valid_weights[id_mask].max().item()
|
108 |
|
109 |
all_result.append(result)
|
110 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
return all_result
|
112 |
|
113 |
+
# Copied from FlagEmbedding
|
114 |
def _process_colbert_vecs(self, colbert_vecs, tokens_num) -> List[torch.Tensor]:
|
115 |
# delte the vectors of padding tokens
|
116 |
vecs = []
|