sachinsen1295's picture
Create app.py
54d6ffa verified
raw
history blame
2.84 kB
import gradio as gr
import whisper
import torch
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio import Audio
from pyannote.core import Segment
import subprocess
import wave
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import os
import datetime
# Load models
model_size = "medium"
whisper_model = whisper.load_model(model_size)
embedding_model = PretrainedSpeakerEmbedding(
"speechbrain/spkrec-ecapa-voxceleb",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
audio_processor = Audio()
def process_audio(file_path, num_speakers):
# Convert to WAV if necessary
if not file_path.endswith(".wav"):
wav_path = file_path.replace(file_path.split('.')[-1], 'wav')
subprocess.call(['ffmpeg', '-i', file_path, wav_path, '-y'])
file_path = wav_path
# Get audio duration
with wave.open(file_path, 'r') as f:
frames = f.getnframes()
rate = f.getframerate()
duration = frames / float(rate)
# Transcribe audio
result = whisper_model.transcribe(file_path)
segments = result["segments"]
# Generate speaker embeddings
embeddings = np.zeros(shape=(len(segments), 192))
for i, segment in enumerate(segments):
start = segment["start"]
end = min(duration, segment["end"])
clip = Segment(start, end)
waveform, _ = audio_processor.crop(file_path, clip)
embeddings[i] = embedding_model(waveform[None])
embeddings = np.nan_to_num(embeddings)
# Perform clustering
clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
labels = clustering.labels_
for i, segment in enumerate(segments):
segment["speaker"] = f"SPEAKER {labels[i] + 1}"
# Generate transcript
transcript = []
for segment in segments:
speaker = segment["speaker"]
start_time = str(datetime.timedelta(seconds=round(segment["start"])))
text = segment["text"]
transcript.append(f"{speaker} ({start_time}): {text}")
# Clean up
os.remove(file_path)
return "\n".join(transcript)
# Gradio interface
def diarize(audio_file, num_speakers):
file_path = "temp_audio.wav"
with open(file_path, "wb") as f:
f.write(audio_file.read())
return process_audio(file_path, num_speakers)
# UI
interface = gr.Interface(
fn=diarize,
inputs=[
gr.Audio(source="upload", type="file", label="Upload Audio File"),
gr.Number(label="Number of Speakers", value=2, precision=0),
],
outputs=gr.Textbox(label="Transcript"),
title="Speaker Diarization & Transcription",
description="Upload an audio file, specify the number of speakers, and get a diarized transcript."
)
if __name__ == "__main__":
interface.launch()