raphaelbiojout
commited on
Commit
·
f331362
1
Parent(s):
30b7244
Update handler
Browse files- handler.py +23 -1
handler.py
CHANGED
@@ -23,7 +23,7 @@ SAMPLE_RATE = 16000
|
|
23 |
def whisper_config():
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
whisper_model = "large-v2"
|
26 |
-
batch_size =
|
27 |
# change to "int8" if low on GPU mem (may reduce accuracy)
|
28 |
compute_type = "float16" if device == "cuda" else "int8"
|
29 |
return device, batch_size, compute_type, whisper_model
|
@@ -158,6 +158,15 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|
158 |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
159 |
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
class EndpointHandler():
|
163 |
def __init__(self, path=""):
|
@@ -187,6 +196,9 @@ class EndpointHandler():
|
|
187 |
logger.info(f"key: {x}, value: {data[x]} ")
|
188 |
print(f"key: {x}, value: {data[x]} ")
|
189 |
|
|
|
|
|
|
|
190 |
# 1. process input
|
191 |
inputs_encoded = data.pop("inputs", data)
|
192 |
parameters = data.pop("parameters", None)
|
@@ -212,11 +224,13 @@ class EndpointHandler():
|
|
212 |
|
213 |
# 2. transcribe
|
214 |
device, batch_size, compute_type, whisper_model = whisper_config()
|
|
|
215 |
transcription = self.model.transcribe(audio_nparray, batch_size=batch_size,language=language)
|
216 |
|
217 |
logger.info(transcription["segments"])
|
218 |
|
219 |
# 3. align
|
|
|
220 |
# model_a, metadata = whisperx.load_align_model(
|
221 |
# language_code=result["language"], device=device)
|
222 |
# transcription = whisperx.align(
|
@@ -225,6 +239,7 @@ class EndpointHandler():
|
|
225 |
# print(transcription["segments"])
|
226 |
|
227 |
# 4. Assign speaker labels
|
|
|
228 |
# add min/max number of speakers if known
|
229 |
diarize_segments = self.diarize_model(audio_nparray)
|
230 |
logger.info(diarize_segments)
|
@@ -234,6 +249,13 @@ class EndpointHandler():
|
|
234 |
logger.info(diarized_transcription["segments"]) # segments are now assigned speaker IDs
|
235 |
results.append({"diarized_transcription": diarized_transcription["segments"]})
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
# results_json = json.dumps(results)
|
238 |
# return {"results": results_json}
|
239 |
return results
|
|
|
23 |
def whisper_config():
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
whisper_model = "large-v2"
|
26 |
+
batch_size = 8 # reduce if low on GPU mem, 16 initailly
|
27 |
# change to "int8" if low on GPU mem (may reduce accuracy)
|
28 |
compute_type = "float16" if device == "cuda" else "int8"
|
29 |
return device, batch_size, compute_type, whisper_model
|
|
|
158 |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
159 |
|
160 |
|
161 |
+
def display_gpu_infos():
|
162 |
+
if not torch.cuda.is_available():
|
163 |
+
return "NO CUDA"
|
164 |
+
|
165 |
+
infos = "torch.cuda.current_device(): " + torch.cuda.current_device() + "\n"
|
166 |
+
infos = infos + "torch.cuda.device(0): " + torch.cuda.device(0) + "\n"
|
167 |
+
infos = infos + "torch.cuda.device_count(): " + torch.cuda.device_count() + "\n"
|
168 |
+
infos = infos + "torch.cuda.get_device_name(0): " + torch.cuda.get_device_name(0) + "\n"
|
169 |
+
return infos
|
170 |
|
171 |
class EndpointHandler():
|
172 |
def __init__(self, path=""):
|
|
|
196 |
logger.info(f"key: {x}, value: {data[x]} ")
|
197 |
print(f"key: {x}, value: {data[x]} ")
|
198 |
|
199 |
+
logger.info("--------------- CUDA ------------------------")
|
200 |
+
logger.info(display_gpu_infos())
|
201 |
+
|
202 |
# 1. process input
|
203 |
inputs_encoded = data.pop("inputs", data)
|
204 |
parameters = data.pop("parameters", None)
|
|
|
224 |
|
225 |
# 2. transcribe
|
226 |
device, batch_size, compute_type, whisper_model = whisper_config()
|
227 |
+
logger.info("--------------- STARTING TRANSCRIPTION ------------------------")
|
228 |
transcription = self.model.transcribe(audio_nparray, batch_size=batch_size,language=language)
|
229 |
|
230 |
logger.info(transcription["segments"])
|
231 |
|
232 |
# 3. align
|
233 |
+
logger.info("--------------- STARTING ALIGNMENT ------------------------")
|
234 |
# model_a, metadata = whisperx.load_align_model(
|
235 |
# language_code=result["language"], device=device)
|
236 |
# transcription = whisperx.align(
|
|
|
239 |
# print(transcription["segments"])
|
240 |
|
241 |
# 4. Assign speaker labels
|
242 |
+
logger.info("--------------- STARTING DIARIZATION ------------------------")
|
243 |
# add min/max number of speakers if known
|
244 |
diarize_segments = self.diarize_model(audio_nparray)
|
245 |
logger.info(diarize_segments)
|
|
|
249 |
logger.info(diarized_transcription["segments"]) # segments are now assigned speaker IDs
|
250 |
results.append({"diarized_transcription": diarized_transcription["segments"]})
|
251 |
|
252 |
+
if torch.cuda.is_available():
|
253 |
+
logger.info("--------------- GPU ------------------------")
|
254 |
+
logger.info(display_gpu_infos())
|
255 |
+
torch.cuda.empty_cache()
|
256 |
+
logger.info("--------------- GPU AFTER empty_cache ------------------------")
|
257 |
+
logger.info(display_gpu_infos())
|
258 |
+
|
259 |
# results_json = json.dumps(results)
|
260 |
# return {"results": results_json}
|
261 |
return results
|