Ngoufack commited on
Commit
5fd6be5
·
1 Parent(s): 06a02ce
Files changed (1) hide show
  1. app.py +62 -43
app.py CHANGED
@@ -5,76 +5,87 @@ import yt_dlp as youtube_dl
5
  import tempfile
6
  import os
7
  import locale
 
8
  import datetime
9
  import subprocess
10
- import wave
11
- import contextlib
12
- import numpy as np
13
- from sklearn.cluster import AgglomerativeClustering
14
- from faster_whisper import WhisperModel
15
  from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
16
  from pyannote.audio import Audio
17
  from pyannote.core import Segment
 
 
 
 
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
- BATCH_SIZE = 8
21
  FILE_LIMIT_MB = 1000
22
- YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
 
23
 
24
  num_speakers = 2
25
  language = None
26
- model_size = 'tiny'
27
- model = WhisperModel(model_size, device=device, compute_type="float32")
28
- embedding_model = PretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb", device=torch.device("cpu"))
29
- audio = Audio()
30
 
31
- def getpreferredencoding(do_setlocale=True):
32
  return "UTF-8"
 
33
  locale.getpreferredencoding = getpreferredencoding
 
 
 
 
 
34
 
35
- def segment_embedding(segment, duration, path):
36
- start = segment.start
37
- end = min(duration, segment.end)
 
38
  clip = Segment(start, end)
39
  waveform, sample_rate = audio.crop(path, clip)
 
 
40
  waveform = waveform.mean(dim=0, keepdim=True)
 
41
  return embedding_model(waveform.unsqueeze(0))
42
 
43
  def time(secs):
44
- return datetime.timedelta(seconds=round(secs))
45
 
 
46
  def transcribe(path, task):
47
  if path is None:
48
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
49
 
50
- if not path.endswith('.wav'):
51
  subprocess.call(['ffmpeg', '-i', path, "audio.wav", '-y'])
52
  path = "audio.wav"
53
-
54
- segments, _ = model.transcribe(path)
55
-
56
- with contextlib.closing(wave.open(path, 'r')) as f:
57
  frames = f.getnframes()
58
  rate = f.getframerate()
59
  duration = frames / float(rate)
60
 
61
  embeddings = np.zeros(shape=(len(segments), 192))
62
  for i, segment in enumerate(segments):
63
- embeddings[i] = segment_embedding(segment, duration=duration, path=path)
64
  embeddings = np.nan_to_num(embeddings)
65
  clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
66
  labels = clustering.labels_
67
-
68
- output_text = ""
69
- for i, segment in enumerate(segments):
70
- segment.speaker = 'SPEAKER ' + str(labels[i] + 1)
71
- for i, segment in enumerate(segments):
72
- if i == 0 or segments[i - 1].speaker != segment.speaker:
73
- output_text += "\n" + segment.speaker + ' ' + str(time(segment.start)) + '\n'
74
- output_text += segment.text + ' '
75
-
76
  return output_text
77
 
 
 
78
  def _return_yt_html_embed(yt_url):
79
  video_id = yt_url.split("?v=")[-1]
80
  return f'<center><iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"></iframe></center>'
@@ -89,9 +100,11 @@ def download_yt_audio(yt_url, filename):
89
  "preferredquality": "192",
90
  }],
91
  }
 
92
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
93
  ydl.download([yt_url])
94
 
 
95
  def yt_transcribe(yt_url, task):
96
  html_embed_str = _return_yt_html_embed(yt_url)
97
 
@@ -99,39 +112,45 @@ def yt_transcribe(yt_url, task):
99
  filepath = os.path.join(tmpdirname, "audio.wav")
100
  download_yt_audio(yt_url, filepath)
101
 
102
- segments, _ = model.transcribe(filepath, batch_size=BATCH_SIZE)
103
 
104
- return html_embed_str, " ".join(segment.text for segment in segments)
105
 
106
- demo = gr.Blocks(theme=gr.themes.Soft())
107
 
