|
import gradio as gr |
|
import whisper |
|
import torch |
|
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding |
|
from pyannote.audio import Audio |
|
from pyannote.core import Segment |
|
import subprocess |
|
import wave |
|
import numpy as np |
|
from sklearn.cluster import AgglomerativeClustering |
|
import os |
|
import datetime |
|
|
|
|
|
model_size = "medium" |
|
whisper_model = whisper.load_model(model_size) |
|
embedding_model = PretrainedSpeakerEmbedding( |
|
"speechbrain/spkrec-ecapa-voxceleb", |
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
) |
|
audio_processor = Audio() |
|
|
|
def process_audio(file_path, num_speakers): |
|
|
|
if not file_path.endswith(".wav"): |
|
wav_path = file_path.replace(file_path.split('.')[-1], 'wav') |
|
subprocess.call(['ffmpeg', '-i', file_path, wav_path, '-y']) |
|
file_path = wav_path |
|
|
|
|
|
with wave.open(file_path, 'r') as f: |
|
frames = f.getnframes() |
|
rate = f.getframerate() |
|
duration = frames / float(rate) |
|
|
|
|
|
result = whisper_model.transcribe(file_path) |
|
segments = result["segments"] |
|
|
|
|
|
embeddings = np.zeros(shape=(len(segments), 192)) |
|
for i, segment in enumerate(segments): |
|
start = segment["start"] |
|
end = min(duration, segment["end"]) |
|
clip = Segment(start, end) |
|
waveform, _ = audio_processor.crop(file_path, clip) |
|
embeddings[i] = embedding_model(waveform[None]) |
|
embeddings = np.nan_to_num(embeddings) |
|
|
|
|
|
clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings) |
|
labels = clustering.labels_ |
|
for i, segment in enumerate(segments): |
|
segment["speaker"] = f"SPEAKER {labels[i] + 1}" |
|
|
|
|
|
transcript = [] |
|
for segment in segments: |
|
speaker = segment["speaker"] |
|
start_time = str(datetime.timedelta(seconds=round(segment["start"]))) |
|
text = segment["text"] |
|
transcript.append(f"{speaker} ({start_time}): {text}") |
|
|
|
|
|
os.remove(file_path) |
|
return "\n".join(transcript) |
|
|
|
|
|
def diarize(audio_file, num_speakers): |
|
file_path = "temp_audio.wav" |
|
with open(file_path, "wb") as f: |
|
f.write(audio_file.read()) |
|
return process_audio(file_path, num_speakers) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=diarize, |
|
inputs=[ |
|
gr.Audio(source="upload", type="file", label="Upload Audio File"), |
|
gr.Number(label="Number of Speakers", value=2, precision=0), |
|
], |
|
outputs=gr.Textbox(label="Transcript"), |
|
title="Speaker Diarization & Transcription", |
|
description="Upload an audio file, specify the number of speakers, and get a diarized transcript." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|