File size: 3,380 Bytes
54d6ffa
 
 
7f4f456
54d6ffa
 
 
 
 
7f4f456
54d6ffa
7f4f456
54d6ffa
 
 
 
 
3bfa8d4
54d6ffa
 
 
7f4f456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54d6ffa
 
7f4f456
54d6ffa
 
 
 
7f4f456
 
54d6ffa
 
 
7f4f456
 
 
 
 
 
 
54d6ffa
 
 
 
 
7f4f456
 
 
 
 
 
54d6ffa
 
7f4f456
 
 
 
54d6ffa
7f4f456
54d6ffa
 
 
7f4f456
 
54d6ffa
7f4f456
54d6ffa
 
 
deb4050
54d6ffa
7f4f456
54d6ffa
 
 
 
 
 
7f4f456
54d6ffa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import whisper
import torch
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio import Audio
from pyannote.core import Segment
import subprocess
import wave
import contextlib
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import datetime

# Load models
embedding_model = PretrainedSpeakerEmbedding(
    "speechbrain/spkrec-ecapa-voxceleb",
    device=torch.device("cuda")  # Use "cuda" if a GPU is available
)
audio_processor = Audio()

# Function to process the audio file and extract transcript and diarization
def process_audio(audio_file, num_speakers, model_size="medium", language="English"):
    # Save the uploaded file to a path
    path = "/tmp/uploaded_audio.wav"
    with open(path, "wb") as f:
        f.write(audio_file.read())
    
    # Convert audio to WAV if it's not already
    if path[-3:] != 'wav':
        wav_path = path.replace(path.split('.')[-1], 'wav')
        subprocess.call(['ffmpeg', '-i', path, wav_path, '-y'])
        path = wav_path

    # Load Whisper model
    model = whisper.load_model(model_size)
    result = model.transcribe(path)
    segments = result["segments"]

    # Get audio duration
    with contextlib.closing(wave.open(path, 'r')) as f:
        frames = f.getnframes()
        rate = f.getframerate()
        duration = frames / float(rate)

    # Function to generate segment embeddings
    def segment_embedding(segment):
        start = segment["start"]
        end = min(duration, segment["end"])
        clip = Segment(start, end)
        waveform, sample_rate = audio_processor.crop(path, clip)
        return embedding_model(waveform[None])

    embeddings = np.zeros(shape=(len(segments), 192))
    for i, segment in enumerate(segments):
        embeddings[i] = segment_embedding(segment)
    
    embeddings = np.nan_to_num(embeddings)

    # Perform clustering
    clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
    labels = clustering.labels_
    for i in range(len(segments)):
        segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)

    # Format the transcript
    def time(secs):
        return str(datetime.timedelta(seconds=round(secs)))

    transcript = []
    for i, segment in enumerate(segments):
        if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
            transcript.append(f"\n{segment['speaker']} {time(segment['start'])}")
        transcript.append(segment["text"][1:])  # Remove leading whitespace

    # Return the final transcript as a string
    return "\n".join(transcript)

# Gradio interface
def diarize(audio_file, num_speakers, model_size="medium"):
    return process_audio(audio_file, num_speakers, model_size)

# Gradio UI
interface = gr.Interface(
    fn=diarize,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio File"),  # Use 'filepath' here
        gr.Number(label="Number of Speakers", value=2, precision=0),
        gr.Radio(["tiny", "base", "small", "medium", "large"], label="Model Size", value="medium")
    ],
    outputs=gr.Textbox(label="Transcript"),
    title="Speaker Diarization & Transcription",
    description="Upload an audio file, specify the number of speakers, and get a diarized transcript."
)

# Run the Gradio app
if __name__ == "__main__":
    interface.launch()