sachinsen1295 commited on
Commit
7f4f456
·
verified ·
1 Parent(s): 52f033a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -38
app.py CHANGED
@@ -1,88 +1,98 @@
1
  import gradio as gr
2
  import whisper
3
  import torch
 
4
  from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
5
  from pyannote.audio import Audio
6
  from pyannote.core import Segment
7
  import subprocess
8
  import wave
9
- import numpy as np
10
  from sklearn.cluster import AgglomerativeClustering
11
- import os
12
  import datetime
13
 
14
  # Load models
15
- model_size = "medium"
16
- whisper_model = whisper.load_model(model_size)
17
  embedding_model = PretrainedSpeakerEmbedding(
18
  "speechbrain/spkrec-ecapa-voxceleb",
19
- device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  )
21
  audio_processor = Audio()
22
 
23
- def process_audio(file_path, num_speakers):
24
- # Convert to WAV if necessary
25
- if not file_path.endswith(".wav"):
26
- wav_path = file_path.replace(file_path.split('.')[-1], 'wav')
27
- subprocess.call(['ffmpeg', '-i', file_path, wav_path, '-y'])
28
- file_path = wav_path
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Get audio duration
31
- with wave.open(file_path, 'r') as f:
32
  frames = f.getnframes()
33
  rate = f.getframerate()
34
  duration = frames / float(rate)
35
 
36
- # Transcribe audio
37
- result = whisper_model.transcribe(file_path)
38
- segments = result["segments"]
39
-
40
- # Generate speaker embeddings
41
- embeddings = np.zeros(shape=(len(segments), 192))
42
- for i, segment in enumerate(segments):
43
  start = segment["start"]
44
  end = min(duration, segment["end"])
45
  clip = Segment(start, end)
46
- waveform, _ = audio_processor.crop(file_path, clip)
47
- embeddings[i] = embedding_model(waveform[None])
 
 
 
 
 
48
  embeddings = np.nan_to_num(embeddings)
49
 
50
  # Perform clustering
51
  clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
52
  labels = clustering.labels_
53
- for i, segment in enumerate(segments):
54
- segment["speaker"] = f"SPEAKER {labels[i] + 1}"
 
 
 
 
55
 
56
- # Generate transcript
57
  transcript = []
58
- for segment in segments:
59
- speaker = segment["speaker"]
60
- start_time = str(datetime.timedelta(seconds=round(segment["start"])))
61
- text = segment["text"]
62
- transcript.append(f"{speaker} ({start_time}): {text}")
63
 
64
- # Clean up
65
- os.remove(file_path)
66
  return "\n".join(transcript)
67
 
68
  # Gradio interface
69
- def diarize(audio_file, num_speakers):
70
- file_path = "temp_audio.wav"
71
- with open(file_path, "wb") as f:
72
- f.write(audio_file.read())
73
- return process_audio(file_path, num_speakers)
74
 
75
- # UI
76
  interface = gr.Interface(
77
  fn=diarize,
78
  inputs=[
79
- gr.Audio(source="upload", type="file", label="Upload Audio File"),
80
  gr.Number(label="Number of Speakers", value=2, precision=0),
 
81
  ],
82
  outputs=gr.Textbox(label="Transcript"),
83
  title="Speaker Diarization & Transcription",
84
  description="Upload an audio file, specify the number of speakers, and get a diarized transcript."
85
  )
86
 
 
87
  if __name__ == "__main__":
88
  interface.launch()
 
1
  import gradio as gr
2
  import whisper
3
  import torch
4
+ import pyannote.audio
5
  from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
6
  from pyannote.audio import Audio
7
  from pyannote.core import Segment
8
  import subprocess
9
  import wave
10
+ import contextlib
11
  from sklearn.cluster import AgglomerativeClustering
12
+ import numpy as np
13
  import datetime
14
 
15
  # Load models
 
 
16
  embedding_model = PretrainedSpeakerEmbedding(
17
  "speechbrain/spkrec-ecapa-voxceleb",
18
+ device=torch.device("cpu") # Use "cuda" if a GPU is available
19
  )
20
  audio_processor = Audio()
21
 
22
+ # Function to process the audio file and extract transcript and diarization
23
+ def process_audio(audio_file, num_speakers, model_size="medium", language="English"):
24
+ # Save the uploaded file to a path
25
+ path = "/tmp/uploaded_audio.wav"
26
+ with open(path, "wb") as f:
27
+ f.write(audio_file.read())
28
+
29
+ # Convert audio to WAV if it's not already
30
+ if path[-3:] != 'wav':
31
+ wav_path = path.replace(path.split('.')[-1], 'wav')
32
+ subprocess.call(['ffmpeg', '-i', path, wav_path, '-y'])
33
+ path = wav_path
34
+
35
+ # Load Whisper model
36
+ model = whisper.load_model(model_size)
37
+ result = model.transcribe(path)
38
+ segments = result["segments"]
39
 
40
  # Get audio duration
41
+ with contextlib.closing(wave.open(path, 'r')) as f:
42
  frames = f.getnframes()
43
  rate = f.getframerate()
44
  duration = frames / float(rate)
45
 
46
+ # Function to generate segment embeddings
47
+ def segment_embedding(segment):
 
 
 
 
 
48
  start = segment["start"]
49
  end = min(duration, segment["end"])
50
  clip = Segment(start, end)
51
+ waveform, sample_rate = audio_processor.crop(path, clip)
52
+ return embedding_model(waveform[None])
53
+
54
+ embeddings = np.zeros(shape=(len(segments), 192))
55
+ for i, segment in enumerate(segments):
56
+ embeddings[i] = segment_embedding(segment)
57
+
58
  embeddings = np.nan_to_num(embeddings)
59
 
60
  # Perform clustering
61
  clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
62
  labels = clustering.labels_
63
+ for i in range(len(segments)):
64
+ segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
65
+
66
+ # Format the transcript
67
+ def time(secs):
68
+ return str(datetime.timedelta(seconds=round(secs)))
69
 
 
70
  transcript = []
71
+ for i, segment in enumerate(segments):
72
+ if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
73
+ transcript.append(f"\n{segment['speaker']} {time(segment['start'])}")
74
+ transcript.append(segment["text"][1:]) # Remove leading whitespace
 
75
 
76
+ # Return the final transcript as a string
 
77
  return "\n".join(transcript)
78
 
79
  # Gradio interface
80
+ def diarize(audio_file, num_speakers, model_size="medium"):
81
+ return process_audio(audio_file, num_speakers, model_size)
 
 
 
82
 
83
+ # Gradio UI
84
  interface = gr.Interface(
85
  fn=diarize,
86
  inputs=[
87
+ gr.Audio(type="file", label="Upload Audio File"), # Removed 'source' argument
88
  gr.Number(label="Number of Speakers", value=2, precision=0),
89
+ gr.Radio(["tiny", "base", "small", "medium", "large"], label="Model Size", value="medium")
90
  ],
91
  outputs=gr.Textbox(label="Transcript"),
92
  title="Speaker Diarization & Transcription",
93
  description="Upload an audio file, specify the number of speakers, and get a diarized transcript."
94
  )
95
 
96
+ # Run the Gradio app
97
  if __name__ == "__main__":
98
  interface.launch()