Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py
CHANGED
@@ -108,7 +108,7 @@ class MultiHeadedAttention(nn.Module):
|
|
108 |
return self.forward_attention(v, scores, mask), new_cache
|
109 |
|
110 |
|
111 |
-
class
|
112 |
|
113 |
def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
|
114 |
"""
|
@@ -203,30 +203,37 @@ class RelativeMultiHeadedAttention(nn.Module):
|
|
203 |
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
204 |
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
205 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
206 |
|
207 |
q, k, v = self.forward_qkv(query, key, value)
|
|
|
208 |
|
209 |
if cache.size(0) > 0:
|
210 |
key_cache, value_cache = torch.split(
|
211 |
cache, cache.size(-1) // 2, dim=-1)
|
212 |
k = torch.cat([key_cache, k], dim=2)
|
213 |
v = torch.cat([value_cache, v], dim=2)
|
214 |
-
# NOTE: We do cache slicing in encoder.forward_chunk, since it's
|
215 |
-
# non-trivial to calculate `next_cache_start` here.
|
216 |
|
217 |
-
# new_cache shape: [batch_size, self.h, time_steps, self.
|
218 |
new_cache = torch.cat((k, v), dim=-1)
|
219 |
|
220 |
# Compute relative position encoding
|
221 |
-
|
222 |
-
relative_position = self.relative_position_encoding(
|
223 |
-
|
|
|
|
|
224 |
|
225 |
-
relative_position_k = relative_position_k.
|
226 |
-
|
|
|
|
|
227 |
|
228 |
native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
|
|
|
229 |
relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
|
|
|
230 |
scores = native_scores + relative_position_scores
|
231 |
|
232 |
return self.forward_attention(v, scores, mask), new_cache
|
@@ -235,12 +242,13 @@ class RelativeMultiHeadedAttention(nn.Module):
|
|
235 |
def main():
|
236 |
rel_attention = RelativeMultiHeadedAttention(n_head=4, n_feat=256, dropout_rate=0.1)
|
237 |
|
238 |
-
|
|
|
239 |
|
240 |
-
x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
|
241 |
-
cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
|
|
|
242 |
|
243 |
-
xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
|
244 |
print(xt.shape)
|
245 |
print(new_cache.shape)
|
246 |
return
|
|
|
108 |
return self.forward_attention(v, scores, mask), new_cache
|
109 |
|
110 |
|
111 |
+
class RelativeMultiHeadSelfAttention(nn.Module):
|
112 |
|
113 |
def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
|
114 |
"""
|
|
|
203 |
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
204 |
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
205 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
206 |
+
# attention! self attention.
|
207 |
|
208 |
q, k, v = self.forward_qkv(query, key, value)
|
209 |
+
# q shape: [batch_size, self.h, time_steps, self.d_k]
|
210 |
|
211 |
if cache.size(0) > 0:
|
212 |
key_cache, value_cache = torch.split(
|
213 |
cache, cache.size(-1) // 2, dim=-1)
|
214 |
k = torch.cat([key_cache, k], dim=2)
|
215 |
v = torch.cat([value_cache, v], dim=2)
|
|
|
|
|
216 |
|
217 |
+
# new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
|
218 |
new_cache = torch.cat((k, v), dim=-1)
|
219 |
|
220 |
# Compute relative position encoding
|
221 |
+
q_length, k_length = q.size(2), k.size(2)
|
222 |
+
relative_position = self.relative_position_encoding(k_length)
|
223 |
+
|
224 |
+
# 流式推理时 q_length 与 k_length 不同。
|
225 |
+
relative_position = relative_position[-q_length:]
|
226 |
|
227 |
+
relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
|
228 |
+
|
229 |
+
relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
|
230 |
+
relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
|
231 |
|
232 |
native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
233 |
+
# native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
234 |
+
|
235 |
relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
|
236 |
+
# relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
237 |
scores = native_scores + relative_position_scores
|
238 |
|
239 |
return self.forward_attention(v, scores, mask), new_cache
|
|
|
242 |
def main():
|
243 |
rel_attention = RelativeMultiHeadedAttention(n_head=4, n_feat=256, dropout_rate=0.1)
|
244 |
|
245 |
+
x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
|
246 |
+
xt, new_cache = rel_attention.forward(x, x, x)
|
247 |
|
248 |
+
# x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
|
249 |
+
# cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
|
250 |
+
# xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
|
251 |
|
|
|
252 |
print(xt.shape)
|
253 |
print(new_cache.shape)
|
254 |
return
|