pyf98 commited on
Commit
bf9aad7
1 Parent(s): 732201b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+
4
+ from espnet2.bin.s2t_inference import Speech2Text
5
+ from espnet2.bin.s2t_inference_language import Speech2Text as Speech2Lang
6
+
7
+
8
+ model_name_or_path = "espnet/owsm_v3"
9
+ device = "cuda"
10
+
11
+ speech2text = Speech2Text.from_pretrained(
12
+ model_name_or_path,
13
+ device=device,
14
+ category_sym="<eng>",
15
+ beam_size=5,
16
+ )
17
+
18
+ speech2lang = Speech2Lang.from_pretrained(
19
+ model_name_or_path,
20
+ device=device,
21
+ nbest=1,
22
+ )
23
+
24
+ iso_codes = ['abk', 'afr', 'amh', 'ara', 'asm', 'ast', 'aze', 'bak', 'bas', 'bel', 'ben', 'bos', 'bre', 'bul', 'cat', 'ceb', 'ces', 'chv', 'ckb', 'cmn', 'cnh', 'cym', 'dan', 'deu', 'dgd', 'div', 'ell', 'eng', 'epo', 'est', 'eus', 'fas', 'fil', 'fin', 'fra', 'frr', 'ful', 'gle', 'glg', 'grn', 'guj', 'hat', 'hau', 'heb', 'hin', 'hrv', 'hsb', 'hun', 'hye', 'ibo', 'ina', 'ind', 'isl', 'ita', 'jav', 'jpn', 'kab', 'kam', 'kan', 'kat', 'kaz', 'kea', 'khm', 'kin', 'kir', 'kmr', 'kor', 'lao', 'lav', 'lga', 'lin', 'lit', 'ltz', 'lug', 'luo', 'mal', 'mar', 'mas', 'mdf', 'mhr', 'mkd', 'mlt', 'mon', 'mri', 'mrj', 'mya', 'myv', 'nan', 'nep', 'nld', 'nno', 'nob', 'npi', 'nso', 'nya', 'oci', 'ori', 'orm', 'ory', 'pan', 'pol', 'por', 'pus', 'quy', 'roh', 'ron', 'rus', 'sah', 'sat', 'sin', 'skr', 'slk', 'slv', 'sna', 'snd', 'som', 'sot', 'spa', 'srd', 'srp', 'sun', 'swa', 'swe', 'swh', 'tam', 'tat', 'tel', 'tgk', 'tgl', 'tha', 'tig', 'tir', 'tok', 'tpi', 'tsn', 'tuk', 'tur', 'twi', 'uig', 'ukr', 'umb', 'urd', 'uzb', 'vie', 'vot', 'wol', 'xho', 'yor', 'yue', 'zho', 'zul']
25
+ lang_names = ['Abkhazian', 'Afrikaans', 'Amharic', 'Arabic', 'Assamese', 'Asturian', 'Azerbaijani', 'Bashkir', 'Basa (Cameroon)', 'Belarusian', 'Bengali', 'Bosnian', 'Breton', 'Bulgarian', 'Catalan', 'Cebuano', 'Czech', 'Chuvash', 'Central Kurdish', 'Mandarin Chinese', 'Hakha Chin', 'Welsh', 'Danish', 'German', 'Dagaari Dioula', 'Dhivehi', 'Modern Greek (1453-)', 'English', 'Esperanto', 'Estonian', 'Basque', 'Persian', 'Filipino', 'Finnish', 'French', 'Northern Frisian', 'Fulah', 'Irish', 'Galician', 'Guarani', 'Gujarati', 'Haitian', 'Hausa', 'Hebrew', 'Hindi', 'Croatian', 'Upper Sorbian', 'Hungarian', 'Armenian', 'Igbo', 'Interlingua (International Auxiliary Language Association)', 'Indonesian', 'Icelandic', 'Italian', 'Javanese', 'Japanese', 'Kabyle', 'Kamba (Kenya)', 'Kannada', 'Georgian', 'Kazakh', 'Kabuverdianu', 'Khmer', 'Kinyarwanda', 'Kirghiz', 'Northern Kurdish', 'Korean', 'Lao', 'Latvian', 'Lungga', 'Lingala', 'Lithuanian', 'Luxembourgish', 'Ganda', 'Luo (Kenya and Tanzania)', 'Malayalam', 'Marathi', 'Masai', 'Moksha', 'Eastern Mari', 'Macedonian', 'Maltese', 'Mongolian', 'Maori', 'Western Mari', 'Burmese', 'Erzya', 'Min Nan Chinese', 'Nepali (macrolanguage)', 'Dutch', 'Norwegian Nynorsk', 'Norwegian Bokmål', 'Nepali (individual language)', 'Pedi', 'Nyanja', 'Occitan (post 1500)', 'Oriya (macrolanguage)', 'Oromo', 'Odia', 'Panjabi', 'Polish', 'Portuguese', 'Pushto', 'Ayacucho Quechua', 'Romansh', 'Romanian', 'Russian', 'Yakut', 'Santali', 'Sinhala', 'Saraiki', 'Slovak', 'Slovenian', 'Shona', 'Sindhi', 'Somali', 'Southern Sotho', 'Spanish', 'Sardinian', 'Serbian', 'Sundanese', 'Swahili (macrolanguage)', 'Swedish', 'Swahili (individual language)', 'Tamil', 'Tatar', 'Telugu', 'Tajik', 'Tagalog', 'Thai', 'Tigre', 'Tigrinya', 'Toki Pona', 'Tok Pisin', 'Tswana', 'Turkmen', 'Turkish', 'Twi', 'Uighur', 'Ukrainian', 'Umbundu', 'Urdu', 'Uzbek', 'Vietnamese', 'Votic', 'Wolof', 'Xhosa', 'Yoruba', 'Yue Chinese', 'Chinese', 'Zulu']
26
+
27
+ task_codes = ['asr', 'st_ara', 'st_cat', 'st_ces', 'st_cym', 'st_deu', 'st_eng', 'st_est', 'st_fas', 'st_fra', 'st_ind', 'st_ita', 'st_jpn', 'st_lav', 'st_mon', 'st_nld', 'st_por', 'st_ron', 'st_rus', 'st_slv', 'st_spa', 'st_swe', 'st_tam', 'st_tur', 'st_vie', 'st_zho']
28
+ task_names = ['Automatic Speech Recognition', 'Translate to Arabic', 'Translate to Catalan', 'Translate to Czech', 'Translate to Welsh', 'Translate to German', 'Translate to English', 'Translate to Estonian', 'Translate to Persian', 'Translate to French', 'Translate to Indonesian', 'Translate to Italian', 'Translate to Japanese', 'Translate to Latvian', 'Translate to Mongolian', 'Translate to Dutch', 'Translate to Portuguese', 'Translate to Romanian', 'Translate to Russian', 'Translate to Slovenian', 'Translate to Spanish', 'Translate to Swedish', 'Translate to Tamil', 'Translate to Turkish', 'Translate to Vietnamese', 'Translate to Chinese']
29
+
30
+ lang2code = dict(
31
+ [('Unknown', 'none')] + sorted(list(zip(lang_names, iso_codes)), key=lambda x: x[0])
32
+ )
33
+ task2code = dict(sorted(list(zip(task_names, task_codes)), key=lambda x: x[0]))
34
+
35
+ code2lang = dict([(v, k) for k, v in lang2code.items()])
36
+
37
+
38
+ # Copied from Whisper utils
39
+ def format_timestamp(
40
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
41
+ ):
42
+ assert seconds >= 0, "non-negative timestamp expected"
43
+ milliseconds = round(seconds * 1000.0)
44
+
45
+ hours = milliseconds // 3_600_000
46
+ milliseconds -= hours * 3_600_000
47
+
48
+ minutes = milliseconds // 60_000
49
+ milliseconds -= minutes * 60_000
50
+
51
+ seconds = milliseconds // 1_000
52
+ milliseconds -= seconds * 1_000
53
+
54
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
55
+ return (
56
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
57
+ )
58
+
59
+
60
+ def predict(audio_path, src_lang: str, task: str, beam_size, long_form: bool, text_prev: str,):
61
+ speech2text.task_id = speech2text.converter.token2id[f'<{task2code[task]}>']
62
+ speech2text.beam_search.beam_size = int(beam_size)
63
+
64
+ # Our model is trained on 30s and 16kHz
65
+ _sr = 16000
66
+ _dur = 30
67
+ speech, rate = librosa.load(audio_path, sr=_sr) # speech has shape (len,); resample to 16k Hz
68
+
69
+ # Detect language using the first 30s of speech
70
+ lang_code = lang2code[src_lang]
71
+ if lang_code == 'none':
72
+ lang_code = speech2lang(
73
+ librosa.util.fix_length(speech, size=(_sr * _dur))
74
+ )[0][0].strip()[1:-1]
75
+ speech2text.category_id = speech2text.converter.token2id[f'<{lang_code}>']
76
+
77
+ # ASR or ST
78
+ if long_form: # speech will be padded in decode_long()
79
+ try:
80
+ speech2text.maxlenratio = 0.0
81
+ utts = speech2text.decode_long(
82
+ speech,
83
+ segment_sec=_dur,
84
+ fs=_sr,
85
+ condition_on_prev_text=False,
86
+ init_text=text_prev,
87
+ start_time="<0.00>",
88
+ end_time_threshold="<29.50>",
89
+ )
90
+
91
+ text = []
92
+ for t1, t2, res in utts:
93
+ text.append(f"[{format_timestamp(seconds=t1)} --> {format_timestamp(seconds=t2)}] {res}")
94
+ text = '\n'.join(text)
95
+
96
+ return code2lang[lang_code], text
97
+ except:
98
+ print("An exception occurred in long-form decoding. Falling back to short-form decoding (only first 30s)")
99
+
100
+ speech2text.maxlenratio = -min(450, int((len(speech) / rate) * 15)) # assuming 15 tokens per second
101
+ speech = librosa.util.fix_length(speech, size=(_sr * _dur))
102
+ text = speech2text(speech, text_prev)[0][3]
103
+
104
+ return code2lang[lang_code], text
105
+
106
+
107
+ demo = gr.Interface(
108
+ predict,
109
+ inputs=[
110
+ gr.Audio(type="filepath", label="Speech Input"),
111
+ gr.Dropdown(choices=list(lang2code), value="English", label="Language", info="Language of input speech. Select 'Unknown' (1st option) to detect it automatically."),
112
+ gr.Dropdown(choices=list(task2code), value="Automatic Speech Recognition", label="Task", info="Task to perform on input speech."),
113
+ gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Beam Size", info="Beam size used in beam search."),
114
+ gr.Checkbox(label="Long Form (Experimental)", info="Whether to perform long-form decoding (experimental feature)."),
115
+ gr.Text(label="Text Prompt (Optional)", info="Generation will be conditioned on this prompt if provided"),
116
+ ],
117
+ outputs=[
118
+ gr.Text(label="Predicted Language", info="Language identification is performed if language is unknown."),
119
+ gr.Text(label="Predicted Text", info="Best hypothesis, without timestamps."),
120
+ ],
121
+ )
122
+
123
+
124
+ if __name__ == "__main__":
125
+ demo.launch(
126
+ show_api=False,
127
+ # debug=True
128
+ )