raphaelbiojout commited on
Commit
f331362
·
1 Parent(s): 30b7244

Update handler

Browse files
Files changed (1) hide show
  1. handler.py +23 -1
handler.py CHANGED
@@ -23,7 +23,7 @@ SAMPLE_RATE = 16000
23
  def whisper_config():
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  whisper_model = "large-v2"
26
- batch_size = 16 # reduce if low on GPU mem
27
  # change to "int8" if low on GPU mem (may reduce accuracy)
28
  compute_type = "float16" if device == "cuda" else "int8"
29
  return device, batch_size, compute_type, whisper_model
@@ -158,6 +158,15 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
158
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
159
 
160
 
 
 
 
 
 
 
 
 
 
161
 
162
  class EndpointHandler():
163
  def __init__(self, path=""):
@@ -187,6 +196,9 @@ class EndpointHandler():
187
  logger.info(f"key: {x}, value: {data[x]} ")
188
  print(f"key: {x}, value: {data[x]} ")
189
 
 
 
 
190
  # 1. process input
191
  inputs_encoded = data.pop("inputs", data)
192
  parameters = data.pop("parameters", None)
@@ -212,11 +224,13 @@ class EndpointHandler():
212
 
213
  # 2. transcribe
214
  device, batch_size, compute_type, whisper_model = whisper_config()
 
215
  transcription = self.model.transcribe(audio_nparray, batch_size=batch_size,language=language)
216
 
217
  logger.info(transcription["segments"])
218
 
219
  # 3. align
 
220
  # model_a, metadata = whisperx.load_align_model(
221
  # language_code=result["language"], device=device)
222
  # transcription = whisperx.align(
@@ -225,6 +239,7 @@ class EndpointHandler():
225
  # print(transcription["segments"])
226
 
227
  # 4. Assign speaker labels
 
228
  # add min/max number of speakers if known
229
  diarize_segments = self.diarize_model(audio_nparray)
230
  logger.info(diarize_segments)
@@ -234,6 +249,13 @@ class EndpointHandler():
234
  logger.info(diarized_transcription["segments"]) # segments are now assigned speaker IDs
235
  results.append({"diarized_transcription": diarized_transcription["segments"]})
236
 
 
 
 
 
 
 
 
237
  # results_json = json.dumps(results)
238
  # return {"results": results_json}
239
  return results
 
23
  def whisper_config():
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  whisper_model = "large-v2"
26
+ batch_size = 8 # reduce if low on GPU mem, 16 initailly
27
  # change to "int8" if low on GPU mem (may reduce accuracy)
28
  compute_type = "float16" if device == "cuda" else "int8"
29
  return device, batch_size, compute_type, whisper_model
 
158
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
159
 
160
 
161
+ def display_gpu_infos():
162
+ if not torch.cuda.is_available():
163
+ return "NO CUDA"
164
+
165
+ infos = "torch.cuda.current_device(): " + torch.cuda.current_device() + "\n"
166
+ infos = infos + "torch.cuda.device(0): " + torch.cuda.device(0) + "\n"
167
+ infos = infos + "torch.cuda.device_count(): " + torch.cuda.device_count() + "\n"
168
+ infos = infos + "torch.cuda.get_device_name(0): " + torch.cuda.get_device_name(0) + "\n"
169
+ return infos
170
 
171
  class EndpointHandler():
172
  def __init__(self, path=""):
 
196
  logger.info(f"key: {x}, value: {data[x]} ")
197
  print(f"key: {x}, value: {data[x]} ")
198
 
199
+ logger.info("--------------- CUDA ------------------------")
200
+ logger.info(display_gpu_infos())
201
+
202
  # 1. process input
203
  inputs_encoded = data.pop("inputs", data)
204
  parameters = data.pop("parameters", None)
 
224
 
225
  # 2. transcribe
226
  device, batch_size, compute_type, whisper_model = whisper_config()
227
+ logger.info("--------------- STARTING TRANSCRIPTION ------------------------")
228
  transcription = self.model.transcribe(audio_nparray, batch_size=batch_size,language=language)
229
 
230
  logger.info(transcription["segments"])
231
 
232
  # 3. align
233
+ logger.info("--------------- STARTING ALIGNMENT ------------------------")
234
  # model_a, metadata = whisperx.load_align_model(
235
  # language_code=result["language"], device=device)
236
  # transcription = whisperx.align(
 
239
  # print(transcription["segments"])
240
 
241
  # 4. Assign speaker labels
242
+ logger.info("--------------- STARTING DIARIZATION ------------------------")
243
  # add min/max number of speakers if known
244
  diarize_segments = self.diarize_model(audio_nparray)
245
  logger.info(diarize_segments)
 
249
  logger.info(diarized_transcription["segments"]) # segments are now assigned speaker IDs
250
  results.append({"diarized_transcription": diarized_transcription["segments"]})
251
 
252
+ if torch.cuda.is_available():
253
+ logger.info("--------------- GPU ------------------------")
254
+ logger.info(display_gpu_infos())
255
+ torch.cuda.empty_cache()
256
+ logger.info("--------------- GPU AFTER empty_cache ------------------------")
257
+ logger.info(display_gpu_infos())
258
+
259
  # results_json = json.dumps(results)
260
  # return {"results": results_json}
261
  return results