HoneyTian commited on
Commit
c99ce11
·
1 Parent(s): 351c010
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -368,6 +368,7 @@ class Encoder(nn.Module):
368
  kernel_size=config.conv_kernel_size_input,
369
  bias=False,
370
  separable=True,
 
371
  )
372
  self.df_conv1 = CausalConv2d(
373
  in_channels=config.conv_channels,
@@ -386,11 +387,11 @@ class Encoder(nn.Module):
386
  nn.ReLU(inplace=True)
387
  )
388
 
389
- # if config.encoder_combine_op == "concat":
390
- # self.embedding_input_size *= 2
391
- # self.combine = Concat()
392
- # else:
393
- # self.combine = Add()
394
 
395
  # emb_gru
396
  if config.spec_bins % 8 != 0:
@@ -430,18 +431,18 @@ class Encoder(nn.Module):
430
  # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
431
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
432
 
433
- # # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
434
- # c0 = self.df_conv0(feat_spec)
435
- # c1 = self.df_conv1(c0)
436
- # # c0 shape: [batch_size, channels, time_steps, df_bins]
437
- # # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
438
- #
439
- # cemb = c1.permute(0, 2, 3, 1)
440
- # # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
441
- # cemb = cemb.flatten(2)
442
- # # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
443
- # cemb = self.df_fc_emb(cemb)
444
- # # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
445
 
446
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
447
  emb = e3.permute(0, 2, 3, 1)
@@ -449,7 +450,7 @@ class Encoder(nn.Module):
449
  emb = emb.flatten(2)
450
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
451
 
452
- # emb = self.combine(emb, cemb)
453
  # # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
454
  # # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
455
 
 
368
  kernel_size=config.conv_kernel_size_input,
369
  bias=False,
370
  separable=True,
371
+ fstride=1,
372
  )
373
  self.df_conv1 = CausalConv2d(
374
  in_channels=config.conv_channels,
 
387
  nn.ReLU(inplace=True)
388
  )
389
 
390
+ if config.encoder_combine_op == "concat":
391
+ self.embedding_input_size *= 2
392
+ self.combine = Concat()
393
+ else:
394
+ self.combine = Add()
395
 
396
  # emb_gru
397
  if config.spec_bins % 8 != 0:
 
431
  # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
432
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
433
 
434
+ # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
435
+ c0 = self.df_conv0(feat_spec)
436
+ c1 = self.df_conv1(c0)
437
+ # c0 shape: [batch_size, channels, time_steps, df_bins]
438
+ # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
439
+
440
+ cemb = c1.permute(0, 2, 3, 1)
441
+ # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
442
+ cemb = cemb.flatten(2)
443
+ # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
444
+ cemb = self.df_fc_emb(cemb)
445
+ # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
446
 
447
  # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
448
  emb = e3.permute(0, 2, 3, 1)
 
450
  emb = emb.flatten(2)
451
  # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
452
 
453
+ emb = self.combine(emb, cemb)
454
  # # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
455
  # # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
456