Spaces:
Sleeping
Sleeping
import torch | |
import whisper | |
import torchaudio as ta | |
import gradio as gr | |
from model_utils import get_processor, get_model, get_whisper_model_small, get_device | |
from config import SAMPLING_RATE, CHUNK_LENGTH_S | |
import spaces | |
def load_and_resample_audio(audio): | |
if isinstance(audio, str): # If audio is a file path | |
waveform, sample_rate = ta.load(audio) | |
else: # If audio is already loaded (sample_rate, waveform) | |
sample_rate, waveform = audio | |
waveform = torch.tensor(waveform).float() | |
if sample_rate != SAMPLING_RATE: | |
waveform = ta.functional.resample(waveform, sample_rate, SAMPLING_RATE) | |
# Ensure the audio is in the correct shape (mono) | |
if waveform.dim() > 1 and waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
elif waveform.dim() == 1: | |
waveform = waveform.unsqueeze(0) | |
return waveform, SAMPLING_RATE | |
def detect_language(waveform): | |
whisper_model = get_whisper_model_small() | |
# Use Whisper's preprocessing | |
audio_tensor = whisper.pad_or_trim(waveform.squeeze()) | |
mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device) | |
# Detect language | |
_, probs = whisper_model.detect_language(mel) | |
detected_lang = max(probs, key=probs.get) | |
print(f"Audio shape: {audio_tensor.shape}") | |
print(f"Mel spectrogram shape: {mel.shape}") | |
print(f"Detected language: {detected_lang}") | |
print("Language probabilities:", probs) | |
return detected_lang | |
def process_long_audio(waveform, sample_rate, task="transcribe", language=None): | |
input_length = waveform.shape[1] | |
chunk_length = int(CHUNK_LENGTH_S * sample_rate) | |
chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)] | |
processor = get_processor() | |
model = get_model() | |
device = get_device() | |
results = [] | |
for chunk in chunks: | |
input_features = processor(chunk.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_features.to( | |
device) | |
with torch.no_grad(): | |
if task == "translate": | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate") | |
generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) | |
else: | |
generated_ids = model.generate(input_features) | |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
results.extend(transcription) | |
# Clear GPU cache | |
torch.cuda.empty_cache() | |
return " ".join(results) | |
def process_audio(audio): | |
if audio is None: | |
return "No file uploaded", "", "" | |
waveform, sample_rate = load_and_resample_audio(audio) | |
detected_lang = detect_language(waveform) | |
transcription = process_long_audio(waveform, sample_rate, task="transcribe") | |
translation = process_long_audio(waveform, sample_rate, task="translate", language=detected_lang) | |
return detected_lang, transcription, translation | |
# Gradio interface | |
iface = gr.Interface( | |
fn=process_audio, | |
inputs=gr.Audio(), | |
outputs=[ | |
gr.Textbox(label="Detected Language"), | |
gr.Textbox(label="Transcription", lines=5), | |
gr.Textbox(label="Translation", lines=5) | |
], | |
title="Audio Transcription and Translation", | |
description="Upload an audio file to detect its language, transcribe, and translate it.", | |
allow_flagging="never", | |
css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }" | |
) | |
if __name__ == "__main__": | |
iface.launch() |