Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import io
|
3 |
+
import soundfile as sf
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import whisper
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
# pre-process
|
12 |
+
# file object input case
|
13 |
+
def trans_byte2arr(byte_data: bytes):
|
14 |
+
|
15 |
+
arr_data, _ = sf.read(file=io.BytesIO(byte_data.read()), dtype="float32")
|
16 |
+
|
17 |
+
sig_data = merge_sig(arr_data)
|
18 |
+
|
19 |
+
return sig_data
|
20 |
+
|
21 |
+
|
22 |
+
def merge_sig(arr_data):
|
23 |
+
|
24 |
+
if arr_data.ndim == 2:
|
25 |
+
# left right channel sound file case
|
26 |
+
# element-wise add left and right
|
27 |
+
sig_data = arr_data.sum(axis=1)
|
28 |
+
elif arr_data.ndim > 2:
|
29 |
+
print("this file is not audio file")
|
30 |
+
else:
|
31 |
+
return arr_data
|
32 |
+
|
33 |
+
return sig_data
|
34 |
+
|
35 |
+
|
36 |
+
# pre-process
|
37 |
+
def audio_speed_reduce(sig_data: np.array, sample_rate: int):
|
38 |
+
if sample_rate > 16000:
|
39 |
+
reduce_size = sample_rate / 16000
|
40 |
+
elif sample_rate < 16000:
|
41 |
+
reduce_size = 16000 / sample_rate
|
42 |
+
else:
|
43 |
+
reduce_size = None
|
44 |
+
|
45 |
+
sig_data = merge_sig(sig_data)
|
46 |
+
|
47 |
+
if reduce_size is None:
|
48 |
+
return audio
|
49 |
+
else:
|
50 |
+
try:
|
51 |
+
audio = sig_data.reshape(-1, int(reduce_size)).mean(axis=1)
|
52 |
+
except:
|
53 |
+
slice_size = len(sig_data) % reduce_size
|
54 |
+
audio = (
|
55 |
+
sig_data[: -int(slice_size)].reshape(-1, int(reduce_size)).mean(axis=1)
|
56 |
+
)
|
57 |
+
|
58 |
+
return audio
|
59 |
+
|
60 |
+
|
61 |
+
def convert_byte_audio(byte_data):
|
62 |
+
# convert audio from bytes
|
63 |
+
arr_data, sr = sf.read(file=io.BytesIO(byte_data), dtype="float32")
|
64 |
+
|
65 |
+
# reduce audio
|
66 |
+
audio = audio_speed_reduce(arr_data, sr)
|
67 |
+
return audio
|
68 |
+
|
69 |
+
|
70 |
+
def get_langage_cls(audio_arr: np.array, model: torch.nn.Module):
|
71 |
+
|
72 |
+
# data slice 30 sec
|
73 |
+
audio = whisper.pad_or_trim(audio_arr)
|
74 |
+
|
75 |
+
# make log-Mel spectrogram and move to the same device as the model
|
76 |
+
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
77 |
+
# detect the spoken language
|
78 |
+
_, probs = model.detect_language(mel)
|
79 |
+
|
80 |
+
return probs
|
81 |
+
|
82 |
+
|
83 |
+
def transcribe(audio: np.array, model: torch.nn.Module, task: str = "transcribe"):
|
84 |
+
|
85 |
+
base_option = dict(beam_size=5, best_of=5)
|
86 |
+
|
87 |
+
if task == "transcribe":
|
88 |
+
base_option = dict(task="transcribe", **base_option)
|
89 |
+
else:
|
90 |
+
base_option = dict(task="translate", **base_option)
|
91 |
+
|
92 |
+
result = model.transcribe(audio, **base_option)
|
93 |
+
return result["text"]
|
94 |
+
|
95 |
+
|
96 |
+
def load_model(model_name: str):
|
97 |
+
model = whisper.load_model(model_name)
|
98 |
+
return model
|
99 |
+
|
100 |
+
|
101 |
+
file_data = st.file_uploader("Upload your audio file")
|
102 |
+
|
103 |
+
|
104 |
+
if file_data is not None:
|
105 |
+
# To read file as bytes:
|
106 |
+
bytes_data = file_data.getvalue()
|
107 |
+
|
108 |
+
audio_arr = convert_byte_audio(bytes_data)
|
109 |
+
|
110 |
+
# audio plotting
|
111 |
+
fig, ax = plt.subplots()
|
112 |
+
ax.plot(audio_arr)
|
113 |
+
st.pyplot(fig)
|
114 |
+
|
115 |
+
st.audio(bytes_data)
|
116 |
+
|
117 |
+
model_option = [
|
118 |
+
"tiny",
|
119 |
+
"base",
|
120 |
+
"small",
|
121 |
+
"medium",
|
122 |
+
"large",
|
123 |
+
]
|
124 |
+
selected_model_size = st.selectbox(
|
125 |
+
"What do you want model size?", ["None"] + model_option
|
126 |
+
)
|
127 |
+
|
128 |
+
if selected_model_size in model_option:
|
129 |
+
model = load_model(selected_model_size)
|
130 |
+
|
131 |
+
lang_button = st.button("What is langage")
|
132 |
+
if lang_button:
|
133 |
+
probs = get_langage_cls(audio_arr=audio_arr, model=model)
|
134 |
+
st.write(f"Detected language: {max(probs, key=probs.get)}")
|
135 |
+
|
136 |
+
task_option = ["transcribe", "translate"]
|
137 |
+
translate_task = st.selectbox("What is your task", ["None"] + task_option)
|
138 |
+
|
139 |
+
if translate_task != "None":
|
140 |
+
result = transcribe(audio=audio_arr, model=model, task=translate_task)
|
141 |
+
st.write(result)
|