108
  mf_transcribe = gr.Interface(
109
  fn=transcribe,
110
- inputs=[gr.Audio(sources="microphone", type="filepath"),
111
- gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")],
 
 
112
  outputs="text",
113
  title="VerbaLens Demo 1 : Prototype",
114
- description="Transcribe long-form microphone or audio inputs using Faster-Whisper.",
115
  allow_flagging="never",
116
  )
117
 
118
  file_transcribe = gr.Interface(
119
  fn=transcribe,
120
- inputs=[gr.Audio(sources="upload", type="filepath", label="Audio file"),
121
- gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")],
 
 
122
  outputs="text",
123
  title="VerbaLens Demo 1 : Prototype",
124
- description="Transcribe uploaded audio files using Faster-Whisper.",
125
  allow_flagging="never",
126
  )
127
 
128
  yt_transcribe = gr.Interface(
129
  fn=yt_transcribe,
130
- inputs=[gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
131
- gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")],
 
 
132
  outputs=["html", "text"],
133
  title="VerbaLens Demo 1 : Prototyping",
134
- description="Transcribe YouTube videos using Faster-Whisper.",
135
  allow_flagging="never",
136
  )
137
 
 
5
  import tempfile
6
  import os
7
  import locale
8
+ import whisper
9
  import datetime
10
  import subprocess
11
+ import pyannote.audio
 
 
 
 
12
  from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
13
  from pyannote.audio import Audio
14
  from pyannote.core import Segment
15
+ import wave
16
+ import contextlib
17
+ from sklearn.cluster import AgglomerativeClustering
18
+ import numpy as np
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ BATCH_SIZE = 16
22
  FILE_LIMIT_MB = 1000
23
+ COMPUTE_TYPE = "float32"
24
+ YT_LENGTH_LIMIT_S = 600 # limit to 1 hour YouTube files
25
 
26
  num_speakers = 2
27
  language = None
28
+ model_size = 'large'
29
+ model_name = model_size
 
 
30
 
31
+ def getpreferredencoding(do_setlocale = True):
32
  return "UTF-8"
33
+
34
  locale.getpreferredencoding = getpreferredencoding
35
+ embedding_model = PretrainedSpeakerEmbedding(
36
+ "speechbrain/spkrec-ecapa-voxceleb",
37
+ device=torch.device("cpu"))
38
+ model = whisper.load_model(model_size).to(device)
39
+ audio = Audio()
40
 
41
+ def segment_embedding(segment,duration,path):
42
+ start = segment["start"]
43
+ # Whisper overshoots the end timestamp in the last segment
44
+ end = min(duration, segment["end"])
45
  clip = Segment(start, end)
46
  waveform, sample_rate = audio.crop(path, clip)
47
+
48
+ # Convert waveform to single channel
49
  waveform = waveform.mean(dim=0, keepdim=True)
50
+
51
  return embedding_model(waveform.unsqueeze(0))
52
 
53
  def time(secs):
54
+ return datetime.timedelta(seconds=round(secs))
55
 
56
+ @spaces.GPU
57
  def transcribe(path, task):
58
  if path is None:
59
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
60
 
61
+ if path[-3:] != 'wav':
62
  subprocess.call(['ffmpeg', '-i', path, "audio.wav", '-y'])
63
  path = "audio.wav"
64
+ result = model.transcribe(path,fp16=False)
65
+ segments = result["segments"]
66
+ print(segments)
67
+ with contextlib.closing(wave.open(path,'r')) as f:
68
  frames = f.getnframes()
69
  rate = f.getframerate()
70
  duration = frames / float(rate)
71
 
72
  embeddings = np.zeros(shape=(len(segments), 192))
73
  for i, segment in enumerate(segments):
74
+ embeddings[i] = segment_embedding(segment,duration=duration,path=path)
75
  embeddings = np.nan_to_num(embeddings)
76
  clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
77
  labels = clustering.labels_
78
+ output_text=""
79
+ for i in range(len(segments)):
80
+ segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
81
+ for (i, segment) in enumerate(segments):
82
+ if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
83
+ output_text += "\n" + segment["speaker"] + ' ' + str(time(segment["start"])) + '\n'
84
+ output_text += segment["text"][1:] + ' '
 
 
85
  return output_text
86
 
87
+
88
+
89
  def _return_yt_html_embed(yt_url):
90
  video_id = yt_url.split("?v=")[-1]
91
  return f'<center><iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"></iframe></center>'
 
100
  "preferredquality": "192",
101
  }],
102
  }
103
+
104
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
105
  ydl.download([yt_url])
106
 
107
+ @spaces.GPU
108
  def yt_transcribe(yt_url, task):
109
  html_embed_str = _return_yt_html_embed(yt_url)
110
 
 
112
  filepath = os.path.join(tmpdirname, "audio.wav")
113
  download_yt_audio(yt_url, filepath)
114
 
115
+ result = model.transcribe(audio, batch_size=BATCH_SIZE)
116
 
117
+ return html_embed_str, result["text"]
118
 
119
+ demo = gr.Blocks(theme=gr.themes.Ocean())
120
 
121
  mf_transcribe = gr.Interface(
122
  fn=transcribe,
123
+ inputs=[
124
+ gr.Audio(sources="microphone", type="filepath"),
125
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
126
+ ],
127
  outputs="text",
128
  title="VerbaLens Demo 1 : Prototype",
129
+ description="Transcribe long-form microphone or audio inputs using WhisperX.",
130
  allow_flagging="never",
131
  )
132
 
133
  file_transcribe = gr.Interface(
134
  fn=transcribe,
135
+ inputs=[
136
+ gr.Audio(sources="upload", type="filepath", label="Audio file"),
137
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
138
+ ],
139
  outputs="text",
140
  title="VerbaLens Demo 1 : Prototype",
141
+ description="Transcribe uploaded audio files using WhisperX.",
142
  allow_flagging="never",
143
  )
144
 
145
  yt_transcribe = gr.Interface(
146
  fn=yt_transcribe,
147
+ inputs=[
148
+ gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
149
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
150
+ ],
151
  outputs=["html", "text"],
152
  title="VerbaLens Demo 1 : Prototyping",
153
+ description="Transcribe YouTube videos using WhisperX.",
154
  allow_flagging="never",
155
  )
156