preetam8 commited on
Commit
be03dff
·
1 Parent(s): 594d8c7

small and pipeline

Browse files
Files changed (1) hide show
  1. app.py +22 -22
app.py CHANGED
@@ -13,35 +13,35 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
  target_language = "fr"
14
 
15
  # load speech translation checkpoint
16
- # asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
17
- whisper_model_name = "openai/whisper-medium"
18
- whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
19
- whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
20
- decoder_ids = whisper_processor.get_decoder_prompt_ids(language=target_language, task="transcribe")
21
 
22
  # load text-to-speech checkpoint
23
  model = VitsModel.from_pretrained("facebook/mms-tts-fra")
24
  tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-fra")
25
 
26
 
27
- # def translate(audio):
28
- # outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": target_language})
29
- # return outputs["text"]
30
-
31
  def translate(audio):
32
- if isinstance(audio, str):
33
- # Account for recorded audio
34
- audio = {
35
- "path": audio,
36
- "sampling_rate": 16_000,
37
- "array": librosa.load(audio, sr=16_000)[0]
38
- }
39
- elif audio["sampling_rate"] != 16_000:
40
- audio["array"] = librosa.resample(audio["array"], audio["sampling_rate"], 16_000)
41
- input_features = whisper_processor(audio["array"], sampling_rate=16000, return_tensors="pt").input_features
42
- predicted_ids = whisper_model.generate(input_features, forced_decoder_ids=decoder_ids)
43
- translated_text = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
44
- return translated_text
 
 
 
 
45
 
46
 
47
  def synthesise(text):
 
13
  target_language = "fr"
14
 
15
  # load speech translation checkpoint
16
+ asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
17
+ # whisper_model_name = "openai/whisper-small"
18
+ # whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
19
+ # whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
20
+ # decoder_ids = whisper_processor.get_decoder_prompt_ids(language=target_language, task="transcribe")
21
 
22
  # load text-to-speech checkpoint
23
  model = VitsModel.from_pretrained("facebook/mms-tts-fra")
24
  tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-fra")
25
 
26
 
 
 
 
 
27
  def translate(audio):
28
+ outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": target_language})
29
+ return outputs["text"]
30
+
31
+ # def translate(audio):
32
+ # if isinstance(audio, str):
33
+ # # Account for recorded audio
34
+ # audio = {
35
+ # "path": audio,
36
+ # "sampling_rate": 16_000,
37
+ # "array": librosa.load(audio, sr=16_000)[0]
38
+ # }
39
+ # elif audio["sampling_rate"] != 16_000:
40
+ # audio["array"] = librosa.resample(audio["array"], audio["sampling_rate"], 16_000)
41
+ # input_features = whisper_processor(audio["array"], sampling_rate=16000, return_tensors="pt").input_features
42
+ # predicted_ids = whisper_model.generate(input_features, forced_decoder_ids=decoder_ids)
43
+ # translated_text = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
44
+ # return translated_text
45
 
46
 
47
  def synthesise(text):