whisperX-endpoint / handler.py
raphaelbiojout
update
8c6a1ad
import subprocess
import torch
# if torch.cuda.is_available():
# process = subprocess.Popen(['pip', 'uninstall', 'onnxruntime'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# stdout, stderr = process.communicate()
# process = subprocess.Popen(['pip', 'install', '--force-reinstall', 'onnxruntime-gpu'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# stdout, stderr = process.communicate()
import whisperx
import os
import time
import json
import base64
import numpy as np
DEVNULL = open(os.devnull, 'w')
# from transformers.pipelines.audio_utils import ffmpeg_read
from typing import Dict, List, Any
import logging
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
def whisper_config():
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model = "large-v2"
batch_size = 16 # reduce if low on GPU mem, 16 initailly
# change to "int8" if low on GPU mem (may reduce accuracy)
compute_type = "float16" if device == "cuda" else "int8"
return device, batch_size, compute_type, whisper_model
# From https://gist.github.com/kylemcdonald/85d70bf53e207bab3775
# load_audio can not detect the input type
def ffmpeg_load_audio(filename, sr=44100, mono=False, normalize=True, in_type=np.int16, out_type=np.float32):
channels = 1 if mono else 2
format_strings = {
np.float64: 'f64le',
np.float32: 'f32le',
np.int16: 's16le',
np.int32: 's32le',
np.uint32: 'u32le'
}
format_string = format_strings[in_type]
command = [
'ffmpeg',
'-i', filename,
'-f', format_string,
'-acodec', 'pcm_' + format_string,
'-ar', str(sr),
'-ac', str(channels),
'-']
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=DEVNULL, bufsize=4096)
bytes_per_sample = np.dtype(in_type).itemsize
frame_size = bytes_per_sample * channels
chunk_size = frame_size * sr # read in 1-second chunks
raw = b''
with p.stdout as stdout:
while True:
data = stdout.read(chunk_size)
if data:
raw += data
else:
break
audio = np.fromstring(raw, dtype=in_type).astype(out_type)
if channels > 1:
audio = audio.reshape((-1, channels)).transpose()
if audio.size == 0:
return audio, sr
if issubclass(out_type, np.floating):
if normalize:
peak = np.abs(audio).max()
if peak > 0:
audio /= peak
elif issubclass(in_type, np.integer):
audio /= np.iinfo(in_type).max
return audio
# FROM HuggingFace
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
"""
Helper function to read an audio file through ffmpeg.
"""
ar = f"{sampling_rate}"
ac = "1"
format_for_conversion = "f32le"
ffmpeg_command = [
"ffmpeg",
"-i",
"pipe:0",
"-ac",
ac,
"-ar",
ar,
"-f",
format_for_conversion,
"-hide_banner",
"-loglevel",
"quiet",
"pipe:1",
]
try:
with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process:
output_stream = ffmpeg_process.communicate(bpayload)
except FileNotFoundError as error:
raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error
out_bytes = output_stream[0]
audio = np.frombuffer(out_bytes, np.float32)
if audio.shape[0] == 0:
raise ValueError(
"Soundfile is either not in the correct format or is malformed. Ensure that the soundfile has "
"a valid audio file extension (e.g. wav, flac or mp3) and is not corrupted. If reading from a remote "
"URL, ensure that the URL is the full address to **download** the audio file."
)
return audio
# FROM whisperX
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI to be installed.
cmd = [
"ffmpeg",
"-nostdin",
"-threads",
"0",
"-i",
file,
"-f",
"s16le",
"-ac",
"1",
"-acodec",
"pcm_s16le",
"-ar",
str(sr),
"-",
]
out = subprocess.run(cmd, capture_output=True, check=True).stdout
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def display_gpu_infos():
if not torch.cuda.is_available():
return "NO CUDA"
infos = "torch.cuda.current_device(): " + str(torch.cuda.current_device()) + ", "
infos = infos + "torch.cuda.device(0): " + str(torch.cuda.device(0)) + ", "
infos = infos + "torch.cuda.device_count(): " + str(torch.cuda.device_count()) + ", "
infos = infos + "torch.cuda.get_device_name(0): " + str(torch.cuda.get_device_name(0))
return infos
class EndpointHandler():
def __init__(self, path=""):
# load the model
device, batch_size, compute_type, whisper_model = whisper_config()
self.model = whisperx.load_model(whisper_model, device=device, compute_type=compute_type)
# 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 ERIC
# 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 OLD
logger.info(f"Model {whisper_model} initialized")
self.diarize_model = whisperx.DiarizationPipeline(
"pyannote/speaker-diarization-3.1",
use_auth_token="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", device=device)
logger.info(f"Model for diarization initialized")
def __call__(self, data: Any) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the deserialized audio file as bytes
Return:
A :obj:`dict`:. base64 encoded image
"""
# get the start time
st = time.time()
logger.info("--------------- CONFIGURATION ------------------------")
device, batch_size, compute_type, whisper_model = whisper_config()
logger.info(f"device: {device}, batch_size: {batch_size}, compute_type:{compute_type}, whisper_model: {whisper_model}")
logger.info(display_gpu_infos())
# 1. process input
inputs_encoded = data.pop("inputs", data)
parameters = data.pop("parameters", None)
options = data.pop("options", None)
# OPTIONS are given as parameters
info = False
if options and "info" in options.keys() and options['info']:
info = True
alignment = False
if options and "alignment" in options.keys() and options['alignment']:
alignment = True
diarization = True
if options and "diarization" in options.keys() and not options['diarization']:
diarization = False
language = "fr"
if parameters and "language" in parameters.keys():
language = parameters["language"]
inputs = base64.b64decode(inputs_encoded)
# make a tmp file
with open('/tmp/myfile.tmp', 'wb') as w:
w.write(inputs)
# audio_nparray = ffmpeg_load_audio('/tmp/myfile.tmp', sr=SAMPLE_RATE, mono=True, out_type=np.float32)
audio_nparray = load_audio('/tmp/myfile.tmp', sr=SAMPLE_RATE)
# clean up
os.remove('/tmp/myfile.tmp')
# audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)
# audio_tensor= torch.from_numpy(audio_nparray)
# get the end time
et = time.time()
# get the execution time
elapsed_time = et - st
logger.info(f"TIME for audio processing : {elapsed_time:.2f} seconds")
if info:
print(f"TIME for audio processing : {elapsed_time:.2f} seconds")
# 2. transcribe
logger.info("--------------- STARTING TRANSCRIPTION ------------------------")
transcription = self.model.transcribe(audio_nparray, batch_size=batch_size,language=language)
if info:
print(transcription["segments"][0:10000]) # before alignment
logger.info(transcription["segments"][0:10000])
try:
first_text = transcription["segments"][0]["text"]
except:
logger.warning("No transcription")
return {"transcription": transcription["segments"]}
# get the execution time
et = time.time()
elapsed_time = et - st
st = time.time()
logger.info(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
if info:
print(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
# 3. align
if alignment:
logger.info("--------------- STARTING ALIGNMENT ------------------------")
model_a, metadata = whisperx.load_align_model(
language_code=transcription["language"], device=device)
transcription = whisperx.align(
transcription["segments"], model_a, metadata, audio_nparray, device, return_char_alignments=False)
if info:
print(transcription["segments"][0:10000])
logger.info(transcription["segments"][0:10000])
# get the execution time
et = time.time()
elapsed_time = et - st
st = time.time()
logger.info(f"TIME for alignment : {elapsed_time:.2f} seconds")
if info:
print(f"TIME for alignment : {elapsed_time:.2f} seconds")
# 4. Assign speaker labels
if diarization:
logger.info("--------------- STARTING DIARIZATION ------------------------")
# add min/max number of speakers if known
diarize_segments = self.diarize_model(audio_nparray)
if info:
print(diarize_segments)
logger.info(diarize_segments)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
transcription = whisperx.assign_word_speakers(diarize_segments, transcription)
if info:
print(transcription["segments"][0:10000])
logger.info(transcription["segments"][0:10000]) # segments are now assigned speaker IDs
# get the execution time
et = time.time()
elapsed_time = et - st
st = time.time()
logger.info(f"TIME for audio diarization : {elapsed_time:.2f} seconds")
if info:
print(f"TIME for audio diarization : {elapsed_time:.2f} seconds")
# results_json = json.dumps(results)
# return {"results": results_json}
return {"transcription": transcription["segments"]}