Spaces:
Build error
Build error
File size: 5,809 Bytes
b815c4a a5753ad 4841807 f427fe9 b815c4a f427fe9 b815c4a f427fe9 b815c4a f427fe9 a5753ad 6d2ca12 f427fe9 6d2ca12 b815c4a 6d2ca12 f427fe9 6d2ca12 f427fe9 6d2ca12 f427fe9 6d2ca12 b815c4a 6d2ca12 4841807 6d2ca12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import pickle
import whisper
import streamlit as st
import torchaudio as ta
from io import BytesIO
from transformers import AutoProcessor, SeamlessM4TModel, WhisperProcessor, WhisperForConditionalGeneration
if torch.cuda.is_available():
device = "cuda:0"
torch_dtype = torch.float16
else:
device = "cpu"
torch_dtype = torch.float32
SAMPLING_RATE=16000
task = "transcribe"
print(f"{device} Active!")
# load Whisper model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
# Title of the app
st.title("Audio Player with Live Transcription")
# Sidebar for file uploader and submit button
st.sidebar.header("Upload Audio Files")
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True)
submit_button = st.sidebar.button("Submit")
# def transcribe_audio(audio_data):
# recognizer = sr.Recognizer()
# with sr.AudioFile(audio_data) as source:
# audio = recognizer.record(source)
# try:
# # Transcribe the audio using Google Web Speech API
# transcription = recognizer.recognize_google(audio)
# return transcription
# except sr.UnknownValueError:
# return "Unable to transcribe the audio."
# except sr.RequestError as e:
# return f"Could not request results; {e}"
def detect_language(audio_file):
whisper_model = whisper.load_model("base")
mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
# detect the spoken language
_, probs = whisper_model.detect_language(mel)
print(f"Detected language: {max(probs[0], key=probs[0].get)}")
return max(probs[0], key=probs[0].get)
# if submit_button and uploaded_files is not None:
# st.write("Files uploaded successfully!")
# for uploaded_file in uploaded_files:
# # Display file name and audio player
# st.write(f"**File name**: {uploaded_file.name}")
# st.audio(uploaded_file, format=uploaded_file.type)
# # Transcription section
# st.write("**Transcription**:")
# # Read the uploaded file data
# waveform, sampling_rate = ta.load(uploaded_file.getvalue())
# resampled_inp = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
# input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features
# if task == "translate":
# # Detect Language
# lang = detect_language(input_features)
# with open('languages.pkl', 'rb') as f:
# lang_dict = pickle.load(f)
# detected_language = lang_dict[lang]
# # Set decoder & Predict translation
# forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language, task="translate")
# predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
# else:
# predicted_ids = model.generate(input_features)
# # decode token ids to text
# transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
# for i in range(len(transcription)):
# st.write(transcription[i])
# # print(waveform, sampling_rate)
# # Run transcription function and display
# # import pdb;pdb.set_trace()
# # st.write(audio_data.getvalue())
if submit_button and uploaded_files is not None:
# Initialize a list to store detected languages
detected_languages = []
for uploaded_file in uploaded_files:
# Read the uploaded file data
waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read()))
# Resample if necessary
if sampling_rate != SAMPLING_RATE:
waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
# Detect language
detected_language = detect_language(waveform, SAMPLING_RATE)
detected_languages.append(detected_language)
# Display each uploaded file with its detected language and an audio player
for i, uploaded_file in enumerate(uploaded_files):
col1, col2 = st.columns([1, 3]) # Two columns, one for the player, one for the buttons
with col1:
st.write(f"**File name**: {uploaded_file.name}")
st.audio(BytesIO(uploaded_file.getvalue()), format=uploaded_file.type)
st.write(f"**Detected Language**: {detected_languages[i]}")
with col2:
# Add Transcription and Translation buttons
if st.button(f"Transcribe {uploaded_file.name}"):
# Transcription process
input_features = processor(waveform[0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
for line in transcription:
st.write(line)
if st.button(f"Translate {uploaded_file.name}"):
# Translation process
with open('languages.pkl', 'rb') as f:
lang_dict = pickle.load(f)
detected_language_name = lang_dict[detected_languages[i]]
forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate")
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
translation = processor.batch_decode(predicted_ids, skip_special_tokens=True)
for line in translation:
st.write(line)
|