|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torchaudio |
|
import time |
|
from transformers import pipeline |
|
from speechbrain.inference.classifiers import EncoderClassifier |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
|
|
|
model_name = "openai/whisper-tiny" |
|
processor = WhisperProcessor.from_pretrained(model_name) |
|
model = WhisperForConditionalGeneration.from_pretrained(model_name) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
|
|
|
|
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa") |
|
|
|
|
|
data = [] |
|
current_chunk = [] |
|
|
|
index_to_lang = { |
|
0: 'Abkhazian', 1: 'Afrikaans', 2: 'Amharic', 3: 'Arabic', 4: 'Assamese', |
|
5: 'Azerbaijani', 6: 'Bashkir', 7: 'Belarusian', 8: 'Bulgarian', 9: 'Bengali', |
|
10: 'Tibetan', 11: 'Breton', 12: 'Bosnian', 13: 'Catalan', 14: 'Cebuano', |
|
15: 'Czech', 16: 'Welsh', 17: 'Danish', 18: 'German', 19: 'Greek', |
|
20: 'English', 21: 'Esperanto', 22: 'Spanish', 23: 'Estonian', 24: 'Basque', |
|
25: 'Persian', 26: 'Finnish', 27: 'Faroese', 28: 'French', 29: 'Galician', |
|
30: 'Guarani', 31: 'Gujarati', 32: 'Manx', 33: 'Hausa', 34: 'Hawaiian', |
|
35: 'Hindi', 36: 'Croatian', 37: 'Haitian', 38: 'Hungarian', 39: 'Armenian', |
|
40: 'Interlingua', 41: 'Indonesian', 42: 'Icelandic', 43: 'Italian', 44: 'Hebrew', |
|
45: 'Japanese', 46: 'Javanese', 47: 'Georgian', 48: 'Kazakh', 49: 'Central Khmer', |
|
50: 'Kannada', 51: 'Korean', 52: 'Latin', 53: 'Luxembourgish', 54: 'Lingala', |
|
55: 'Lao', 56: 'Lithuanian', 57: 'Latvian', 58: 'Malagasy', 59: 'Maori', |
|
60: 'Macedonian', 61: 'Malayalam', 62: 'Mongolian', 63: 'Marathi', 64: 'Malay', |
|
65: 'Maltese', 66: 'Burmese', 67: 'Nepali', 68: 'Dutch', 69: 'Norwegian Nynorsk', |
|
70: 'Norwegian', 71: 'Occitan', 72: 'Panjabi', 73: 'Polish', 74: 'Pushto', |
|
75: 'Portuguese', 76: 'Romanian', 77: 'Russian', 78: 'Sanskrit', 79: 'Scots', |
|
80: 'Sindhi', 81: 'Sinhala', 82: 'Slovak', 83: 'Slovenian', 84: 'Shona', |
|
85: 'Somali', 86: 'Albanian', 87: 'Serbian', 88: 'Sundanese', 89: 'Swedish', |
|
90: 'Swahili', 91: 'Tamil', 92: 'Telugu', 93: 'Tajik', 94: 'Thai', |
|
95: 'Turkmen', 96: 'Tagalog', 97: 'Turkish', 98: 'Tatar', 99: 'Ukrainian', |
|
100: 'Urdu', 101: 'Uzbek', 102: 'Vietnamese', 103: 'Waray', 104: 'Yiddish', |
|
105: 'Yoruba', 106: 'Chinese' |
|
} |
|
lang_index_JA_EN = { |
|
'ja': 45, |
|
'en': 20, |
|
} |
|
SAMPLING_RATE = 16000 |
|
CHUNK_DURATION = 5 |
|
|
|
|
|
def normalize_audio(audio): |
|
|
|
audio = audio / np.max(np.abs(audio)) |
|
return audio |
|
|
|
|
|
def resample_audio(audio, orig_sr, target_sr=16000): |
|
if orig_sr != target_sr: |
|
print(f"Resampling audio from {orig_sr} to {target_sr}") |
|
audio = audio.astype(np.float32) |
|
resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr) |
|
audio = resampler(torch.from_numpy(audio).unsqueeze(0)).squeeze(0).numpy() |
|
return audio |
|
|
|
|
|
def process_audio(audio): |
|
global data, current_chunk |
|
print("Process_audio") |
|
print(audio) |
|
sr, audio_data = audio |
|
|
|
|
|
print(audio_data.shape, audio_data.dtype) |
|
|
|
audio_data = resample_audio(audio_data, sr, target_sr=SAMPLING_RATE) |
|
audio_sec = 0 |
|
|
|
|
|
audio_data = normalize_audio(audio_data) |
|
|
|
|
|
current_chunk.append(audio_data) |
|
total_chunk = np.concatenate(current_chunk) |
|
|
|
while len(total_chunk) >= SAMPLING_RATE * CHUNK_DURATION: |
|
chunk = total_chunk[:SAMPLING_RATE * CHUNK_DURATION] |
|
total_chunk = total_chunk[SAMPLING_RATE * CHUNK_DURATION:] |
|
audio_sec += CHUNK_DURATION |
|
|
|
print(f"Processing audio chunk of length {len(chunk)}") |
|
volume_norm = np.linalg.norm(chunk) / np.finfo(np.float32).max |
|
length = len(chunk) / SAMPLING_RATE |
|
lang_guess = language_id.classify_batch(torch.from_numpy(chunk).unsqueeze(0)) |
|
|
|
|
|
ja_prob = lang_guess[0][0][lang_index_JA_EN['ja']].item() |
|
en_prob = lang_guess[0][0][lang_index_JA_EN['en']].item() |
|
ja_en = 'ja' if ja_prob > en_prob else 'en' |
|
|
|
|
|
top3_indices = torch.topk(lang_guess[0], 3, dim=1, largest=True).indices[0] |
|
top3_languages = [index_to_lang[idx.item()] for idx in top3_indices] |
|
|
|
input_features = processor(chunk, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_features.to(device) |
|
predicted_ids = model.generate(input_features) |
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
print(transcription) |
|
|
|
data.append({ |
|
|
|
"Time": audio_sec, |
|
"Length (s)": length, |
|
"Volume": volume_norm, |
|
"Japanese_English": f"{ja_en} ({ja_prob:.2f}, {en_prob:.2f})", |
|
"Language": top3_languages, |
|
"Text": transcription, |
|
}) |
|
|
|
df = pd.DataFrame(data) |
|
yield (SAMPLING_RATE, chunk), df |
|
|
|
|
|
current_chunk = [total_chunk] |
|
|
|
|
|
inputs = gr.Audio(sources=["microphone", "upload"], type="numpy") |
|
outputs = [gr.Audio(type="numpy"), gr.DataFrame(headers=["Time", "Volume", "Length (s)"])] |
|
|
|
demo = gr.Interface( |
|
fn=process_audio, |
|
inputs=inputs, |
|
outputs=outputs, |
|
live=True, |
|
title="Real-time Audio Processing", |
|
description="Speak into the microphone and see real-time audio processing results." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|