HoneyTian commited on
Commit
0f4cf3f
·
1 Parent(s): d983ee9
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py CHANGED
@@ -108,7 +108,7 @@ class MultiHeadedAttention(nn.Module):
108
  return self.forward_attention(v, scores, mask), new_cache
109
 
110
 
111
- class RelativeMultiHeadedAttention(nn.Module):
112
 
113
  def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
114
  """
@@ -203,30 +203,37 @@ class RelativeMultiHeadedAttention(nn.Module):
203
  mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
204
  cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
205
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
206
 
207
  q, k, v = self.forward_qkv(query, key, value)
 
208
 
209
  if cache.size(0) > 0:
210
  key_cache, value_cache = torch.split(
211
  cache, cache.size(-1) // 2, dim=-1)
212
  k = torch.cat([key_cache, k], dim=2)
213
  v = torch.cat([value_cache, v], dim=2)
214
- # NOTE: We do cache slicing in encoder.forward_chunk, since it's
215
- # non-trivial to calculate `next_cache_start` here.
216
 
217
- # new_cache shape: [batch_size, self.h, time_steps, self.d_v * 2]
218
  new_cache = torch.cat((k, v), dim=-1)
219
 
220
  # Compute relative position encoding
221
- length = q.size(2)
222
- relative_position = self.relative_position_encoding(length)
223
- relative_position_k = self.relative_position_k[relative_position.view(-1)].view(length, length, -1)
 
 
224
 
225
- relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, length, length, d_k)
226
- relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, length, length, d_k)
 
 
227
 
228
  native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
 
 
229
  relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
 
230
  scores = native_scores + relative_position_scores
231
 
232
  return self.forward_attention(v, scores, mask), new_cache
@@ -235,12 +242,13 @@ class RelativeMultiHeadedAttention(nn.Module):
235
  def main():
236
  rel_attention = RelativeMultiHeadedAttention(n_head=4, n_feat=256, dropout_rate=0.1)
237
 
238
- # x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
 
239
 
240
- x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
241
- cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
 
242
 
243
- xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
244
  print(xt.shape)
245
  print(new_cache.shape)
246
  return
 
108
  return self.forward_attention(v, scores, mask), new_cache
109
 
110
 
111
+ class RelativeMultiHeadSelfAttention(nn.Module):
112
 
113
  def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
114
  """
 
203
  mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
204
  cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
205
  ) -> Tuple[torch.Tensor, torch.Tensor]:
206
+ # attention! self attention.
207
 
208
  q, k, v = self.forward_qkv(query, key, value)
209
+ # q shape: [batch_size, self.h, time_steps, self.d_k]
210
 
211
  if cache.size(0) > 0:
212
  key_cache, value_cache = torch.split(
213
  cache, cache.size(-1) // 2, dim=-1)
214
  k = torch.cat([key_cache, k], dim=2)
215
  v = torch.cat([value_cache, v], dim=2)
 
 
216
 
217
+ # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
218
  new_cache = torch.cat((k, v), dim=-1)
219
 
220
  # Compute relative position encoding
221
+ q_length, k_length = q.size(2), k.size(2)
222
+ relative_position = self.relative_position_encoding(k_length)
223
+
224
+ # 流式推理时 q_length 与 k_length 不同。
225
+ relative_position = relative_position[-q_length:]
226
 
227
+ relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
228
+
229
+ relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
230
+ relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
231
 
232
  native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
233
+ # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
234
+
235
  relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
236
+ # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
237
  scores = native_scores + relative_position_scores
238
 
239
  return self.forward_attention(v, scores, mask), new_cache
 
242
  def main():
243
  rel_attention = RelativeMultiHeadedAttention(n_head=4, n_feat=256, dropout_rate=0.1)
244
 
245
+ x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
246
+ xt, new_cache = rel_attention.forward(x, x, x)
247
 
248
+ # x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
249
+ # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
250
+ # xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
251
 
 
252
  print(xt.shape)
253
  print(new_cache.shape)
254
  return