lomit commited on
Commit
6ca622d
·
1 Parent(s): 3150b5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py CHANGED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import io
3
+ import soundfile as sf
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+ import whisper
8
+
9
+ import torch
10
+
11
+ # pre-process
12
+ # file object input case
13
+ def trans_byte2arr(byte_data: bytes):
14
+
15
+ arr_data, _ = sf.read(file=io.BytesIO(byte_data.read()), dtype="float32")
16
+
17
+ sig_data = merge_sig(arr_data)
18
+
19
+ return sig_data
20
+
21
+
22
+ def merge_sig(arr_data):
23
+
24
+ if arr_data.ndim == 2:
25
+ # left right channel sound file case
26
+ # element-wise add left and right
27
+ sig_data = arr_data.sum(axis=1)
28
+ elif arr_data.ndim > 2:
29
+ print("this file is not audio file")
30
+ else:
31
+ return arr_data
32
+
33
+ return sig_data
34
+
35
+
36
+ # pre-process
37
+ def audio_speed_reduce(sig_data: np.array, sample_rate: int):
38
+ if sample_rate > 16000:
39
+ reduce_size = sample_rate / 16000
40
+ elif sample_rate < 16000:
41
+ reduce_size = 16000 / sample_rate
42
+ else:
43
+ reduce_size = None
44
+
45
+ sig_data = merge_sig(sig_data)
46
+
47
+ if reduce_size is None:
48
+ return audio
49
+ else:
50
+ try:
51
+ audio = sig_data.reshape(-1, int(reduce_size)).mean(axis=1)
52
+ except:
53
+ slice_size = len(sig_data) % reduce_size
54
+ audio = (
55
+ sig_data[: -int(slice_size)].reshape(-1, int(reduce_size)).mean(axis=1)
56
+ )
57
+
58
+ return audio
59
+
60
+
61
+ def convert_byte_audio(byte_data):
62
+ # convert audio from bytes
63
+ arr_data, sr = sf.read(file=io.BytesIO(byte_data), dtype="float32")
64
+
65
+ # reduce audio
66
+ audio = audio_speed_reduce(arr_data, sr)
67
+ return audio
68
+
69
+
70
+ def get_langage_cls(audio_arr: np.array, model: torch.nn.Module):
71
+
72
+ # data slice 30 sec
73
+ audio = whisper.pad_or_trim(audio_arr)
74
+
75
+ # make log-Mel spectrogram and move to the same device as the model
76
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
77
+ # detect the spoken language
78
+ _, probs = model.detect_language(mel)
79
+
80
+ return probs
81
+
82
+
83
+ def transcribe(audio: np.array, model: torch.nn.Module, task: str = "transcribe"):
84
+
85
+ base_option = dict(beam_size=5, best_of=5)
86
+
87
+ if task == "transcribe":
88
+ base_option = dict(task="transcribe", **base_option)
89
+ else:
90
+ base_option = dict(task="translate", **base_option)
91
+
92
+ result = model.transcribe(audio, **base_option)
93
+ return result["text"]
94
+
95
+
96
+ def load_model(model_name: str):
97
+ model = whisper.load_model(model_name)
98
+ return model
99
+
100
+
101
+ file_data = st.file_uploader("Upload your audio file")
102
+
103
+
104
+ if file_data is not None:
105
+ # To read file as bytes:
106
+ bytes_data = file_data.getvalue()
107
+
108
+ audio_arr = convert_byte_audio(bytes_data)
109
+
110
+ # audio plotting
111
+ fig, ax = plt.subplots()
112
+ ax.plot(audio_arr)
113
+ st.pyplot(fig)
114
+
115
+ st.audio(bytes_data)
116
+
117
+ model_option = [
118
+ "tiny",
119
+ "base",
120
+ "small",
121
+ "medium",
122
+ "large",
123
+ ]
124
+ selected_model_size = st.selectbox(
125
+ "What do you want model size?", ["None"] + model_option
126
+ )
127
+
128
+ if selected_model_size in model_option:
129
+ model = load_model(selected_model_size)
130
+
131
+ lang_button = st.button("What is langage")
132
+ if lang_button:
133
+ probs = get_langage_cls(audio_arr=audio_arr, model=model)
134
+ st.write(f"Detected language: {max(probs, key=probs.get)}")
135
+
136
+ task_option = ["transcribe", "translate"]
137
+ translate_task = st.selectbox("What is your task", ["None"] + task_option)
138
+
139
+ if translate_task != "None":
140
+ result = transcribe(audio=audio_arr, model=model, task=translate_task)
141
+ st.write(result)