HoneyTian commited on
Commit
3a53e4a
·
1 Parent(s): 413efc4
examples/nx_clean_unet/yaml/config.yaml CHANGED
@@ -16,9 +16,10 @@ tsfm_hidden_size: 256
16
  tsfm_attention_heads: 8
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
- tsfm_max_length: 5120
20
  tsfm_chunk_size: 4
21
  tsfm_num_left_chunks: 64
 
22
 
23
  discriminator_dim: 32
24
  discriminator_in_channel: 2
 
16
  tsfm_attention_heads: 8
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
+ tsfm_max_length: 512
20
  tsfm_chunk_size: 4
21
  tsfm_num_left_chunks: 64
22
+ tsfm_num_right_chunks: 2
23
 
24
  discriminator_dim: 32
25
  discriminator_in_channel: 2
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py CHANGED
@@ -25,8 +25,9 @@ class NXCleanUNetConfig(PretrainedConfig):
25
  tsfm_num_blocks: int = 6,
26
  tsfm_dropout_rate: float = 0.1,
27
  tsfm_max_length: int = 1024,
28
- tsfm_chunk_size: int = 1,
29
  tsfm_num_left_chunks: int = 128,
 
30
 
31
  discriminator_dim: int = 16,
32
  discriminator_in_channel: int = 2,
@@ -62,6 +63,7 @@ class NXCleanUNetConfig(PretrainedConfig):
62
  self.tsfm_max_length = tsfm_max_length
63
  self.tsfm_chunk_size = tsfm_chunk_size
64
  self.tsfm_num_left_chunks = tsfm_num_left_chunks
 
65
 
66
  self.discriminator_dim = discriminator_dim
67
  self.discriminator_in_channel = discriminator_in_channel
 
25
  tsfm_num_blocks: int = 6,
26
  tsfm_dropout_rate: float = 0.1,
27
  tsfm_max_length: int = 1024,
28
+ tsfm_chunk_size: int = 4,
29
  tsfm_num_left_chunks: int = 128,
30
+ tsfm_num_right_chunks: int = 2,
31
 
32
  discriminator_dim: int = 16,
33
  discriminator_in_channel: int = 2,
 
63
  self.tsfm_max_length = tsfm_max_length
64
  self.tsfm_chunk_size = tsfm_chunk_size
65
  self.tsfm_num_left_chunks = tsfm_num_left_chunks
66
+ self.tsfm_num_right_chunks = tsfm_num_right_chunks
67
 
68
  self.discriminator_dim = discriminator_dim
69
  self.discriminator_in_channel = discriminator_in_channel
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -172,6 +172,9 @@ class NXCleanUNet(nn.Module):
172
  attention_heads=config.tsfm_attention_heads,
173
  num_blocks=config.tsfm_num_blocks,
174
  dropout_rate=config.tsfm_dropout_rate,
 
 
 
175
  )
176
  self.up_sampling = UpSampling(
177
  num_layers=config.down_sampling_num_layers,
 
172
  attention_heads=config.tsfm_attention_heads,
173
  num_blocks=config.tsfm_num_blocks,
174
  dropout_rate=config.tsfm_dropout_rate,
175
+ chunk_size=config.chunk_size,
176
+ num_left_chunks=config.num_left_chunks,
177
+ num_right_chunks=config.num_right_chunks,
178
  )
