HoneyTian commited on
Commit
873103c
·
1 Parent(s): cf7ad48

add frcrn model

Browse files
examples/frcrn/yaml/config.yaml CHANGED
@@ -16,14 +16,14 @@ seed: 1234
16
 
17
  sample_rate: 8000
18
  segment_size: 32000
19
- nfft: 512
20
- win_size: 512
21
- hop_size: 256
22
  win_type: hann
23
 
24
  use_complex_networks: true
25
- model_depth: 20
26
- model_complexity: 45
27
 
28
  min_snr_db: -10
29
  max_snr_db: 20
 
16
 
17
  sample_rate: 8000
18
  segment_size: 32000
19
+ nfft: 128
20
+ win_size: 128
21
+ hop_size: 64
22
  win_type: hann
23
 
24
  use_complex_networks: true
25
+ model_depth: 10
26
+ model_complexity: -1
27
 
28
  min_snr_db: -10
29
  max_snr_db: 20
toolbox/torchaudio/models/frcrn/unet.py CHANGED
@@ -340,13 +340,22 @@ class UNet(nn.Module):
340
 
341
  def main():
342
  # [batch_size, 1, freq_bins, time_steps, 2]
343
- x = torch.rand(size=(1, 1, 257, 2000, 2))
344
- # x = torch.rand(size=(1, 1, 256, 2000, 2))
345
- # x = torch.rand(size=(1, 1, 255, 2000, 2))
 
 
 
 
 
 
 
 
 
346
  unet = UNet(
347
  in_channels=1,
348
- model_complexity=45,
349
- model_depth=20,
350
  use_complex_networks=True
351
  )
352
  print(unet)
 
340
 
341
  def main():
342
  # [batch_size, 1, freq_bins, time_steps, 2]
343
+ # x = torch.rand(size=(1, 1, 257, 2000, 2))
344
+ # unet = UNet(
345
+ # in_channels=1,
346
+ # model_complexity=45,
347
+ # model_depth=20,
348
+ # use_complex_networks=True
349
+ # )
350
+ # print(unet)
351
+ # result = unet.forward(x)
352
+ # print(result.shape)
353
+
354
+ x = torch.rand(size=(1, 1, 65, 2000, 2))
355
  unet = UNet(
356
  in_channels=1,
357
+ model_complexity=-1,
358
+ model_depth=10,
359
  use_complex_networks=True
360
  )
361
  print(unet)