HoneyTian commited on
Commit
6713e7b
·
1 Parent(s): 1e78a70
toolbox/torchaudio/models/mpnet/inference_mpnet.py CHANGED
@@ -59,10 +59,9 @@ class InferenceMPNet(object):
59
  noisy_audio = noisy_audio.unsqueeze(dim=0)
60
 
61
  # noisy_audio shape: [batch_size, n_samples]
62
- noisy_audio = self.enhancement_by_tensor(noisy_audio)
63
- noisy_audio = noisy_audio[0]
64
-
65
- return noisy_audio.cpu().numpy()
66
 
67
  def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
68
  if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
@@ -82,6 +81,7 @@ class InferenceMPNet(object):
82
  )
83
  enhanced_audio = audio_g.detach()
84
 
 
85
  return enhanced_audio
86
 
87
  def main():
 
59
  noisy_audio = noisy_audio.unsqueeze(dim=0)
60
 
61
  # noisy_audio shape: [batch_size, n_samples]
62
+ enhanced_audio = self.enhancement_by_tensor(noisy_audio)
63
+ # noisy_audio shape: [n_samples,]
64
+ return enhanced_audio.cpu().numpy()
 
65
 
66
  def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
67
  if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
 
81
  )
82
  enhanced_audio = audio_g.detach()
83
 
84
+ enhanced_audio = enhanced_audio[0]
85
  return enhanced_audio
86
 
87
  def main():