HoneyTian commited on
Commit
3a9b74d
·
1 Parent(s): 43453fa
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -387,11 +387,11 @@ class Encoder(nn.Module):
387
  nn.ReLU(inplace=True)
388
  )
389
 
390
- # if config.encoder_combine_op == "concat":
391
- # self.embedding_input_size *= 2
392
- # self.combine = Concat()
393
- # else:
394
- # self.combine = Add()
395
 
396
  # emb_gru
397
  if config.spec_bins % 8 != 0:
@@ -431,18 +431,18 @@ class Encoder(nn.Module):
431
  # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
432
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
433
 
434
- # # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
435
- # c0 = self.df_conv0(feat_spec)
436
- # c1 = self.df_conv1(c0)
437
- # # c0 shape: [batch_size, channels, time_steps, df_bins]
438
- # # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
439
- #
440
- # cemb = c1.permute(0, 2, 3, 1)
441
- # # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
442
- # cemb = cemb.flatten(2)
443
- # # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
444
- # cemb = self.df_fc_emb(cemb)
445
- # # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
446
 
447
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
448
  emb = e3.permute(0, 2, 3, 1)
@@ -450,7 +450,7 @@ class Encoder(nn.Module):
450
  emb = emb.flatten(2)
451
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
452
 
453
- # emb = self.combine(emb, cemb)
454
  # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
455
  # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
456
 
@@ -461,8 +461,7 @@ class Encoder(nn.Module):
461
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
462
  # lsnr shape: [batch_size, time_steps, 1]
463
 
464
- # return e0, e1, e2, e3, emb, c0, lsnr, h
465
- return e0, e1, e2, e3, emb, lsnr, h
466
 
467
 
468
  class Decoder(nn.Module):
 
387
  nn.ReLU(inplace=True)
388
  )
389
 
390
+ if config.encoder_combine_op == "concat":
391
+ self.embedding_input_size *= 2
392
+ self.combine = Concat()
393
+ else:
394
+ self.combine = Add()
395
 
396
  # emb_gru
397
  if config.spec_bins % 8 != 0:
 
431
  # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
432
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
433
 
434
+ # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
435
+ c0 = self.df_conv0(feat_spec)
436
+ c1 = self.df_conv1(c0)
437
+ # c0 shape: [batch_size, channels, time_steps, df_bins]
438
+ # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
439
+
440
+ cemb = c1.permute(0, 2, 3, 1)
441
+ # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
442
+ cemb = cemb.flatten(2)
443
+ # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
444
+ cemb = self.df_fc_emb(cemb)
445
+ # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
446
 
447
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
448
  emb = e3.permute(0, 2, 3, 1)
 
450
  emb = emb.flatten(2)
451
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
452
 
453
+ emb = self.combine(emb, cemb)
454
  # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
455
  # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
456
 
 
461
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
462
  # lsnr shape: [batch_size, time_steps, 1]
463
 
464
+ return e0, e1, e2, e3, emb, c0, lsnr, h
 
465
 
466
 
467
  class Decoder(nn.Module):