File size: 4,922 Bytes
b815c4a
a5753ad
4841807
f427fe9
b815c4a
e564472
f427fe9
 
353faef
b815c4a
353faef
 
 
f427fe9
353faef
e564472
b815c4a
353faef
b815c4a
e564472
f427fe9
 
 
 
 
 
 
 
 
353faef
 
 
 
 
 
 
f427fe9
 
a5753ad
353faef
e564472
a5753ad
e564472
 
353faef
 
6d2ca12
 
e564472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353faef
6d2ca12
353faef
 
e564472
6d2ca12
 
 
 
 
 
353faef
 
 
6d2ca12
353faef
 
 
 
6d2ca12
 
 
353faef
 
6d2ca12
 
 
e564472
 
 
353faef
 
 
e564472
6d2ca12
 
e564472
 
 
 
6d2ca12
e564472
 
 
353faef
 
 
e564472
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import pickle
import whisper
import streamlit as st
import torchaudio as ta
import numpy as np

from io import BytesIO
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# Set up device and dtype
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda:0" else torch.float32

SAMPLING_RATE = 16000
CHUNK_LENGTH_S = 20  # 30 seconds per chunk

# Load Whisper model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)

# Title of the app
st.title("Audio Player with Live Transcription")

# Sidebar for file uploader and submit button
st.sidebar.header("Upload Audio Files")
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True)
submit_button = st.sidebar.button("Submit")

# Session state to hold data
if 'audio_files' not in st.session_state:
    st.session_state.audio_files = []
    st.session_state.transcriptions = {}
    st.session_state.translations = {}
    st.session_state.detected_languages = []
    st.session_state.waveforms = []


def detect_language(audio_file):
    whisper_model = whisper.load_model("small")
    trimmed_audio = whisper.pad_or_trim(audio_file.squeeze())
    mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
    _, probs = whisper_model.detect_language(mel)
    detected_lang = max(probs[0], key=probs[0].get)
    print(f"Detected language: {detected_lang}")
    return detected_lang


def process_long_audio(waveform, sampling_rate, task="transcribe", language=None):
    input_length = waveform.shape[1]
    chunk_length = int(CHUNK_LENGTH_S * sampling_rate)
    chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)]

    results = []
    for chunk in chunks:
        # import pdb;pdb.set_trace()
        input_features = processor(chunk[0], sampling_rate=sampling_rate, return_tensors="pt").input_features.to(device)

        with torch.no_grad():
            if task == "translate":
                forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
                generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
            else:
                generated_ids = model.generate(input_features)

        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
        results.extend(transcription)

    return " ".join(results)


# Process uploaded files
if submit_button and uploaded_files is not None:
    st.session_state.audio_files = uploaded_files
    st.session_state.detected_languages = []
    st.session_state.waveforms = []

    for uploaded_file in uploaded_files:
        waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read()))
        if sampling_rate != SAMPLING_RATE:
            waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)

        st.session_state.waveforms.append(waveform)
        detected_language = detect_language(waveform)
        st.session_state.detected_languages.append(detected_language)

# Display uploaded files and options
if 'audio_files' in st.session_state and st.session_state.audio_files:
    for i, uploaded_file in enumerate(st.session_state.audio_files):
        col1, col2 = st.columns([1, 3])

        with col1:
            st.write(f"**File name**: {uploaded_file.name}")
            st.audio(BytesIO(uploaded_file.read()), format=uploaded_file.type)
            st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}")

        with col2:
            if st.button(f"Transcribe {uploaded_file.name}"):
                with st.spinner("Transcribing..."):
                    transcription = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE)
                    st.session_state.transcriptions[i] = transcription

            if st.session_state.transcriptions.get(i):
                st.write("**Transcription**:")
                st.write(st.session_state.transcriptions[i])

            if st.button(f"Translate {uploaded_file.name}"):
                with st.spinner("Translating..."):
                    with open('languages.pkl', 'rb') as f:
                        lang_dict = pickle.load(f)
                    detected_language_name = lang_dict[st.session_state.detected_languages[i]]

                    translation = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE, task="translate",
                                                     language=detected_language_name)
                    st.session_state.translations[i] = translation

            if st.session_state.translations.get(i):
                st.write("**Translation**:")
                st.write(st.session_state.translations[i])