sachinsen1295 commited on
Commit
54d6ffa
·
verified ·
1 Parent(s): 7dd909c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import whisper
3
+ import torch
4
+ from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
5
+ from pyannote.audio import Audio
6
+ from pyannote.core import Segment
7
+ import subprocess
8
+ import wave
9
+ import numpy as np
10
+ from sklearn.cluster import AgglomerativeClustering
11
+ import os
12
+ import datetime
13
+
14
+ # Load models
15
+ model_size = "medium"
16
+ whisper_model = whisper.load_model(model_size)
17
+ embedding_model = PretrainedSpeakerEmbedding(
18
+ "speechbrain/spkrec-ecapa-voxceleb",
19
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ )
21
+ audio_processor = Audio()
22
+
23
+ def process_audio(file_path, num_speakers):
24
+ # Convert to WAV if necessary
25
+ if not file_path.endswith(".wav"):
26
+ wav_path = file_path.replace(file_path.split('.')[-1], 'wav')
27
+ subprocess.call(['ffmpeg', '-i', file_path, wav_path, '-y'])
28
+ file_path = wav_path
29
+
30
+ # Get audio duration
31
+ with wave.open(file_path, 'r') as f:
32
+ frames = f.getnframes()
33
+ rate = f.getframerate()
34
+ duration = frames / float(rate)
35
+
36
+ # Transcribe audio
37
+ result = whisper_model.transcribe(file_path)
38
+ segments = result["segments"]
39
+
40
+ # Generate speaker embeddings
41
+ embeddings = np.zeros(shape=(len(segments), 192))
42
+ for i, segment in enumerate(segments):
43
+ start = segment["start"]
44
+ end = min(duration, segment["end"])
45
+ clip = Segment(start, end)
46
+ waveform, _ = audio_processor.crop(file_path, clip)
47
+ embeddings[i] = embedding_model(waveform[None])
48
+ embeddings = np.nan_to_num(embeddings)
49
+
50
+ # Perform clustering
51
+ clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
52
+ labels = clustering.labels_
53
+ for i, segment in enumerate(segments):
54
+ segment["speaker"] = f"SPEAKER {labels[i] + 1}"
55
+
56
+ # Generate transcript
57
+ transcript = []
58
+ for segment in segments:
59
+ speaker = segment["speaker"]
60
+ start_time = str(datetime.timedelta(seconds=round(segment["start"])))
61
+ text = segment["text"]
62
+ transcript.append(f"{speaker} ({start_time}): {text}")
63
+
64
+ # Clean up
65
+ os.remove(file_path)
66
+ return "\n".join(transcript)
67
+
68
+ # Gradio interface
69
+ def diarize(audio_file, num_speakers):
70
+ file_path = "temp_audio.wav"
71
+ with open(file_path, "wb") as f:
72
+ f.write(audio_file.read())
73
+ return process_audio(file_path, num_speakers)
74
+
75
+ # UI
76
+ interface = gr.Interface(
77
+ fn=diarize,
78
+ inputs=[
79
+ gr.Audio(source="upload", type="file", label="Upload Audio File"),
80
+ gr.Number(label="Number of Speakers", value=2, precision=0),
81
+ ],
82
+ outputs=gr.Textbox(label="Transcript"),
83
+ title="Speaker Diarization & Transcription",
84
+ description="Upload an audio file, specify the number of speakers, and get a diarized transcript."
85
+ )
86
+
87
+ if __name__ == "__main__":
88
+ interface.launch()