File size: 4,120 Bytes
b815c4a
a5753ad
4841807
f427fe9
b815c4a
f427fe9
 
353faef
b815c4a
353faef
 
 
f427fe9
353faef
b815c4a
353faef
b815c4a
 
f427fe9
 
 
 
 
 
 
 
 
353faef
 
 
 
 
 
 
f427fe9
 
a5753ad
353faef
 
a5753ad
353faef
 
 
 
6d2ca12
 
353faef
6d2ca12
353faef
 
6d2ca12
 
 
 
 
 
353faef
 
 
6d2ca12
353faef
 
 
 
6d2ca12
 
 
353faef
 
6d2ca12
 
353faef
 
 
6d2ca12
 
 
353faef
 
 
 
 
6d2ca12
 
 
 
 
353faef
6d2ca12
 
 
 
353faef
 
 
 
 
6d2ca12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import pickle
import whisper
import streamlit as st
import torchaudio as ta

from io import BytesIO
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# Set up device and dtype
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda:0" else torch.float32

SAMPLING_RATE = 16000

# Load Whisper model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

# Title of the app
st.title("Audio Player with Live Transcription")

# Sidebar for file uploader and submit button
st.sidebar.header("Upload Audio Files")
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True)
submit_button = st.sidebar.button("Submit")

# Session state to hold data
if 'audio_files' not in st.session_state:
    st.session_state.audio_files = []
    st.session_state.transcriptions = {}
    st.session_state.translations = {}
    st.session_state.detected_languages = []
    st.session_state.waveforms = []


def detect_language(audio_file):
    whisper_model = whisper.load_model("small")
    trimmed_audio = whisper.pad_or_trim(audio_file)
    mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
    _, probs = whisper_model.detect_language(mel[0])
    detected_lang = max(probs, key=probs.get)
    print(f"Detected language: {detected_lang}")
    return detected_lang


# Process uploaded files
if submit_button and uploaded_files is not None:
    st.session_state.audio_files = uploaded_files
    st.session_state.detected_languages = []

    for uploaded_file in uploaded_files:
        waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read()))
        if sampling_rate != SAMPLING_RATE:
            waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)

        st.session_state.waveforms.append(waveform)
        detected_language = detect_language(waveform)
        st.session_state.detected_languages.append(detected_language)

# Display uploaded files and options
if 'audio_files' in st.session_state and st.session_state.audio_files:
    for i, uploaded_file in enumerate(st.session_state.audio_files):
        col1, col2 = st.columns([1, 3])

        with col1:
            st.write(f"**File name**: {uploaded_file.name}")
            st.audio(BytesIO(uploaded_file.read()), format=uploaded_file.type)
            st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}")

        with col2:
            # import pdb;pdb.set_trace()
            input_features = processor(st.session_state.waveforms[i][0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features

            if st.button(f"Transcribe {uploaded_file.name}"):
                predicted_ids = model.generate(input_features)
                transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
                st.session_state.transcriptions[i] = transcription

            if st.session_state.transcriptions.get(i):
                st.write("**Transcription**:")
                for line in st.session_state.transcriptions[i]:
                    st.write(line)

            if st.button(f"Translate {uploaded_file.name}"):
                with open('languages.pkl', 'rb') as f:
                    lang_dict = pickle.load(f)
                detected_language_name = lang_dict[st.session_state.detected_languages[i]]

                forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate")
                predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
                translation = processor.batch_decode(predicted_ids, skip_special_tokens=True)
                st.session_state.translations[i] = translation

            if st.session_state.translations.get(i):
                st.write("**Translation**:")
                for line in st.session_state.translations[i]:
                    st.write(line)