HoneyTian commited on
Commit
8059ca7
·
1 Parent(s): 4297745
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -95,17 +95,11 @@ class DecoderBlock(nn.Module):
95
  def forward(self, inputs: torch.Tensor):
96
  # inputs shape: [batch_size, channel, num_samples]
97
  x = self.conv(inputs)
98
- print(x._version)
99
-
100
  x = self.glu(x)
101
- print(x._version)
102
-
103
  x = self.convt(x)
104
- print(x._version)
105
- # if self.do_relu:
106
- # x = F.relu(x)
107
  # x = self.relu(x)
108
- # print(x._version)
109
 
110
  return x
111
 
@@ -258,7 +252,7 @@ class CleanUNet(nn.Module):
258
  # decoder
259
  for i, upsampling_block in enumerate(self.decoder):
260
  skip_i = skip_connections[i]
261
- x += skip_i[:, :, :x.shape[-1]]
262
  x = upsampling_block(x)
263
 
264
  x = x[:, :, :L] * std
 
95
  def forward(self, inputs: torch.Tensor):
96
  # inputs shape: [batch_size, channel, num_samples]
97
  x = self.conv(inputs)
 
 
98
  x = self.glu(x)
 
 
99
  x = self.convt(x)
100
+ if self.do_relu:
101
+ x = F.relu(x)
 
102
  # x = self.relu(x)
 
103
 
104
  return x
105
 
 
252
  # decoder
253
  for i, upsampling_block in enumerate(self.decoder):
254
  skip_i = skip_connections[i]
255
+ x = x + skip_i[:, :, :x.shape[-1]]
256
  x = upsampling_block(x)
257
 
258
  x = x[:, :, :L] * std