Barani1-t commited on
Commit
9caff82
·
1 Parent(s): d7e755b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -23
app.py CHANGED
@@ -11,37 +11,25 @@ target_dtype = np.int16
11
  max_range = np.iinfo(target_dtype).max
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
- # load speech translation checkpoint
15
-
16
- asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
17
-
18
- # load text-to-speech checkpoint and speaker embeddings
19
- processor = SpeechT5Processor.from_pretrained("sanchit-gandhi/speecht5_tts_vox_nl")
20
-
21
- model = SpeechT5ForTextToSpeech.from_pretrained("sanchit-gandhi/speecht5_tts_vox_nl").to(device)
22
-
23
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
24
-
25
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
26
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
27
 
28
  model_mms = VitsModel.from_pretrained("facebook/mms-tts-nld")
29
  tokenizer_mms = VitsTokenizer.from_pretrained("facebook/mms-tts-nld")
30
 
31
-
32
- def translate(audio):
33
- outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"language": "nl","task": "transcribe"})
34
- return outputs["text"]
35
 
36
 
37
- def synthesise_speechT5(text):
38
- inputs = processor(text=text, padding='max_length', truncation=True,max_length=600,return_tensors="pt")
39
- print(inputs)
40
- speech = model.generate_speech(inputs["input_ids"].to(device), speaker_embeddings.to(device),vocoder=vocoder)
41
- return speech.cpu()
42
 
43
  def synthesise(text):
44
- inputs = tokenizer_mms(text, return_tensors="pt")
 
45
  input_ids = inputs["input_ids"]
46
  with torch.no_grad():
47
  outputs = model_mms(input_ids)
 
11
  max_range = np.iinfo(target_dtype).max
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  model_mms = VitsModel.from_pretrained("facebook/mms-tts-nld")
16
  tokenizer_mms = VitsTokenizer.from_pretrained("facebook/mms-tts-nld")
17
 
18
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
19
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device)
20
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="nl", task="transcribe")
21
+ sampling_rate = processor.feature_extractor.sampling_rate
22
 
23
 
24
+ def translate(audio):
25
+ input_features = processor(audio,sampling_rate=sampling_rate,return_tensors="pt").input_features
26
+ predicted_ids = model.generate(input_features.to(device),forced_decoder_ids=forced_decoder_ids)
27
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
28
+ return transcription
29
 
30
  def synthesise(text):
31
+ print(text)
32
+ inputs = tokenizer_mms(text[0], return_tensors="pt")
33
  input_ids = inputs["input_ids"]
34
  with torch.no_grad():
35
  outputs = model_mms(input_ids)