Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -368,6 +368,7 @@ class Encoder(nn.Module):
|
|
368 |
kernel_size=config.conv_kernel_size_input,
|
369 |
bias=False,
|
370 |
separable=True,
|
|
|
371 |
)
|
372 |
self.df_conv1 = CausalConv2d(
|
373 |
in_channels=config.conv_channels,
|
@@ -386,11 +387,11 @@ class Encoder(nn.Module):
|
|
386 |
nn.ReLU(inplace=True)
|
387 |
)
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
|
395 |
# emb_gru
|
396 |
if config.spec_bins % 8 != 0:
|
@@ -430,18 +431,18 @@ class Encoder(nn.Module):
|
|
430 |
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
431 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
432 |
|
433 |
-
#
|
434 |
-
|
435 |
-
|
436 |
-
#
|
437 |
-
#
|
438 |
-
|
439 |
-
|
440 |
-
#
|
441 |
-
|
442 |
-
#
|
443 |
-
|
444 |
-
#
|
445 |
|
446 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
447 |
emb = e3.permute(0, 2, 3, 1)
|
@@ -449,7 +450,7 @@ class Encoder(nn.Module):
|
|
449 |
emb = emb.flatten(2)
|
450 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
451 |
|
452 |
-
|
453 |
# # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
454 |
# # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
455 |
|
|
|
368 |
kernel_size=config.conv_kernel_size_input,
|
369 |
bias=False,
|
370 |
separable=True,
|
371 |
+
fstride=1,
|
372 |
)
|
373 |
self.df_conv1 = CausalConv2d(
|
374 |
in_channels=config.conv_channels,
|
|
|
387 |
nn.ReLU(inplace=True)
|
388 |
)
|
389 |
|
390 |
+
if config.encoder_combine_op == "concat":
|
391 |
+
self.embedding_input_size *= 2
|
392 |
+
self.combine = Concat()
|
393 |
+
else:
|
394 |
+
self.combine = Add()
|
395 |
|
396 |
# emb_gru
|
397 |
if config.spec_bins % 8 != 0:
|
|
|
431 |
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
432 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
433 |
|
434 |
+
# feat_spec, shape: (batch_size, 2, time_steps, df_bins)
|
435 |
+
c0 = self.df_conv0(feat_spec)
|
436 |
+
c1 = self.df_conv1(c0)
|
437 |
+
# c0 shape: [batch_size, channels, time_steps, df_bins]
|
438 |
+
# c1 shape: [batch_size, channels, time_steps, df_bins // 2]
|
439 |
+
|
440 |
+
cemb = c1.permute(0, 2, 3, 1)
|
441 |
+
# cemb shape: [batch_size, time_steps, df_bins // 2, channels]
|
442 |
+
cemb = cemb.flatten(2)
|
443 |
+
# cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
|
444 |
+
cemb = self.df_fc_emb(cemb)
|
445 |
+
# cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
446 |
|
447 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
448 |
emb = e3.permute(0, 2, 3, 1)
|
|
|
450 |
emb = emb.flatten(2)
|
451 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
452 |
|
453 |
+
emb = self.combine(emb, cemb)
|
454 |
# # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
455 |
# # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
456 |
|