HoneyTian commited on
Commit
91e3fb3
·
1 Parent(s): decba93
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -386,11 +386,11 @@ class Encoder(nn.Module):
386
  nn.ReLU(inplace=True)
387
  )
388
 
389
- if config.encoder_combine_op == "concat":
390
- self.embedding_input_size *= 2
391
- self.combine = Concat()
392
- else:
393
- self.combine = Add()
394
 
395
  # emb_gru
396
  if config.spec_bins % 8 != 0:
@@ -430,18 +430,18 @@ class Encoder(nn.Module):
430
  # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
431
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
432
 
433
- # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
434
- c0 = self.df_conv0(feat_spec)
435
- c1 = self.df_conv1(c0)
436
- # c0 shape: [batch_size, channels, time_steps, df_bins]
437
- # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
438
-
439
- cemb = c1.permute(0, 2, 3, 1)
440
- # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
441
- cemb = cemb.flatten(2)
442
- # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
443
- cemb = self.df_fc_emb(cemb)
444
- # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
445
 
446
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
447
  emb = e3.permute(0, 2, 3, 1)
@@ -449,9 +449,9 @@ class Encoder(nn.Module):
449
  emb = emb.flatten(2)
450
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
451
 
452
- emb = self.combine(emb, cemb)
453
- # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
454
- # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
455
 
456
  emb, h = self.emb_gru.forward(emb, hidden_state)
457
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
@@ -460,7 +460,8 @@ class Encoder(nn.Module):
460
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
461
  # lsnr shape: [batch_size, time_steps, 1]
462
 
463
- return e0, e1, e2, e3, emb, c0, lsnr, h
 
464
 
465
 
466
  class Decoder(nn.Module):
@@ -828,7 +829,8 @@ class SpectrumDfNet(nn.Module):
828
  # # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
829
  # spec = spec.detach()
830
 
831
- e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
 
832
 
833
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
834
  # mask shape: [batch_size, 1, time_steps, spec_bins]
@@ -926,7 +928,7 @@ def main():
926
  spec_complex = spec_complex[:, :-1, :]
927
 
928
  output = model.forward(spec_complex)
929
- print(output[0].shape)
930
  return
931
 
932
 
 
386
  nn.ReLU(inplace=True)
387
  )
388
 
389
+ # if config.encoder_combine_op == "concat":
390
+ # self.embedding_input_size *= 2
391
+ # self.combine = Concat()
392
+ # else:
393
+ # self.combine = Add()
394
 
395
  # emb_gru
396
  if config.spec_bins % 8 != 0:
 
430
  # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
431
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
432
 
433
+ # # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
434
+ # c0 = self.df_conv0(feat_spec)
435
+ # c1 = self.df_conv1(c0)
436
+ # # c0 shape: [batch_size, channels, time_steps, df_bins]
437
+ # # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
438
+ #
439
+ # cemb = c1.permute(0, 2, 3, 1)
440
+ # # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
441
+ # cemb = cemb.flatten(2)
442
+ # # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
443
+ # cemb = self.df_fc_emb(cemb)
444
+ # # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
445
 
446
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
447
  emb = e3.permute(0, 2, 3, 1)
 
449
  emb = emb.flatten(2)
450
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
451
 
452
+ # emb = self.combine(emb, cemb)
453
+ # # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
454
+ # # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
455
 
456
  emb, h = self.emb_gru.forward(emb, hidden_state)
457
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
 
460
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
461
  # lsnr shape: [batch_size, time_steps, 1]
462
 
463
+ # return e0, e1, e2, e3, emb, c0, lsnr, h
464
+ return e0, e1, e2, e3, emb, lsnr, h
465
 
466
 
467
  class Decoder(nn.Module):
 
829
  # # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
830
  # spec = spec.detach()
831
 
832
+ # e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
833
+ e0, e1, e2, e3, emb, lsnr, h = self.encoder.forward(feat_power, feat_spec)
834
 
835
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
836
  # mask shape: [batch_size, 1, time_steps, spec_bins]
 
928
  spec_complex = spec_complex[:, :-1, :]
929
 
930
  output = model.forward(spec_complex)
931
+ print(output[1].shape)
932
  return
933
 
934