Spaces:
Build error
Build error
File size: 4,922 Bytes
b815c4a a5753ad 4841807 f427fe9 b815c4a e564472 f427fe9 353faef b815c4a 353faef f427fe9 353faef e564472 b815c4a 353faef b815c4a e564472 f427fe9 353faef f427fe9 a5753ad 353faef e564472 a5753ad e564472 353faef 6d2ca12 e564472 353faef 6d2ca12 353faef e564472 6d2ca12 353faef 6d2ca12 353faef 6d2ca12 353faef 6d2ca12 e564472 353faef e564472 6d2ca12 e564472 6d2ca12 e564472 353faef e564472 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
import pickle
import whisper
import streamlit as st
import torchaudio as ta
import numpy as np
from io import BytesIO
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# Set up device and dtype
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda:0" else torch.float32
SAMPLING_RATE = 16000
CHUNK_LENGTH_S = 20 # 30 seconds per chunk
# Load Whisper model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
# Title of the app
st.title("Audio Player with Live Transcription")
# Sidebar for file uploader and submit button
st.sidebar.header("Upload Audio Files")
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True)
submit_button = st.sidebar.button("Submit")
# Session state to hold data
if 'audio_files' not in st.session_state:
st.session_state.audio_files = []
st.session_state.transcriptions = {}
st.session_state.translations = {}
st.session_state.detected_languages = []
st.session_state.waveforms = []
def detect_language(audio_file):
whisper_model = whisper.load_model("small")
trimmed_audio = whisper.pad_or_trim(audio_file.squeeze())
mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
_, probs = whisper_model.detect_language(mel)
detected_lang = max(probs[0], key=probs[0].get)
print(f"Detected language: {detected_lang}")
return detected_lang
def process_long_audio(waveform, sampling_rate, task="transcribe", language=None):
input_length = waveform.shape[1]
chunk_length = int(CHUNK_LENGTH_S * sampling_rate)
chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)]
results = []
for chunk in chunks:
# import pdb;pdb.set_trace()
input_features = processor(chunk[0], sampling_rate=sampling_rate, return_tensors="pt").input_features.to(device)
with torch.no_grad():
if task == "translate":
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
else:
generated_ids = model.generate(input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
results.extend(transcription)
return " ".join(results)
# Process uploaded files
if submit_button and uploaded_files is not None:
st.session_state.audio_files = uploaded_files
st.session_state.detected_languages = []
st.session_state.waveforms = []
for uploaded_file in uploaded_files:
waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read()))
if sampling_rate != SAMPLING_RATE:
waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
st.session_state.waveforms.append(waveform)
detected_language = detect_language(waveform)
st.session_state.detected_languages.append(detected_language)
# Display uploaded files and options
if 'audio_files' in st.session_state and st.session_state.audio_files:
for i, uploaded_file in enumerate(st.session_state.audio_files):
col1, col2 = st.columns([1, 3])
with col1:
st.write(f"**File name**: {uploaded_file.name}")
st.audio(BytesIO(uploaded_file.read()), format=uploaded_file.type)
st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}")
with col2:
if st.button(f"Transcribe {uploaded_file.name}"):
with st.spinner("Transcribing..."):
transcription = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE)
st.session_state.transcriptions[i] = transcription
if st.session_state.transcriptions.get(i):
st.write("**Transcription**:")
st.write(st.session_state.transcriptions[i])
if st.button(f"Translate {uploaded_file.name}"):
with st.spinner("Translating..."):
with open('languages.pkl', 'rb') as f:
lang_dict = pickle.load(f)
detected_language_name = lang_dict[st.session_state.detected_languages[i]]
translation = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE, task="translate",
language=detected_language_name)
st.session_state.translations[i] = translation
if st.session_state.translations.get(i):
st.write("**Translation**:")
st.write(st.session_state.translations[i]) |