preetam8 commited on
Commit
2e545c4
·
1 Parent(s): d614113

asr model approach

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import torch
6
 
7
  from transformers import VitsModel, VitsTokenizer, pipeline
 
8
 
9
 
10
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -12,16 +13,35 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
  target_language = "fr"
13
 
14
  # load speech translation checkpoint
15
- asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
 
 
 
 
16
 
17
  # load text-to-speech checkpoint
18
  model = VitsModel.from_pretrained("facebook/mms-tts-fra")
19
  tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-fra")
20
 
21
 
 
 
 
 
22
  def translate(audio):
23
- outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": target_language})
24
- return outputs["text"]
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def synthesise(text):
 
5
  import torch
6
 
7
  from transformers import VitsModel, VitsTokenizer, pipeline
8
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
9
 
10
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
13
  target_language = "fr"
14
 
15
  # load speech translation checkpoint
16
+ # asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
17
+ whisper_model_name = "openai/whisper-base"
18
+ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
19
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
20
+ decoder_ids = whisper_processor.get_decoder_prompt_ids(language=target_language, task="transcribe")
21
 
22
  # load text-to-speech checkpoint
23
  model = VitsModel.from_pretrained("facebook/mms-tts-fra")
24
  tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-fra")
25
 
26
 
27
+ # def translate(audio):
28
+ # outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": target_language})
29
+ # return outputs["text"]
30
+
31
  def translate(audio):
32
+ if isinstance(audio, str):
33
+ # Account for recorded audio
34
+ audio = {
35
+ "path": audio,
36
+ "sampling_rate": 16_000,
37
+ "array": librosa.load(audio, sr=16_000)[0]
38
+ }
39
+ elif audio["sampling_rate"] != 16_000:
40
+ audio["array"] = librosa.resample(audio["array"], audio["sampling_rate"], 16_000)
41
+ input_features = whisper_processor(audio["array"], sampling_rate=16000, return_tensors="pt").input_features
42
+ predicted_ids = whisper_model.generate(input_features, forced_decoder_ids=decoder_ids)
43
+ translated_text = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
44
+ return translated_text
45
 
46
 
47
  def synthesise(text):