HoneyTian commited on
Commit
55d487a
·
1 Parent(s): ce1e2dc
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py CHANGED
@@ -135,7 +135,10 @@ class CausalConv2d(nn.Module):
135
  return x, new_cache
136
 
137
 
138
- class CausalConvTranspose2d(nn.Module):
 
 
 
139
  def __init__(self,
140
  in_channels: int,
141
  out_channels: int,
@@ -148,7 +151,7 @@ class CausalConvTranspose2d(nn.Module):
148
  norm_layer: str = "batch_norm_2d",
149
  activation_layer: str = "relu",
150
  ):
151
- super(CausalConvTranspose2d, self).__init__()
152
 
153
  kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
154
 
@@ -198,7 +201,7 @@ class CausalConvTranspose2d(nn.Module):
198
  else:
199
  self.activation = nn.Identity()
200
 
201
- def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
202
  """
203
  :param inputs: shape: [b, c, t, f]
204
  :param cache: shape: [b, c, lookback, f];
@@ -228,6 +231,101 @@ class CausalConvTranspose2d(nn.Module):
228
  return x, new_cache
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  class GroupedLinear(nn.Module):
232
 
233
  def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
 
135
  return x, new_cache
136
 
137
 
138
+ class CausalConvTranspose2dErrorCase(nn.Module):
139
+ """
140
+ 错误的缓存方法。
141
+ """
142
  def __init__(self,
143
  in_channels: int,
144
  out_channels: int,
 
151
  norm_layer: str = "batch_norm_2d",
152
  activation_layer: str = "relu",
153
  ):
154
+ super(CausalConvTranspose2dErrorCase, self).__init__()
155
 
156
  kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
157
 
 
201
  else:
202
  self.activation = nn.Identity()
203
 
204
+ def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
205
  """
206
  :param inputs: shape: [b, c, t, f]
207
  :param cache: shape: [b, c, lookback, f];
 
231
  return x, new_cache
232
 
233
 
234
+ class CausalConvTranspose2d(nn.Module):
235
+ def __init__(self,
236
+ in_channels: int,
237
+ out_channels: int,
238
+ kernel_size: Union[int, Iterable[int]],
239
+ fstride: int = 1,
240
+ dilation: int = 1,
241
+ pad_f_dim: bool = True,
242
+ bias: bool = True,
243
+ separable: bool = False,
244
+ norm_layer: str = "batch_norm_2d",
245
+ activation_layer: str = "relu",
246
+ ):
247
+ super(CausalConvTranspose2d, self).__init__()
248
+
249
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
250
+
251
+ if pad_f_dim:
252
+ fpad = kernel_size[1] // 2
253
+ else:
254
+ fpad = 0
255
+
256
+ # for last 2 dim, pad (left, right, top, bottom).
257
+ self.lookback = kernel_size[0] - 1
258
+ if self.lookback > 0:
259
+ self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
260
+ else:
261
+ self.tpad = nn.Identity()
262
+
263
+ groups = math.gcd(in_channels, out_channels) if separable else 1
264
+ if groups == 1:
265
+ separable = False
266
+
267
+ self.convt = nn.ConvTranspose2d(
268
+ in_channels,
269
+ out_channels,
270
+ kernel_size=kernel_size,
271
+ padding=(kernel_size[0] - 1, fpad + dilation - 1),
272
+ output_padding=(0, fpad),
273
+ stride=(1, fstride), # stride over time is always 1
274
+ dilation=(1, dilation), # dilation over time is always 1
275
+ groups=groups,
276
+ bias=bias,
277
+ )
278
+
279
+ if separable:
280
+ self.convp = nn.Conv2d(
281
+ out_channels,
282
+ out_channels,
283
+ kernel_size=1,
284
+ bias=False,
285
+ )
286
+ else:
287
+ self.convp = nn.Identity()
288
+
289
+ if norm_layer is not None:
290
+ norm_layer = norm_layer_dict[norm_layer]
291
+ self.norm = norm_layer(out_channels)
292
+ else:
293
+ self.norm = nn.Identity()
294
+
295
+ if activation_layer is not None:
296
+ activation_layer = activation_layer_dict[activation_layer]
297
+ self.activation = activation_layer()
298
+ else:
299
+ self.activation = nn.Identity()
300
+
301
+ def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
302
+ """
303
+ :param inputs: shape: [b, c, t, f]
304
+ :param cache: shape: [b, c, lookback, f];
305
+ :return:
306
+ """
307
+ x = inputs
308
+
309
+ # x shape: [b, c, t, f]
310
+ x = self.convt(x)
311
+ # x shape: [b, c, t+lookback, f]
312
+
313
+ if cache is None:
314
+ x = self.tpad(x)
315
+ else:
316
+ x = torch.concat(tensors=[cache, x], dim=2)
317
+
318
+ new_cache = None
319
+ if self.lookback > 0:
320
+ new_cache = x[:, :, -self.lookback:, :]
321
+
322
+ x = self.convp(x)
323
+ x = self.norm(x)
324
+ x = self.activation(x)
325
+
326
+ return x, new_cache
327
+
328
+
329
  class GroupedLinear(nn.Module):
330
 
331
  def __init__(self, input_size: int, hidden_size: int, groups: int = 1):