HoneyTian commited on
Commit
23dc849
·
1 Parent(s): d0b38bd
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -50,8 +50,7 @@ class DownSampling(nn.Module):
50
  super(DownSampling, self).__init__()
51
  self.num_layers = num_layers
52
 
53
- self.down_sampling_block_list = list()
54
-
55
  for idx in range(self.num_layers):
56
  down_sampling_block = DownSamplingBlock(
57
  in_channels=in_channels,
@@ -62,6 +61,8 @@ class DownSampling(nn.Module):
62
  self.down_sampling_block_list.append(down_sampling_block)
63
  in_channels = hidden_channels
64
 
 
 
65
  def forward(self, x: torch.Tensor):
66
  # x shape: [batch_size, channels, num_samples]
67
  for down_sampling_block in self.down_sampling_block_list:
@@ -111,8 +112,7 @@ class UpSampling(nn.Module):
111
  super(UpSampling, self).__init__()
112
  self.num_layers = num_layers
113
 
114
- self.up_sampling_block_list = list()
115
-
116
  for idx in range(self.num_layers-1):
117
  up_sampling_block = UpSamplingBlock(
118
  out_channels=hidden_channels,
@@ -131,6 +131,8 @@ class UpSampling(nn.Module):
131
  do_relu=False,
132
  )
133
  self.up_sampling_block_list.append(up_sampling_block)
 
 
134
 
135
  def forward(self, x: torch.Tensor):
136
  # x shape: [batch_size, channels, num_samples]
 
50
  super(DownSampling, self).__init__()
51
  self.num_layers = num_layers
52
 
53
+ down_sampling_block_list = list()
 
54
  for idx in range(self.num_layers):
55
  down_sampling_block = DownSamplingBlock(
56
  in_channels=in_channels,
 
61
  self.down_sampling_block_list.append(down_sampling_block)
62
  in_channels = hidden_channels
63
 
64
+ self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list)
65
+
66
  def forward(self, x: torch.Tensor):
67
  # x shape: [batch_size, channels, num_samples]
68
  for down_sampling_block in self.down_sampling_block_list:
 
112
  super(UpSampling, self).__init__()
113
  self.num_layers = num_layers
114
 
115
+ up_sampling_block_list = list()
 
116
  for idx in range(self.num_layers-1):
117
  up_sampling_block = UpSamplingBlock(
118
  out_channels=hidden_channels,
 
131
  do_relu=False,
132
  )
133
  self.up_sampling_block_list.append(up_sampling_block)
134
+ self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
135
+
136
 
137
  def forward(self, x: torch.Tensor):
138
  # x shape: [batch_size, channels, num_samples]