Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -387,11 +387,11 @@ class Encoder(nn.Module):
|
|
387 |
nn.ReLU(inplace=True)
|
388 |
)
|
389 |
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
|
396 |
# emb_gru
|
397 |
if config.spec_bins % 8 != 0:
|
@@ -431,18 +431,18 @@ class Encoder(nn.Module):
|
|
431 |
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
432 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
433 |
|
434 |
-
#
|
435 |
-
|
436 |
-
|
437 |
-
#
|
438 |
-
#
|
439 |
-
|
440 |
-
|
441 |
-
#
|
442 |
-
|
443 |
-
#
|
444 |
-
|
445 |
-
#
|
446 |
|
447 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
448 |
emb = e3.permute(0, 2, 3, 1)
|
@@ -450,7 +450,7 @@ class Encoder(nn.Module):
|
|
450 |
emb = emb.flatten(2)
|
451 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
452 |
|
453 |
-
|
454 |
# if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
455 |
# if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
456 |
|
@@ -461,8 +461,7 @@ class Encoder(nn.Module):
|
|
461 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
462 |
# lsnr shape: [batch_size, time_steps, 1]
|
463 |
|
464 |
-
|
465 |
-
return e0, e1, e2, e3, emb, lsnr, h
|
466 |
|
467 |
|
468 |
class Decoder(nn.Module):
|
|
|
387 |
nn.ReLU(inplace=True)
|
388 |
)
|
389 |
|
390 |
+
if config.encoder_combine_op == "concat":
|
391 |
+
self.embedding_input_size *= 2
|
392 |
+
self.combine = Concat()
|
393 |
+
else:
|
394 |
+
self.combine = Add()
|
395 |
|
396 |
# emb_gru
|
397 |
if config.spec_bins % 8 != 0:
|
|
|
431 |
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
432 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
433 |
|
434 |
+
# feat_spec, shape: (batch_size, 2, time_steps, df_bins)
|
435 |
+
c0 = self.df_conv0(feat_spec)
|
436 |
+
c1 = self.df_conv1(c0)
|
437 |
+
# c0 shape: [batch_size, channels, time_steps, df_bins]
|
438 |
+
# c1 shape: [batch_size, channels, time_steps, df_bins // 2]
|
439 |
+
|
440 |
+
cemb = c1.permute(0, 2, 3, 1)
|
441 |
+
# cemb shape: [batch_size, time_steps, df_bins // 2, channels]
|
442 |
+
cemb = cemb.flatten(2)
|
443 |
+
# cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
|
444 |
+
cemb = self.df_fc_emb(cemb)
|
445 |
+
# cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
446 |
|
447 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
448 |
emb = e3.permute(0, 2, 3, 1)
|
|
|
450 |
emb = emb.flatten(2)
|
451 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
452 |
|
453 |
+
emb = self.combine(emb, cemb)
|
454 |
# if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
455 |
# if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
456 |
|
|
|
461 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
462 |
# lsnr shape: [batch_size, time_steps, 1]
|
463 |
|
464 |
+
return e0, e1, e2, e3, emb, c0, lsnr, h
|
|
|
465 |
|
466 |
|
467 |
class Decoder(nn.Module):
|