Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
CHANGED
@@ -50,8 +50,7 @@ class DownSampling(nn.Module):
|
|
50 |
super(DownSampling, self).__init__()
|
51 |
self.num_layers = num_layers
|
52 |
|
53 |
-
|
54 |
-
|
55 |
for idx in range(self.num_layers):
|
56 |
down_sampling_block = DownSamplingBlock(
|
57 |
in_channels=in_channels,
|
@@ -62,6 +61,8 @@ class DownSampling(nn.Module):
|
|
62 |
self.down_sampling_block_list.append(down_sampling_block)
|
63 |
in_channels = hidden_channels
|
64 |
|
|
|
|
|
65 |
def forward(self, x: torch.Tensor):
|
66 |
# x shape: [batch_size, channels, num_samples]
|
67 |
for down_sampling_block in self.down_sampling_block_list:
|
@@ -111,8 +112,7 @@ class UpSampling(nn.Module):
|
|
111 |
super(UpSampling, self).__init__()
|
112 |
self.num_layers = num_layers
|
113 |
|
114 |
-
|
115 |
-
|
116 |
for idx in range(self.num_layers-1):
|
117 |
up_sampling_block = UpSamplingBlock(
|
118 |
out_channels=hidden_channels,
|
@@ -131,6 +131,8 @@ class UpSampling(nn.Module):
|
|
131 |
do_relu=False,
|
132 |
)
|
133 |
self.up_sampling_block_list.append(up_sampling_block)
|
|
|
|
|
134 |
|
135 |
def forward(self, x: torch.Tensor):
|
136 |
# x shape: [batch_size, channels, num_samples]
|
|
|
50 |
super(DownSampling, self).__init__()
|
51 |
self.num_layers = num_layers
|
52 |
|
53 |
+
down_sampling_block_list = list()
|
|
|
54 |
for idx in range(self.num_layers):
|
55 |
down_sampling_block = DownSamplingBlock(
|
56 |
in_channels=in_channels,
|
|
|
61 |
self.down_sampling_block_list.append(down_sampling_block)
|
62 |
in_channels = hidden_channels
|
63 |
|
64 |
+
self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list)
|
65 |
+
|
66 |
def forward(self, x: torch.Tensor):
|
67 |
# x shape: [batch_size, channels, num_samples]
|
68 |
for down_sampling_block in self.down_sampling_block_list:
|
|
|
112 |
super(UpSampling, self).__init__()
|
113 |
self.num_layers = num_layers
|
114 |
|
115 |
+
up_sampling_block_list = list()
|
|
|
116 |
for idx in range(self.num_layers-1):
|
117 |
up_sampling_block = UpSamplingBlock(
|
118 |
out_channels=hidden_channels,
|
|
|
131 |
do_relu=False,
|
132 |
)
|
133 |
self.up_sampling_block_list.append(up_sampling_block)
|
134 |
+
self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
|
135 |
+
|
136 |
|
137 |
def forward(self, x: torch.Tensor):
|
138 |
# x shape: [batch_size, channels, num_samples]
|