Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -386,11 +386,11 @@ class Encoder(nn.Module):
|
|
386 |
nn.ReLU(inplace=True)
|
387 |
)
|
388 |
|
389 |
-
if config.encoder_combine_op == "concat":
|
390 |
-
|
391 |
-
|
392 |
-
else:
|
393 |
-
|
394 |
|
395 |
# emb_gru
|
396 |
if config.spec_bins % 8 != 0:
|
@@ -430,18 +430,18 @@ class Encoder(nn.Module):
|
|
430 |
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
431 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
432 |
|
433 |
-
# feat_spec, shape: (batch_size, 2, time_steps, df_bins)
|
434 |
-
c0 = self.df_conv0(feat_spec)
|
435 |
-
c1 = self.df_conv1(c0)
|
436 |
-
# c0 shape: [batch_size, channels, time_steps, df_bins]
|
437 |
-
# c1 shape: [batch_size, channels, time_steps, df_bins // 2]
|
438 |
-
|
439 |
-
cemb = c1.permute(0, 2, 3, 1)
|
440 |
-
# cemb shape: [batch_size, time_steps, df_bins // 2, channels]
|
441 |
-
cemb = cemb.flatten(2)
|
442 |
-
# cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
|
443 |
-
cemb = self.df_fc_emb(cemb)
|
444 |
-
# cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
445 |
|
446 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
447 |
emb = e3.permute(0, 2, 3, 1)
|
@@ -449,9 +449,9 @@ class Encoder(nn.Module):
|
|
449 |
emb = emb.flatten(2)
|
450 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
451 |
|
452 |
-
emb = self.combine(emb, cemb)
|
453 |
-
# if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
454 |
-
# if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
455 |
|
456 |
emb, h = self.emb_gru.forward(emb, hidden_state)
|
457 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
@@ -460,7 +460,8 @@ class Encoder(nn.Module):
|
|
460 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
461 |
# lsnr shape: [batch_size, time_steps, 1]
|
462 |
|
463 |
-
return e0, e1, e2, e3, emb, c0, lsnr, h
|
|
|
464 |
|
465 |
|
466 |
class Decoder(nn.Module):
|
@@ -828,7 +829,8 @@ class SpectrumDfNet(nn.Module):
|
|
828 |
# # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
829 |
# spec = spec.detach()
|
830 |
|
831 |
-
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
|
|
832 |
|
833 |
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
834 |
# mask shape: [batch_size, 1, time_steps, spec_bins]
|
@@ -926,7 +928,7 @@ def main():
|
|
926 |
spec_complex = spec_complex[:, :-1, :]
|
927 |
|
928 |
output = model.forward(spec_complex)
|
929 |
-
print(output[
|
930 |
return
|
931 |
|
932 |
|
|
|
386 |
nn.ReLU(inplace=True)
|
387 |
)
|
388 |
|
389 |
+
# if config.encoder_combine_op == "concat":
|
390 |
+
# self.embedding_input_size *= 2
|
391 |
+
# self.combine = Concat()
|
392 |
+
# else:
|
393 |
+
# self.combine = Add()
|
394 |
|
395 |
# emb_gru
|
396 |
if config.spec_bins % 8 != 0:
|
|
|
430 |
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
431 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
432 |
|
433 |
+
# # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
|
434 |
+
# c0 = self.df_conv0(feat_spec)
|
435 |
+
# c1 = self.df_conv1(c0)
|
436 |
+
# # c0 shape: [batch_size, channels, time_steps, df_bins]
|
437 |
+
# # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
|
438 |
+
#
|
439 |
+
# cemb = c1.permute(0, 2, 3, 1)
|
440 |
+
# # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
|
441 |
+
# cemb = cemb.flatten(2)
|
442 |
+
# # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
|
443 |
+
# cemb = self.df_fc_emb(cemb)
|
444 |
+
# # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
445 |
|
446 |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
447 |
emb = e3.permute(0, 2, 3, 1)
|
|
|
449 |
emb = emb.flatten(2)
|
450 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
451 |
|
452 |
+
# emb = self.combine(emb, cemb)
|
453 |
+
# # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
454 |
+
# # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
455 |
|
456 |
emb, h = self.emb_gru.forward(emb, hidden_state)
|
457 |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
|
|
460 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
461 |
# lsnr shape: [batch_size, time_steps, 1]
|
462 |
|
463 |
+
# return e0, e1, e2, e3, emb, c0, lsnr, h
|
464 |
+
return e0, e1, e2, e3, emb, lsnr, h
|
465 |
|
466 |
|
467 |
class Decoder(nn.Module):
|
|
|
829 |
# # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
830 |
# spec = spec.detach()
|
831 |
|
832 |
+
# e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
833 |
+
e0, e1, e2, e3, emb, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
834 |
|
835 |
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
836 |
# mask shape: [batch_size, 1, time_steps, spec_bins]
|
|
|
928 |
spec_complex = spec_complex[:, :-1, :]
|
929 |
|
930 |
output = model.forward(spec_complex)
|
931 |
+
print(output[1].shape)
|
932 |
return
|
933 |
|
934 |
|