179
  self.up_sampling = UpSampling(
180
  num_layers=config.down_sampling_num_layers,
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import torch.nn as nn
8
 
9
 
10
- class MultiHeadAttention(nn.Module):
11
  def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
12
  """
13
  :param n_head: int. the number of heads.
@@ -86,14 +86,12 @@ class MultiHeadAttention(nn.Module):
86
  return self.linear_out(x) # (batch, time1, n_feat)
87
 
88
  def forward(self,
89
- query: torch.Tensor,
90
- key: torch.Tensor,
91
- value: torch.Tensor,
92
  mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
93
  cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
94
  ) -> Tuple[torch.Tensor, torch.Tensor]:
95
 
96
- q, k, v = self.forward_qkv(query, key, value)
97
 
98
  if cache.size(0) > 0:
99
  key_cache, value_cache = torch.split(
@@ -157,32 +155,40 @@ class RelativeMultiHeadSelfAttention(nn.Module):
157
  def forward_attention(self,
158
  value: torch.Tensor,
159
  scores: torch.Tensor,
160
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
161
  ) -> torch.Tensor:
162
  """
163
  compute attention context vector.
164
- :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
165
- :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
166
- :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
167
- (batch_size, time1, time2), (0, 0, 0) means fake mask.
168
- :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
169
- weighted by the attention score (batch_size, time1, time2).
170
  """
171
  n_batch = value.size(0)
172
- if mask.size(2) > 0: # time2 > 0
173
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
174
- # For last chunk, time2 might be larger than scores.size(-1)
175
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
176
  scores = scores.masked_fill(mask, -float('inf'))
177
- attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
178
  else:
179
- attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
 
180
 
181
  p_attn = self.dropout(attn)
182
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
183
- x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
184
 
185
- return self.linear_out(x) # (batch, time1, n_feat)
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def relative_position_encoding(self, length: int) -> torch.Tensor:
188
  """
@@ -197,18 +203,16 @@ class RelativeMultiHeadSelfAttention(nn.Module):
197
  return final_mat
198
 
199
  def forward(self,
200
- query: torch.Tensor,
201
- key: torch.Tensor,
202
- value: torch.Tensor,
203
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
204
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
205
  ) -> Tuple[torch.Tensor, torch.Tensor]:
206
  # attention! self attention.
207
 
208
- q, k, v = self.forward_qkv(query, key, value)
209
- # q shape: [batch_size, self.h, time_steps, self.d_k]
210
 
211
- if cache.size(0) > 0:
212
  key_cache, value_cache = torch.split(
213
  cache, cache.size(-1) // 2, dim=-1)
214
  k = torch.cat([key_cache, k], dim=2)
@@ -217,11 +221,13 @@ class RelativeMultiHeadSelfAttention(nn.Module):
217
  # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
218
  new_cache = torch.cat((k, v), dim=-1)
219
 
 
 
 
220
  # Compute relative position encoding
221
  q_length, k_length = q.size(2), k.size(2)
222
  relative_position = self.relative_position_encoding(k_length)
223
 
224
- # 流式推理时 q_length 与 k_length 不同。
225
  relative_position = relative_position[-q_length:]
226
 
227
  relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
@@ -229,11 +235,10 @@ class RelativeMultiHeadSelfAttention(nn.Module):
229
  relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
230
  relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
231
 
232
- native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
233
- # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
234
-
235
  relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
236
  # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
 
 
237
  scores = native_scores + relative_position_scores
238
 
239
  return self.forward_attention(v, scores, mask), new_cache
 
7
  import torch.nn as nn
8
 
9
 
10
+ class MultiHeadSelfAttention(nn.Module):
11
  def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
12
  """
13
  :param n_head: int. the number of heads.
 
86
  return self.linear_out(x) # (batch, time1, n_feat)
87
 
88
  def forward(self,
89
+ x: torch.Tensor,
 
 
90
  mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
91
  cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
92
  ) -> Tuple[torch.Tensor, torch.Tensor]:
93
 
94
+ q, k, v = self.forward_qkv(x, x, x)
95
 
96
  if cache.size(0) > 0:
97
  key_cache, value_cache = torch.split(
 
155
  def forward_attention(self,
156
  value: torch.Tensor,
157
  scores: torch.Tensor,
158
+ mask: torch.Tensor = None
159
  ) -> torch.Tensor:
160
  """
161
  compute attention context vector.
162
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
163
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
164
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
165
+ :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
166
+ weighted by the attention score (batch_size, query_time_steps, key_time_steps).
 
167
  """
168
  n_batch = value.size(0)
169
+ if mask is not None:
170
+ mask = mask.unsqueeze(1).eq(0)
171
+ # mask shape: [batch_size, 1, query_time_steps, key_time_steps]
 
172
  scores = scores.masked_fill(mask, -float('inf'))
173
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
174
  else:
175
+ attn = torch.softmax(scores, dim=-1)
176
+ # attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
177
 
178
  p_attn = self.dropout(attn)
 
 
179
 
180
+ x = torch.matmul(p_attn, value)
181
+ # x shape: [batch_size, n_head, query_time_steps, d_k]
182
+ x = x.transpose(1, 2)
183
+ # x shape: [batch_size, query_time_steps, n_head, d_k]
184
+
185
+ x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
186
+ # x shape: [batch_size, query_time_steps, n_head * d_k]
187
+ # x shape: [batch_size, query_time_steps, n_feat]
188
+
189
+ x = self.linear_out(x)
190
+ # x shape: [batch_size, query_time_steps, n_feat]
191
+ return x
192
 
193
  def relative_position_encoding(self, length: int) -> torch.Tensor:
194
  """
 
203
  return final_mat
204
 
205
  def forward(self,
206
+ x: torch.Tensor,
207
+ mask: torch.Tensor = None,
208
+ cache: torch.Tensor = None
 
 
209
  ) -> Tuple[torch.Tensor, torch.Tensor]:
210
  # attention! self attention.
211
 
212
+ q, k, v = self.forward_qkv(x, x, x)
213
+ # q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
214
 
215
+ if cache is not None:
216
  key_cache, value_cache = torch.split(
217
  cache, cache.size(-1) // 2, dim=-1)
218
  k = torch.cat([key_cache, k], dim=2)
 
221
  # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
222
  new_cache = torch.cat((k, v), dim=-1)
223
 
224
+ # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
225
+ native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
226
+
227
  # Compute relative position encoding
228
  q_length, k_length = q.size(2), k.size(2)
229
  relative_position = self.relative_position_encoding(k_length)
230
 
 
231
  relative_position = relative_position[-q_length:]
232
 
233
  relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
 
235
  relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
236
  relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
237
 
 
 
 
238
  relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
239
  # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
240
+
241
+ # score
242
  scores = native_scores + relative_position_scores
243
 
244
  return self.forward_attention(v, scores, mask), new_cache
toolbox/torchaudio/models/nx_clean_unet/transformer/mask.py CHANGED
@@ -25,6 +25,7 @@ def subsequent_chunk_mask(
25
  size: int,
26
  chunk_size: int,
27
  num_left_chunks: int = -1,
 
28
  device: torch.device = torch.device("cpu"),
29
  ) -> torch.Tensor:
30
  """
@@ -41,6 +42,7 @@ def subsequent_chunk_mask(
41
  :param size: int. size of mask.
42
  :param chunk_size: int. size of chunk.
43
  :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
 
44
  :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
45
  :return: torch.Tensor. mask
46
  """
@@ -51,7 +53,7 @@ def subsequent_chunk_mask(
51
  start = 0
52
  else:
53
  start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
54
- ending = min((i // chunk_size + 1) * chunk_size, size)
55
  ret[i, start:ending] = True
56
  return ret
57
 
@@ -59,6 +61,12 @@ def subsequent_chunk_mask(
59
  def main():
60
  chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
61
  print(chunk_mask)
 
 
 
 
 
 
62
  return
63
 
64
 
 
25
  size: int,
26
  chunk_size: int,
27
  num_left_chunks: int = -1,
28
+ num_right_chunks: int = 0,
29
  device: torch.device = torch.device("cpu"),
30
  ) -> torch.Tensor:
31
  """
 
42
  :param size: int. size of mask.
43
  :param chunk_size: int. size of chunk.
44
  :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
45
+ :param num_right_chunks: int. number of right chunks.
46
  :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
47
  :return: torch.Tensor. mask
48
  """
 
53
  start = 0
54
  else:
55
  start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
56
+ ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
57
  ret[i, start:ending] = True
58
  return ret
59
 
 
61
  def main():
62
  chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
63
  print(chunk_mask)
64
+
65
+ chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
66
+ print(chunk_mask)
67
+
68
+ chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
69
+ print(chunk_mask)
70
  return
71
 
72
 
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import torch.nn as nn
8
 
9
  from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
10
- from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadAttention, RelativeMultiHeadSelfAttention
11
 
12
 
13
  class PositionwiseFeedForward(nn.Module):
@@ -87,7 +87,7 @@ class TransformerEncoderLayer(nn.Module):
87
  xt = self.norm1(x)
88
 
89
  x_att, new_att_cache = self.attention.forward(
90
- xt, xt, xt, mask=mask, cache=attention_cache
91
  )
92
  x = x + self.dropout1(xt)
93
  xt = self.norm2(x)
@@ -112,6 +112,7 @@ class TransformerEncoder(nn.Module):
112
  max_relative_position: int = 1024,
113
  chunk_size: int = 1,
114
  num_left_chunks: int = 128,
 
115
  ):
116
  super().__init__()
117
  self.input_size = input_size
@@ -120,6 +121,7 @@ class TransformerEncoder(nn.Module):
120
  self.max_relative_position = max_relative_position
121
  self.chunk_size = chunk_size
122
  self.num_left_chunks = num_left_chunks
 
123
 
124
  self.input_linear = nn.Linear(
125
  in_features=self.input_size,
@@ -155,7 +157,8 @@ class TransformerEncoder(nn.Module):
155
  chunk_masks = subsequent_chunk_mask(
156
  size=time_steps,
157
  chunk_size=self.chunk_size,
158
- num_left_chunks=self.num_left_chunks
 
159
  )
160
  chunk_masks = chunk_masks.to(xs.device)
161
  # chunk_masks shape: [1, time_steps, time_steps]
 
7
  import torch.nn as nn
8
 
9
  from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
10
+ from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
11
 
12
 
13
  class PositionwiseFeedForward(nn.Module):
 
87
  xt = self.norm1(x)
88
 
89
  x_att, new_att_cache = self.attention.forward(
90
+ xt, mask=mask, cache=attention_cache
91
  )
92
  x = x + self.dropout1(xt)
93
  xt = self.norm2(x)
 
112
  max_relative_position: int = 1024,
113
  chunk_size: int = 1,
114
  num_left_chunks: int = 128,
115
+ num_right_chunks: int = 2,
116
  ):
117
  super().__init__()
118
  self.input_size = input_size
 
121
  self.max_relative_position = max_relative_position
122
  self.chunk_size = chunk_size
123
  self.num_left_chunks = num_left_chunks
124
+ self.num_right_chunks = num_right_chunks
125
 
126
  self.input_linear = nn.Linear(
127
  in_features=self.input_size,
 
157
  chunk_masks = subsequent_chunk_mask(
158
  size=time_steps,
159
  chunk_size=self.chunk_size,
160
+ num_left_chunks=self.num_left_chunks,
161
+ num_right_chunks=self.num_right_chunks,
162
  )
163
  chunk_masks = chunk_masks.to(xs.device)
164
  # chunk_masks shape: [1, time_steps, time_steps]
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml CHANGED
@@ -10,23 +10,24 @@ hop_size: 80
10
  # 例如 2**5=32 就意味着 32个值在降采样之后是一个时间步,
11
  # 则一步是 32/sample_rate = 0.004秒。
12
  # 那么 tsfm_chunk_size=4 则为16ms,tsfm_chunk_size=8 则为32ms
13
- # 假设每次向左看1秒,则:
14
- # tsfm_chunk_size=1,tsfm_num_left_chunks: 256
15
- # tsfm_chunk_size=4,tsfm_num_left_chunks: 64
16
- # tsfm_chunk_size=8,tsfm_num_left_chunks: 32
17
  down_sampling_num_layers: 5
18
  down_sampling_in_channels: 1
19
  down_sampling_hidden_channels: 64
20
  down_sampling_kernel_size: 4
21
  down_sampling_stride: 2
22
 
23
- tsfm_hidden_size: 64
24
- tsfm_attention_heads: 4
25
  tsfm_num_blocks: 6
26
  tsfm_dropout_rate: 0.1
27
- tsfm_max_length: 5120
28
  tsfm_chunk_size: 4
29
  tsfm_num_left_chunks: 64
 
30
 
31
  discriminator_dim: 32
32
  discriminator_in_channel: 2
 
10
  # 例如 2**5=32 就意味着 32个值在降采样之后是一个时间步,
11
  # 则一步是 32/sample_rate = 0.004秒。
12
  # 那么 tsfm_chunk_size=4 则为16ms,tsfm_chunk_size=8 则为32ms
13
+ # 假设每次向左看1秒,向右看30ms,则:
14
+ # tsfm_chunk_size=1,tsfm_num_left_chunks=256,tsfm_num_right_chunks=8
15
+ # tsfm_chunk_size=4,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
16
+ # tsfm_chunk_size=8,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
17
  down_sampling_num_layers: 5
18
  down_sampling_in_channels: 1
19
  down_sampling_hidden_channels: 64
20
  down_sampling_kernel_size: 4
21
  down_sampling_stride: 2
22
 
23
+ tsfm_hidden_size: 256
24
+ tsfm_attention_heads: 8
25
  tsfm_num_blocks: 6
26
  tsfm_dropout_rate: 0.1
27
+ tsfm_max_length: 512
28
  tsfm_chunk_size: 4
29
  tsfm_num_left_chunks: 64
30
+ tsfm_num_right_chunks: 2
31
 
32
  discriminator_dim: 32
33
  discriminator_in_channel: 2