Spaces:
Running
Running
update
Browse files- examples/nx_clean_unet/yaml/config.yaml +2 -1
- toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py +3 -1
- toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py +3 -0
- toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py +38 -33
- toolbox/torchaudio/models/nx_clean_unet/transformer/mask.py +9 -1
- toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py +6 -3
- toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml +8 -7
examples/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -16,9 +16,10 @@ tsfm_hidden_size: 256
|
|
16 |
tsfm_attention_heads: 8
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
-
tsfm_max_length:
|
20 |
tsfm_chunk_size: 4
|
21 |
tsfm_num_left_chunks: 64
|
|
|
22 |
|
23 |
discriminator_dim: 32
|
24 |
discriminator_in_channel: 2
|
|
|
16 |
tsfm_attention_heads: 8
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
+
tsfm_max_length: 512
|
20 |
tsfm_chunk_size: 4
|
21 |
tsfm_num_left_chunks: 64
|
22 |
+
tsfm_num_right_chunks: 2
|
23 |
|
24 |
discriminator_dim: 32
|
25 |
discriminator_in_channel: 2
|
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py
CHANGED
@@ -25,8 +25,9 @@ class NXCleanUNetConfig(PretrainedConfig):
|
|
25 |
tsfm_num_blocks: int = 6,
|
26 |
tsfm_dropout_rate: float = 0.1,
|
27 |
tsfm_max_length: int = 1024,
|
28 |
-
tsfm_chunk_size: int =
|
29 |
tsfm_num_left_chunks: int = 128,
|
|
|
30 |
|
31 |
discriminator_dim: int = 16,
|
32 |
discriminator_in_channel: int = 2,
|
@@ -62,6 +63,7 @@ class NXCleanUNetConfig(PretrainedConfig):
|
|
62 |
self.tsfm_max_length = tsfm_max_length
|
63 |
self.tsfm_chunk_size = tsfm_chunk_size
|
64 |
self.tsfm_num_left_chunks = tsfm_num_left_chunks
|
|
|
65 |
|
66 |
self.discriminator_dim = discriminator_dim
|
67 |
self.discriminator_in_channel = discriminator_in_channel
|
|
|
25 |
tsfm_num_blocks: int = 6,
|
26 |
tsfm_dropout_rate: float = 0.1,
|
27 |
tsfm_max_length: int = 1024,
|
28 |
+
tsfm_chunk_size: int = 4,
|
29 |
tsfm_num_left_chunks: int = 128,
|
30 |
+
tsfm_num_right_chunks: int = 2,
|
31 |
|
32 |
discriminator_dim: int = 16,
|
33 |
discriminator_in_channel: int = 2,
|
|
|
63 |
self.tsfm_max_length = tsfm_max_length
|
64 |
self.tsfm_chunk_size = tsfm_chunk_size
|
65 |
self.tsfm_num_left_chunks = tsfm_num_left_chunks
|
66 |
+
self.tsfm_num_right_chunks = tsfm_num_right_chunks
|
67 |
|
68 |
self.discriminator_dim = discriminator_dim
|
69 |
self.discriminator_in_channel = discriminator_in_channel
|
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
CHANGED
@@ -172,6 +172,9 @@ class NXCleanUNet(nn.Module):
|
|
172 |
attention_heads=config.tsfm_attention_heads,
|
173 |
num_blocks=config.tsfm_num_blocks,
|
174 |
dropout_rate=config.tsfm_dropout_rate,
|
|
|
|
|
|
|
175 |
)
|
176 |
self.up_sampling = UpSampling(
|
177 |
num_layers=config.down_sampling_num_layers,
|
|
|
172 |
attention_heads=config.tsfm_attention_heads,
|
173 |
num_blocks=config.tsfm_num_blocks,
|
174 |
dropout_rate=config.tsfm_dropout_rate,
|
175 |
+
chunk_size=config.chunk_size,
|
176 |
+
num_left_chunks=config.num_left_chunks,
|
177 |
+
num_right_chunks=config.num_right_chunks,
|
178 |
)
|
179 |
self.up_sampling = UpSampling(
|
180 |
num_layers=config.down_sampling_num_layers,
|
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
|
10 |
-
class
|
11 |
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
12 |
"""
|
13 |
:param n_head: int. the number of heads.
|
@@ -86,14 +86,12 @@ class MultiHeadAttention(nn.Module):
|
|
86 |
return self.linear_out(x) # (batch, time1, n_feat)
|
87 |
|
88 |
def forward(self,
|
89 |
-
|
90 |
-
key: torch.Tensor,
|
91 |
-
value: torch.Tensor,
|
92 |
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
93 |
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
94 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95 |
|
96 |
-
q, k, v = self.forward_qkv(
|
97 |
|
98 |
if cache.size(0) > 0:
|
99 |
key_cache, value_cache = torch.split(
|
@@ -157,32 +155,40 @@ class RelativeMultiHeadSelfAttention(nn.Module):
|
|
157 |
def forward_attention(self,
|
158 |
value: torch.Tensor,
|
159 |
scores: torch.Tensor,
|
160 |
-
mask: torch.Tensor =
|
161 |
) -> torch.Tensor:
|
162 |
"""
|
163 |
compute attention context vector.
|
164 |
-
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head,
|
165 |
-
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head,
|
166 |
-
:param mask: torch.Tensor. mask. shape=(batch_size, 1,
|
167 |
-
|
168 |
-
|
169 |
-
weighted by the attention score (batch_size, time1, time2).
|
170 |
"""
|
171 |
n_batch = value.size(0)
|
172 |
-
if mask
|
173 |
-
mask = mask.unsqueeze(1).eq(0)
|
174 |
-
#
|
175 |
-
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
176 |
scores = scores.masked_fill(mask, -float('inf'))
|
177 |
-
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
178 |
else:
|
179 |
-
attn = torch.softmax(scores, dim=-1)
|
|
|
180 |
|
181 |
p_attn = self.dropout(attn)
|
182 |
-
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
183 |
-
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
184 |
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
def relative_position_encoding(self, length: int) -> torch.Tensor:
|
188 |
"""
|
@@ -197,18 +203,16 @@ class RelativeMultiHeadSelfAttention(nn.Module):
|
|
197 |
return final_mat
|
198 |
|
199 |
def forward(self,
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
204 |
-
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
205 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
206 |
# attention! self attention.
|
207 |
|
208 |
-
q, k, v = self.forward_qkv(
|
209 |
-
# q shape: [batch_size, self.h,
|
210 |
|
211 |
-
if cache
|
212 |
key_cache, value_cache = torch.split(
|
213 |
cache, cache.size(-1) // 2, dim=-1)
|
214 |
k = torch.cat([key_cache, k], dim=2)
|
@@ -217,11 +221,13 @@ class RelativeMultiHeadSelfAttention(nn.Module):
|
|
217 |
# new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
|
218 |
new_cache = torch.cat((k, v), dim=-1)
|
219 |
|
|
|
|
|
|
|
220 |
# Compute relative position encoding
|
221 |
q_length, k_length = q.size(2), k.size(2)
|
222 |
relative_position = self.relative_position_encoding(k_length)
|
223 |
|
224 |
-
# 流式推理时 q_length 与 k_length 不同。
|
225 |
relative_position = relative_position[-q_length:]
|
226 |
|
227 |
relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
|
@@ -229,11 +235,10 @@ class RelativeMultiHeadSelfAttention(nn.Module):
|
|
229 |
relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
|
230 |
relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
|
231 |
|
232 |
-
native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
233 |
-
# native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
234 |
-
|
235 |
relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
|
236 |
# relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
|
|
|
|
237 |
scores = native_scores + relative_position_scores
|
238 |
|
239 |
return self.forward_attention(v, scores, mask), new_cache
|
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
|
10 |
+
class MultiHeadSelfAttention(nn.Module):
|
11 |
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
12 |
"""
|
13 |
:param n_head: int. the number of heads.
|
|
|
86 |
return self.linear_out(x) # (batch, time1, n_feat)
|
87 |
|
88 |
def forward(self,
|
89 |
+
x: torch.Tensor,
|
|
|
|
|
90 |
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
91 |
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
92 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
93 |
|
94 |
+
q, k, v = self.forward_qkv(x, x, x)
|
95 |
|
96 |
if cache.size(0) > 0:
|
97 |
key_cache, value_cache = torch.split(
|
|
|
155 |
def forward_attention(self,
|
156 |
value: torch.Tensor,
|
157 |
scores: torch.Tensor,
|
158 |
+
mask: torch.Tensor = None
|
159 |
) -> torch.Tensor:
|
160 |
"""
|
161 |
compute attention context vector.
|
162 |
+
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
|
163 |
+
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
|
164 |
+
:param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
|
165 |
+
:return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
|
166 |
+
weighted by the attention score (batch_size, query_time_steps, key_time_steps).
|
|
|
167 |
"""
|
168 |
n_batch = value.size(0)
|
169 |
+
if mask is not None:
|
170 |
+
mask = mask.unsqueeze(1).eq(0)
|
171 |
+
# mask shape: [batch_size, 1, query_time_steps, key_time_steps]
|
|
|
172 |
scores = scores.masked_fill(mask, -float('inf'))
|
173 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
174 |
else:
|
175 |
+
attn = torch.softmax(scores, dim=-1)
|
176 |
+
# attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
|
177 |
|
178 |
p_attn = self.dropout(attn)
|
|
|
|
|
179 |
|
180 |
+
x = torch.matmul(p_attn, value)
|
181 |
+
# x shape: [batch_size, n_head, query_time_steps, d_k]
|
182 |
+
x = x.transpose(1, 2)
|
183 |
+
# x shape: [batch_size, query_time_steps, n_head, d_k]
|
184 |
+
|
185 |
+
x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
186 |
+
# x shape: [batch_size, query_time_steps, n_head * d_k]
|
187 |
+
# x shape: [batch_size, query_time_steps, n_feat]
|
188 |
+
|
189 |
+
x = self.linear_out(x)
|
190 |
+
# x shape: [batch_size, query_time_steps, n_feat]
|
191 |
+
return x
|
192 |
|
193 |
def relative_position_encoding(self, length: int) -> torch.Tensor:
|
194 |
"""
|
|
|
203 |
return final_mat
|
204 |
|
205 |
def forward(self,
|
206 |
+
x: torch.Tensor,
|
207 |
+
mask: torch.Tensor = None,
|
208 |
+
cache: torch.Tensor = None
|
|
|
|
|
209 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
210 |
# attention! self attention.
|
211 |
|
212 |
+
q, k, v = self.forward_qkv(x, x, x)
|
213 |
+
# q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
|
214 |
|
215 |
+
if cache is not None:
|
216 |
key_cache, value_cache = torch.split(
|
217 |
cache, cache.size(-1) // 2, dim=-1)
|
218 |
k = torch.cat([key_cache, k], dim=2)
|
|
|
221 |
# new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
|
222 |
new_cache = torch.cat((k, v), dim=-1)
|
223 |
|
224 |
+
# native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
225 |
+
native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
226 |
+
|
227 |
# Compute relative position encoding
|
228 |
q_length, k_length = q.size(2), k.size(2)
|
229 |
relative_position = self.relative_position_encoding(k_length)
|
230 |
|
|
|
231 |
relative_position = relative_position[-q_length:]
|
232 |
|
233 |
relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
|
|
|
235 |
relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
|
236 |
relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
|
237 |
|
|
|
|
|
|
|
238 |
relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
|
239 |
# relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
240 |
+
|
241 |
+
# score
|
242 |
scores = native_scores + relative_position_scores
|
243 |
|
244 |
return self.forward_attention(v, scores, mask), new_cache
|
toolbox/torchaudio/models/nx_clean_unet/transformer/mask.py
CHANGED
@@ -25,6 +25,7 @@ def subsequent_chunk_mask(
|
|
25 |
size: int,
|
26 |
chunk_size: int,
|
27 |
num_left_chunks: int = -1,
|
|
|
28 |
device: torch.device = torch.device("cpu"),
|
29 |
) -> torch.Tensor:
|
30 |
"""
|
@@ -41,6 +42,7 @@ def subsequent_chunk_mask(
|
|
41 |
:param size: int. size of mask.
|
42 |
:param chunk_size: int. size of chunk.
|
43 |
:param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
|
|
|
44 |
:param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
|
45 |
:return: torch.Tensor. mask
|
46 |
"""
|
@@ -51,7 +53,7 @@ def subsequent_chunk_mask(
|
|
51 |
start = 0
|
52 |
else:
|
53 |
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
54 |
-
ending = min((i // chunk_size + 1) * chunk_size, size)
|
55 |
ret[i, start:ending] = True
|
56 |
return ret
|
57 |
|
@@ -59,6 +61,12 @@ def subsequent_chunk_mask(
|
|
59 |
def main():
|
60 |
chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
|
61 |
print(chunk_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
return
|
63 |
|
64 |
|
|
|
25 |
size: int,
|
26 |
chunk_size: int,
|
27 |
num_left_chunks: int = -1,
|
28 |
+
num_right_chunks: int = 0,
|
29 |
device: torch.device = torch.device("cpu"),
|
30 |
) -> torch.Tensor:
|
31 |
"""
|
|
|
42 |
:param size: int. size of mask.
|
43 |
:param chunk_size: int. size of chunk.
|
44 |
:param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
|
45 |
+
:param num_right_chunks: int. number of right chunks.
|
46 |
:param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
|
47 |
:return: torch.Tensor. mask
|
48 |
"""
|
|
|
53 |
start = 0
|
54 |
else:
|
55 |
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
56 |
+
ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
|
57 |
ret[i, start:ending] = True
|
58 |
return ret
|
59 |
|
|
|
61 |
def main():
|
62 |
chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
|
63 |
print(chunk_mask)
|
64 |
+
|
65 |
+
chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
|
66 |
+
print(chunk_mask)
|
67 |
+
|
68 |
+
chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
|
69 |
+
print(chunk_mask)
|
70 |
return
|
71 |
|
72 |
|
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
10 |
-
from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import
|
11 |
|
12 |
|
13 |
class PositionwiseFeedForward(nn.Module):
|
@@ -87,7 +87,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
87 |
xt = self.norm1(x)
|
88 |
|
89 |
x_att, new_att_cache = self.attention.forward(
|
90 |
-
xt,
|
91 |
)
|
92 |
x = x + self.dropout1(xt)
|
93 |
xt = self.norm2(x)
|
@@ -112,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
|
112 |
max_relative_position: int = 1024,
|
113 |
chunk_size: int = 1,
|
114 |
num_left_chunks: int = 128,
|
|
|
115 |
):
|
116 |
super().__init__()
|
117 |
self.input_size = input_size
|
@@ -120,6 +121,7 @@ class TransformerEncoder(nn.Module):
|
|
120 |
self.max_relative_position = max_relative_position
|
121 |
self.chunk_size = chunk_size
|
122 |
self.num_left_chunks = num_left_chunks
|
|
|
123 |
|
124 |
self.input_linear = nn.Linear(
|
125 |
in_features=self.input_size,
|
@@ -155,7 +157,8 @@ class TransformerEncoder(nn.Module):
|
|
155 |
chunk_masks = subsequent_chunk_mask(
|
156 |
size=time_steps,
|
157 |
chunk_size=self.chunk_size,
|
158 |
-
num_left_chunks=self.num_left_chunks
|
|
|
159 |
)
|
160 |
chunk_masks = chunk_masks.to(xs.device)
|
161 |
# chunk_masks shape: [1, time_steps, time_steps]
|
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
10 |
+
from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
|
11 |
|
12 |
|
13 |
class PositionwiseFeedForward(nn.Module):
|
|
|
87 |
xt = self.norm1(x)
|
88 |
|
89 |
x_att, new_att_cache = self.attention.forward(
|
90 |
+
xt, mask=mask, cache=attention_cache
|
91 |
)
|
92 |
x = x + self.dropout1(xt)
|
93 |
xt = self.norm2(x)
|
|
|
112 |
max_relative_position: int = 1024,
|
113 |
chunk_size: int = 1,
|
114 |
num_left_chunks: int = 128,
|
115 |
+
num_right_chunks: int = 2,
|
116 |
):
|
117 |
super().__init__()
|
118 |
self.input_size = input_size
|
|
|
121 |
self.max_relative_position = max_relative_position
|
122 |
self.chunk_size = chunk_size
|
123 |
self.num_left_chunks = num_left_chunks
|
124 |
+
self.num_right_chunks = num_right_chunks
|
125 |
|
126 |
self.input_linear = nn.Linear(
|
127 |
in_features=self.input_size,
|
|
|
157 |
chunk_masks = subsequent_chunk_mask(
|
158 |
size=time_steps,
|
159 |
chunk_size=self.chunk_size,
|
160 |
+
num_left_chunks=self.num_left_chunks,
|
161 |
+
num_right_chunks=self.num_right_chunks,
|
162 |
)
|
163 |
chunk_masks = chunk_masks.to(xs.device)
|
164 |
# chunk_masks shape: [1, time_steps, time_steps]
|
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -10,23 +10,24 @@ hop_size: 80
|
|
10 |
# 例如 2**5=32 就意味着 32个值在降采样之后是一个时间步,
|
11 |
# 则一步是 32/sample_rate = 0.004秒。
|
12 |
# 那么 tsfm_chunk_size=4 则为16ms,tsfm_chunk_size=8 则为32ms
|
13 |
-
# 假设每次向左看1
|
14 |
-
# tsfm_chunk_size=1,tsfm_num_left_chunks
|
15 |
-
# tsfm_chunk_size=4,tsfm_num_left_chunks
|
16 |
-
# tsfm_chunk_size=8,tsfm_num_left_chunks
|
17 |
down_sampling_num_layers: 5
|
18 |
down_sampling_in_channels: 1
|
19 |
down_sampling_hidden_channels: 64
|
20 |
down_sampling_kernel_size: 4
|
21 |
down_sampling_stride: 2
|
22 |
|
23 |
-
tsfm_hidden_size:
|
24 |
-
tsfm_attention_heads:
|
25 |
tsfm_num_blocks: 6
|
26 |
tsfm_dropout_rate: 0.1
|
27 |
-
tsfm_max_length:
|
28 |
tsfm_chunk_size: 4
|
29 |
tsfm_num_left_chunks: 64
|
|
|
30 |
|
31 |
discriminator_dim: 32
|
32 |
discriminator_in_channel: 2
|
|
|
10 |
# 例如 2**5=32 就意味着 32个值在降采样之后是一个时间步,
|
11 |
# 则一步是 32/sample_rate = 0.004秒。
|
12 |
# 那么 tsfm_chunk_size=4 则为16ms,tsfm_chunk_size=8 则为32ms
|
13 |
+
# 假设每次向左看1秒,向右看30ms,则:
|
14 |
+
# tsfm_chunk_size=1,tsfm_num_left_chunks=256,tsfm_num_right_chunks=8
|
15 |
+
# tsfm_chunk_size=4,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
|
16 |
+
# tsfm_chunk_size=8,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
|
17 |
down_sampling_num_layers: 5
|
18 |
down_sampling_in_channels: 1
|
19 |
down_sampling_hidden_channels: 64
|
20 |
down_sampling_kernel_size: 4
|
21 |
down_sampling_stride: 2
|
22 |
|
23 |
+
tsfm_hidden_size: 256
|
24 |
+
tsfm_attention_heads: 8
|
25 |
tsfm_num_blocks: 6
|
26 |
tsfm_dropout_rate: 0.1
|
27 |
+
tsfm_max_length: 512
|
28 |
tsfm_chunk_size: 4
|
29 |
tsfm_num_left_chunks: 64
|
30 |
+
tsfm_num_right_chunks: 2
|
31 |
|
32 |
discriminator_dim: 32
|
33 |
discriminator_in_channel: 2
|