Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
CHANGED
@@ -133,7 +133,6 @@ class UpSampling(nn.Module):
|
|
133 |
up_sampling_block_list.append(up_sampling_block)
|
134 |
self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
|
135 |
|
136 |
-
|
137 |
def forward(self, x: torch.Tensor):
|
138 |
# x shape: [batch_size, channels, num_samples]
|
139 |
for up_sampling_block in self.up_sampling_block_list:
|
|
|
133 |
up_sampling_block_list.append(up_sampling_block)
|
134 |
self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
|
135 |
|
|
|
136 |
def forward(self, x: torch.Tensor):
|
137 |
# x shape: [batch_size, channels, num_samples]
|
138 |
for up_sampling_block in self.up_sampling_block_list:
|
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -467,6 +467,7 @@ class TransformerEncoder(nn.Module):
|
|
467 |
chunk_size=self.chunk_size,
|
468 |
num_left_chunks=self.num_left_chunks
|
469 |
)
|
|
|
470 |
# chunk_masks shape: [1, time_steps, time_steps]
|
471 |
chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
|
472 |
# chunk_masks shape: [batch_size, time_steps, time_steps]
|
|
|
467 |
chunk_size=self.chunk_size,
|
468 |
num_left_chunks=self.num_left_chunks
|
469 |
)
|
470 |
+
chunk_masks = chunk_masks.to(xs.device)
|
471 |
# chunk_masks shape: [1, time_steps, time_steps]
|
472 |
chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
|
473 |
# chunk_masks shape: [batch_size, time_steps, time_steps]
|