HoneyTian commited on
Commit
61c260b
·
1 Parent(s): 46c2bb3
examples/clean_unet_aishell/run.sh CHANGED
@@ -14,8 +14,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
14
 
15
  sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
18
- --max_count 10000
19
 
20
 
21
  END
 
14
 
15
  sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
 
18
 
19
 
20
  END
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -76,6 +76,32 @@ def padding(x, D, K, S):
76
  return x
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  class CleanUNet(nn.Module):
80
  """
81
  CleanUNet architecture.
@@ -134,18 +160,39 @@ class CleanUNet(nn.Module):
134
 
135
  if i == 0:
136
  # no relu at end
137
- self.decoder.append(nn.Sequential(
138
- nn.Conv1d(channels_h, channels_h * 2, 1),
139
- nn.GLU(dim=1),
140
- nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
 
 
141
  ))
142
  else:
143
- self.decoder.insert(0, nn.Sequential(
144
- nn.Conv1d(channels_h, channels_h * 2, 1),
145
- nn.GLU(dim=1),
146
- nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
147
- # nn.ReLU(inplace=False)
148
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  channels_output = channels_h
150
 
151
  # double H but keep below max_H
 
76
  return x
77
 
78
 
79
+ class DecoderBlock(nn.Module):
80
+ def __init__(self,
81
+ channels_h: int,
82
+ channels_output: int,
83
+ kernel_size: int,
84
+ stride: int,
85
+ do_relu: bool = True,
86
+ ):
87
+ super(DecoderBlock, self).__init__()
88
+ self.do_relu = do_relu
89
+
90
+ self.conv = nn.Conv1d(channels_h, channels_h * 2, 1)
91
+ self.glu = nn.GLU(dim=1)
92
+ self.convt = nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
93
+ self.relu = nn.ReLU()
94
+
95
+ def forward(self, inputs: torch.Tensor):
96
+ # inputs shape: [batch_size, channel, num_samples]
97
+ x = self.conv(inputs)
98
+ x = self.glu(x)
99
+ x = self.convt(x)
100
+ if self.do_relu:
101
+ x = self.relu(x)
102
+ return x
103
+
104
+
105
  class CleanUNet(nn.Module):
106
  """
107
  CleanUNet architecture.
 
160
 
161
  if i == 0:
162
  # no relu at end
163
+ self.decoder.append(DecoderBlock(
164
+ channels_h=channels_h,
165
+ channels_output=channels_output,
166
+ kernel_size=kernel_size,
167
+ stride=stride,
168
+ do_relu=False,
169
  ))
170
  else:
171
+ self.decoder.insert(
172
+ index=0,
173
+ module=DecoderBlock(
174
+ channels_h=channels_h,
175
+ channels_output=channels_output,
176
+ kernel_size=kernel_size,
177
+ stride=stride,
178
+ do_relu=True,
179
+ )
180
+ )
181
+
182
+ # if i == 0:
183
+ # # no relu at end
184
+ # self.decoder.append(nn.Sequential(
185
+ # nn.Conv1d(channels_h, channels_h * 2, 1),
186
+ # nn.GLU(dim=1),
187
+ # nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
188
+ # ))
189
+ # else:
190
+ # self.decoder.insert(0, nn.Sequential(
191
+ # nn.Conv1d(channels_h, channels_h * 2, 1),
192
+ # nn.GLU(dim=1),
193
+ # nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
194
+ # # nn.ReLU(inplace=False)
195
+ # ))
196
  channels_output = channels_h
197
 
198
  # double H but keep below max_H