HoneyTian commited on
Commit
b06a791
·
1 Parent(s): 8059ca7
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -76,34 +76,6 @@ def padding(x, D, K, S):
76
  return x
77
 
78
 
79
- class DecoderBlock(nn.Module):
80
- def __init__(self,
81
- channels_h: int,
82
- channels_output: int,
83
- kernel_size: int,
84
- stride: int,
85
- do_relu: bool = True,
86
- ):
87
- super(DecoderBlock, self).__init__()
88
- self.do_relu = do_relu
89
-
90
- self.conv = nn.Conv1d(channels_h, channels_h * 2, 1)
91
- self.glu = nn.GLU(dim=1)
92
- self.convt = nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
93
- # self.relu = nn.ReLU()
94
-
95
- def forward(self, inputs: torch.Tensor):
96
- # inputs shape: [batch_size, channel, num_samples]
97
- x = self.conv(inputs)
98
- x = self.glu(x)
99
- x = self.convt(x)
100
- if self.do_relu:
101
- x = F.relu(x)
102
- # x = self.relu(x)
103
-
104
- return x
105
-
106
-
107
  class CleanUNet(nn.Module):
108
  """
109
  CleanUNet architecture.
@@ -162,39 +134,18 @@ class CleanUNet(nn.Module):
162
 
163
  if i == 0:
164
  # no relu at end
165
- self.decoder.append(DecoderBlock(
166
- channels_h=channels_h,
167
- channels_output=channels_output,
168
- kernel_size=kernel_size,
169
- stride=stride,
170
- do_relu=False,
171
  ))
172
  else:
173
- self.decoder.insert(
174
- index=0,
175
- module=DecoderBlock(
176
- channels_h=channels_h,
177
- channels_output=channels_output,
178
- kernel_size=kernel_size,
179
- stride=stride,
180
- do_relu=True,
181
- )
182
- )
183
-
184
- # if i == 0:
185
- # # no relu at end
186
- # self.decoder.append(nn.Sequential(
187
- # nn.Conv1d(channels_h, channels_h * 2, 1),
188
- # nn.GLU(dim=1),
189
- # nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
190
- # ))
191
- # else:
192
- # self.decoder.insert(0, nn.Sequential(
193
- # nn.Conv1d(channels_h, channels_h * 2, 1),
194
- # nn.GLU(dim=1),
195
- # nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
196
- # # nn.ReLU(inplace=False)
197
- # ))
198
  channels_output = channels_h
199
 
200
  # double H but keep below max_H
 
76
  return x
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  class CleanUNet(nn.Module):
80
  """
81
  CleanUNet architecture.
 
134
 
135
  if i == 0:
136
  # no relu at end
137
+ self.decoder.append(nn.Sequential(
138
+ nn.Conv1d(channels_h, channels_h * 2, 1),
139
+ nn.GLU(dim=1),
140
+ nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
 
 
141
  ))
142
  else:
143
+ self.decoder.insert(0, nn.Sequential(
144
+ nn.Conv1d(channels_h, channels_h * 2, 1),
145
+ nn.GLU(dim=1),
146
+ nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
147
+ nn.ReLU()
148
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  channels_output = channels_h
150
 
151
  # double H but keep below max_H