HoneyTian commited on
Commit
5055ff3
·
1 Parent(s): f91bf4a
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -133,7 +133,6 @@ class UpSampling(nn.Module):
133
  up_sampling_block_list.append(up_sampling_block)
134
  self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
135
 
136
-
137
  def forward(self, x: torch.Tensor):
138
  # x shape: [batch_size, channels, num_samples]
139
  for up_sampling_block in self.up_sampling_block_list:
 
133
  up_sampling_block_list.append(up_sampling_block)
134
  self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
135
 
 
136
  def forward(self, x: torch.Tensor):
137
  # x shape: [batch_size, channels, num_samples]
138
  for up_sampling_block in self.up_sampling_block_list:
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py CHANGED
@@ -467,6 +467,7 @@ class TransformerEncoder(nn.Module):
467
  chunk_size=self.chunk_size,
468
  num_left_chunks=self.num_left_chunks
469
  )
 
470
  # chunk_masks shape: [1, time_steps, time_steps]
471
  chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
472
  # chunk_masks shape: [batch_size, time_steps, time_steps]
 
467
  chunk_size=self.chunk_size,
468
  num_left_chunks=self.num_left_chunks
469
  )
470
+ chunk_masks = chunk_masks.to(xs.device)
471
  # chunk_masks shape: [1, time_steps, time_steps]
472
  chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
473
  # chunk_masks shape: [batch_size, time_steps, time_steps]