lpw commited on
Commit
0ea264d
·
1 Parent(s): e6f99c2

Update audio_pipe.py

Browse files
Files changed (1) hide show
  1. audio_pipe.py +2 -2
audio_pipe.py CHANGED
@@ -106,7 +106,7 @@ class SpeechToSpeechPipeline():
106
  [self.tts_model], tts_cfg
107
  )
108
 
109
- def __call__(self, inputs: str) -> Tuple[np.array, int, List[str]]:
110
  """
111
  Args:
112
  inputs (:obj:`np.array`):
@@ -121,7 +121,7 @@ class SpeechToSpeechPipeline():
121
  This can be the name of the instruments for audio source separation
122
  or some annotation for speech enhancement. The length must be `C'`.
123
  """
124
- _inputs = torchaudio.load(inputs)
125
  sample, text = None, None
126
  if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
127
  sample = S2THubInterface.get_model_input(self.task, _inputs)
 
106
  [self.tts_model], tts_cfg
107
  )
108
 
109
+ def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
110
  """
111
  Args:
112
  inputs (:obj:`np.array`):
 
121
  This can be the name of the instruments for audio source separation
122
  or some annotation for speech enhancement. The length must be `C'`.
123
  """
124
+ _inputs = torch.from_numpy(inputs).unsqueeze(0)
125
  sample, text = None, None
126
  if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
127
  sample = S2THubInterface.get_model_input(self.task, _inputs)