Spaces:
Build error
Build error
import torch | |
import pickle | |
import whisper | |
import streamlit as st | |
import torchaudio as ta | |
from io import BytesIO | |
from transformers import AutoProcessor, SeamlessM4TModel, WhisperProcessor, WhisperForConditionalGeneration | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
torch_dtype = torch.float16 | |
else: | |
device = "cpu" | |
torch_dtype = torch.float32 | |
SAMPLING_RATE=16000 | |
task = "transcribe" | |
print(f"{device} Active!") | |
# load Whisper model and processor | |
processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
# Title of the app | |
st.title("Audio Player with Live Transcription") | |
# Sidebar for file uploader and submit button | |
st.sidebar.header("Upload Audio Files") | |
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True) | |
submit_button = st.sidebar.button("Submit") | |
# def transcribe_audio(audio_data): | |
# recognizer = sr.Recognizer() | |
# with sr.AudioFile(audio_data) as source: | |
# audio = recognizer.record(source) | |
# try: | |
# # Transcribe the audio using Google Web Speech API | |
# transcription = recognizer.recognize_google(audio) | |
# return transcription | |
# except sr.UnknownValueError: | |
# return "Unable to transcribe the audio." | |
# except sr.RequestError as e: | |
# return f"Could not request results; {e}" | |
def detect_language(audio_file): | |
whisper_model = whisper.load_model("base") | |
mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device) | |
# detect the spoken language | |
_, probs = whisper_model.detect_language(mel) | |
print(f"Detected language: {max(probs[0], key=probs[0].get)}") | |
return max(probs[0], key=probs[0].get) | |
# if submit_button and uploaded_files is not None: | |
# st.write("Files uploaded successfully!") | |
# for uploaded_file in uploaded_files: | |
# # Display file name and audio player | |
# st.write(f"**File name**: {uploaded_file.name}") | |
# st.audio(uploaded_file, format=uploaded_file.type) | |
# # Transcription section | |
# st.write("**Transcription**:") | |
# # Read the uploaded file data | |
# waveform, sampling_rate = ta.load(uploaded_file.getvalue()) | |
# resampled_inp = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE) | |
# input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features | |
# if task == "translate": | |
# # Detect Language | |
# lang = detect_language(input_features) | |
# with open('languages.pkl', 'rb') as f: | |
# lang_dict = pickle.load(f) | |
# detected_language = lang_dict[lang] | |
# # Set decoder & Predict translation | |
# forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language, task="translate") | |
# predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) | |
# else: | |
# predicted_ids = model.generate(input_features) | |
# # decode token ids to text | |
# transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
# for i in range(len(transcription)): | |
# st.write(transcription[i]) | |
# # print(waveform, sampling_rate) | |
# # Run transcription function and display | |
# # import pdb;pdb.set_trace() | |
# # st.write(audio_data.getvalue()) | |
if submit_button and uploaded_files is not None: | |
# Initialize a list to store detected languages | |
detected_languages = [] | |
for uploaded_file in uploaded_files: | |
# Read the uploaded file data | |
waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read())) | |
# Resample if necessary | |
if sampling_rate != SAMPLING_RATE: | |
waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE) | |
# Detect language | |
detected_language = detect_language(waveform, SAMPLING_RATE) | |
detected_languages.append(detected_language) | |
# Display each uploaded file with its detected language and an audio player | |
for i, uploaded_file in enumerate(uploaded_files): | |
col1, col2 = st.columns([1, 3]) # Two columns, one for the player, one for the buttons | |
with col1: | |
st.write(f"**File name**: {uploaded_file.name}") | |
st.audio(BytesIO(uploaded_file.getvalue()), format=uploaded_file.type) | |
st.write(f"**Detected Language**: {detected_languages[i]}") | |
with col2: | |
# Add Transcription and Translation buttons | |
if st.button(f"Transcribe {uploaded_file.name}"): | |
# Transcription process | |
input_features = processor(waveform[0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features | |
predicted_ids = model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
for line in transcription: | |
st.write(line) | |
if st.button(f"Translate {uploaded_file.name}"): | |
# Translation process | |
with open('languages.pkl', 'rb') as f: | |
lang_dict = pickle.load(f) | |
detected_language_name = lang_dict[detected_languages[i]] | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate") | |
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) | |
translation = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
for line in translation: | |
st.write(line) | |