Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py
CHANGED
@@ -76,34 +76,6 @@ def padding(x, D, K, S):
|
|
76 |
return x
|
77 |
|
78 |
|
79 |
-
class DecoderBlock(nn.Module):
|
80 |
-
def __init__(self,
|
81 |
-
channels_h: int,
|
82 |
-
channels_output: int,
|
83 |
-
kernel_size: int,
|
84 |
-
stride: int,
|
85 |
-
do_relu: bool = True,
|
86 |
-
):
|
87 |
-
super(DecoderBlock, self).__init__()
|
88 |
-
self.do_relu = do_relu
|
89 |
-
|
90 |
-
self.conv = nn.Conv1d(channels_h, channels_h * 2, 1)
|
91 |
-
self.glu = nn.GLU(dim=1)
|
92 |
-
self.convt = nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
|
93 |
-
# self.relu = nn.ReLU()
|
94 |
-
|
95 |
-
def forward(self, inputs: torch.Tensor):
|
96 |
-
# inputs shape: [batch_size, channel, num_samples]
|
97 |
-
x = self.conv(inputs)
|
98 |
-
x = self.glu(x)
|
99 |
-
x = self.convt(x)
|
100 |
-
if self.do_relu:
|
101 |
-
x = F.relu(x)
|
102 |
-
# x = self.relu(x)
|
103 |
-
|
104 |
-
return x
|
105 |
-
|
106 |
-
|
107 |
class CleanUNet(nn.Module):
|
108 |
"""
|
109 |
CleanUNet architecture.
|
@@ -162,39 +134,18 @@ class CleanUNet(nn.Module):
|
|
162 |
|
163 |
if i == 0:
|
164 |
# no relu at end
|
165 |
-
self.decoder.append(
|
166 |
-
channels_h
|
167 |
-
|
168 |
-
kernel_size
|
169 |
-
stride=stride,
|
170 |
-
do_relu=False,
|
171 |
))
|
172 |
else:
|
173 |
-
self.decoder.insert(
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
stride=stride,
|
180 |
-
do_relu=True,
|
181 |
-
)
|
182 |
-
)
|
183 |
-
|
184 |
-
# if i == 0:
|
185 |
-
# # no relu at end
|
186 |
-
# self.decoder.append(nn.Sequential(
|
187 |
-
# nn.Conv1d(channels_h, channels_h * 2, 1),
|
188 |
-
# nn.GLU(dim=1),
|
189 |
-
# nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
|
190 |
-
# ))
|
191 |
-
# else:
|
192 |
-
# self.decoder.insert(0, nn.Sequential(
|
193 |
-
# nn.Conv1d(channels_h, channels_h * 2, 1),
|
194 |
-
# nn.GLU(dim=1),
|
195 |
-
# nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
|
196 |
-
# # nn.ReLU(inplace=False)
|
197 |
-
# ))
|
198 |
channels_output = channels_h
|
199 |
|
200 |
# double H but keep below max_H
|
|
|
76 |
return x
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
class CleanUNet(nn.Module):
|
80 |
"""
|
81 |
CleanUNet architecture.
|
|
|
134 |
|
135 |
if i == 0:
|
136 |
# no relu at end
|
137 |
+
self.decoder.append(nn.Sequential(
|
138 |
+
nn.Conv1d(channels_h, channels_h * 2, 1),
|
139 |
+
nn.GLU(dim=1),
|
140 |
+
nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
|
|
|
|
|
141 |
))
|
142 |
else:
|
143 |
+
self.decoder.insert(0, nn.Sequential(
|
144 |
+
nn.Conv1d(channels_h, channels_h * 2, 1),
|
145 |
+
nn.GLU(dim=1),
|
146 |
+
nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
|
147 |
+
nn.ReLU()
|
148 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
channels_output = channels_h
|
150 |
|
151 |
# double H but keep below max_H
|