Kr08 commited on
Commit
a5753ad
·
verified ·
1 Parent(s): 7d9c19a

Update app.py: added language detection module and subsequent forced decoder

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import streamlit as st
3
  import torchaudio as ta
4
 
@@ -43,7 +44,14 @@ submit_button = st.sidebar.button("Submit")
43
  # except sr.RequestError as e:
44
  # return f"Could not request results; {e}"
45
 
46
-
 
 
 
 
 
 
 
47
  if submit_button and uploaded_files is not None:
48
  st.write("Files uploaded successfully!")
49
 
@@ -62,13 +70,16 @@ if submit_button and uploaded_files is not None:
62
 
63
  input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features
64
 
65
-
66
-
67
- ## Here Generate specific language!!!
68
- forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="translate")
69
-
 
70
 
71
  if task == "translate":
 
 
72
  predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
73
  else:
74
  predicted_ids = model.generate(input_features)
 
1
  import torch
2
+ import pickle
3
  import streamlit as st
4
  import torchaudio as ta
5
 
 
44
  # except sr.RequestError as e:
45
  # return f"Could not request results; {e}"
46
 
47
+ def detect_language(audio_file):
48
+ whisper_model = whisper.load_model("base")
49
+ mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
50
+ # detect the spoken language
51
+ _, probs = whisper_model.detect_language(mel)
52
+ print(f"Detected language: {max(probs[0], key=probs[0].get)}")
53
+ return max(probs[0], key=probs[0].get)
54
+
55
  if submit_button and uploaded_files is not None:
56
  st.write("Files uploaded successfully!")
57
 
 
70
 
71
  input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features
72
 
73
+
74
+ lang = detect_language(input_features)
75
+
76
+ with open('languages.pkl', 'rb') as f:
77
+ lang_dict = pickle.load(f)
78
+ detected_language = lang_dict[lang]
79
 
80
  if task == "translate":
81
+ ## Here Generate specific language!!!
82
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language, task="translate")
83
  predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
84
  else:
85
  predicted_ids = model.generate(input_features)