File size: 4,258 Bytes
54d6ffa
 
 
7f4f456
54d6ffa
 
 
 
 
7f4f456
54d6ffa
7f4f456
54d6ffa
 
 
924e1ef
54d6ffa
924e1ef
 
54d6ffa
 
 
7f4f456
 
3ed4359
 
 
 
 
 
 
 
 
 
b6c6bcf
3ed4359
7f4f456
 
 
 
 
 
b6c6bcf
 
7f4f456
b6c6bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
7f4f456
54d6ffa
b6c6bcf
 
 
3ed4359
 
 
54d6ffa
7f4f456
54d6ffa
 
 
 
7f4f456
 
54d6ffa
 
 
7f4f456
 
 
 
 
 
 
54d6ffa
 
 
 
 
7f4f456
 
 
 
 
 
54d6ffa
 
7f4f456
 
 
 
54d6ffa
7f4f456
54d6ffa
 
 
7f4f456
 
54d6ffa
7f4f456
54d6ffa
 
 
deb4050
54d6ffa
7f4f456
54d6ffa
 
 
 
 
 
7f4f456
54d6ffa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import whisper
import torch
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio import Audio
from pyannote.core import Segment
import subprocess
import wave
import contextlib
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import datetime

# Load models
device = torch.device("cpu")  # Explicitly set device to CPU
embedding_model = PretrainedSpeakerEmbedding(
    "speechbrain/spkrec-ecapa-voxceleb", 
    device=device  # Ensure it uses CPU
)
audio_processor = Audio()

# Function to process the audio file and extract transcript and diarization
def process_audio(audio_file, num_speakers, model_size="medium", language="English"):
    # Check if the audio_file is a file-like object (it should be)
    if isinstance(audio_file, str):
        # If it's a string (path), open the file from the path
        path = audio_file  # directly use the path if it's a string
    else:
        # Otherwise, handle it as a file-like object
        path = "/tmp/uploaded_audio.wav"
        with open(path, "wb") as f:
            f.write(audio_file.read())  # read from the file-like object

    print(f"Audio file saved to: {path}")

    # Convert audio to WAV if it's not already
    if path[-3:] != 'wav':
        wav_path = path.replace(path.split('.')[-1], 'wav')
        subprocess.call(['ffmpeg', '-i', path, wav_path, '-y'])
        path = wav_path

    print(f"Audio converted to: {path}")

    # Load Whisper model
    try:
        model = whisper.load_model(model_size)
        print("Whisper model loaded successfully.")
    except Exception as e:
        print(f"Error loading Whisper model: {e}")
        return f"Error loading Whisper model: {e}"

    try:
        result = model.transcribe(path)
        print(f"Transcription result: {result}")
    except Exception as e:
        print(f"Error during transcription: {e}")
        return f"Error during transcription: {e}"

    segments = result["segments"]

    # Remaining processing code...


    # Remaining processing code...


    # Get audio duration
    with contextlib.closing(wave.open(path, 'r')) as f:
        frames = f.getnframes()
        rate = f.getframerate()
        duration = frames / float(rate)

    # Function to generate segment embeddings
    def segment_embedding(segment):
        start = segment["start"]
        end = min(duration, segment["end"])
        clip = Segment(start, end)
        waveform, sample_rate = audio_processor.crop(path, clip)
        return embedding_model(waveform[None])

    embeddings = np.zeros(shape=(len(segments), 192))
    for i, segment in enumerate(segments):
        embeddings[i] = segment_embedding(segment)
    
    embeddings = np.nan_to_num(embeddings)

    # Perform clustering
    clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
    labels = clustering.labels_
    for i in range(len(segments)):
        segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)

    # Format the transcript
    def time(secs):
        return str(datetime.timedelta(seconds=round(secs)))

    transcript = []
    for i, segment in enumerate(segments):
        if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
            transcript.append(f"\n{segment['speaker']} {time(segment['start'])}")
        transcript.append(segment["text"][1:])  # Remove leading whitespace

    # Return the final transcript as a string
    return "\n".join(transcript)

# Gradio interface
def diarize(audio_file, num_speakers, model_size="medium"):
    return process_audio(audio_file, num_speakers, model_size)

# Gradio UI
interface = gr.Interface(
    fn=diarize,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio File"),  # Use 'filepath' here
        gr.Number(label="Number of Speakers", value=2, precision=0),
        gr.Radio(["tiny", "base", "small", "medium", "large"], label="Model Size", value="medium")
    ],
    outputs=gr.Textbox(label="Transcript"),
    title="Speaker Diarization & Transcription",
    description="Upload an audio file, specify the number of speakers, and get a diarized transcript."
)

# Run the Gradio app
if __name__ == "__main__":
    interface.launch()