Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py
CHANGED
@@ -95,17 +95,11 @@ class DecoderBlock(nn.Module):
|
|
95 |
def forward(self, inputs: torch.Tensor):
|
96 |
# inputs shape: [batch_size, channel, num_samples]
|
97 |
x = self.conv(inputs)
|
98 |
-
print(x._version)
|
99 |
-
|
100 |
x = self.glu(x)
|
101 |
-
print(x._version)
|
102 |
-
|
103 |
x = self.convt(x)
|
104 |
-
|
105 |
-
|
106 |
-
# x = F.relu(x)
|
107 |
# x = self.relu(x)
|
108 |
-
# print(x._version)
|
109 |
|
110 |
return x
|
111 |
|
@@ -258,7 +252,7 @@ class CleanUNet(nn.Module):
|
|
258 |
# decoder
|
259 |
for i, upsampling_block in enumerate(self.decoder):
|
260 |
skip_i = skip_connections[i]
|
261 |
-
x
|
262 |
x = upsampling_block(x)
|
263 |
|
264 |
x = x[:, :, :L] * std
|
|
|
95 |
def forward(self, inputs: torch.Tensor):
|
96 |
# inputs shape: [batch_size, channel, num_samples]
|
97 |
x = self.conv(inputs)
|
|
|
|
|
98 |
x = self.glu(x)
|
|
|
|
|
99 |
x = self.convt(x)
|
100 |
+
if self.do_relu:
|
101 |
+
x = F.relu(x)
|
|
|
102 |
# x = self.relu(x)
|
|
|
103 |
|
104 |
return x
|
105 |
|
|
|
252 |
# decoder
|
253 |
for i, upsampling_block in enumerate(self.decoder):
|
254 |
skip_i = skip_connections[i]
|
255 |
+
x = x + skip_i[:, :, :x.shape[-1]]
|
256 |
x = upsampling_block(x)
|
257 |
|
258 |
x = x[:, :, :L] * std
|