|
import subprocess |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import whisperx |
|
import os |
|
import time |
|
import json |
|
import base64 |
|
import numpy as np |
|
|
|
DEVNULL = open(os.devnull, 'w') |
|
|
|
|
|
from typing import Dict, List, Any |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
def whisper_config(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
whisper_model = "large-v2" |
|
batch_size = 16 |
|
|
|
compute_type = "float16" if device == "cuda" else "int8" |
|
return device, batch_size, compute_type, whisper_model |
|
|
|
|
|
|
|
def ffmpeg_load_audio(filename, sr=44100, mono=False, normalize=True, in_type=np.int16, out_type=np.float32): |
|
channels = 1 if mono else 2 |
|
format_strings = { |
|
np.float64: 'f64le', |
|
np.float32: 'f32le', |
|
np.int16: 's16le', |
|
np.int32: 's32le', |
|
np.uint32: 'u32le' |
|
} |
|
format_string = format_strings[in_type] |
|
command = [ |
|
'ffmpeg', |
|
'-i', filename, |
|
'-f', format_string, |
|
'-acodec', 'pcm_' + format_string, |
|
'-ar', str(sr), |
|
'-ac', str(channels), |
|
'-'] |
|
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=DEVNULL, bufsize=4096) |
|
bytes_per_sample = np.dtype(in_type).itemsize |
|
frame_size = bytes_per_sample * channels |
|
chunk_size = frame_size * sr |
|
raw = b'' |
|
with p.stdout as stdout: |
|
while True: |
|
data = stdout.read(chunk_size) |
|
if data: |
|
raw += data |
|
else: |
|
break |
|
audio = np.fromstring(raw, dtype=in_type).astype(out_type) |
|
if channels > 1: |
|
audio = audio.reshape((-1, channels)).transpose() |
|
if audio.size == 0: |
|
return audio, sr |
|
if issubclass(out_type, np.floating): |
|
if normalize: |
|
peak = np.abs(audio).max() |
|
if peak > 0: |
|
audio /= peak |
|
elif issubclass(in_type, np.integer): |
|
audio /= np.iinfo(in_type).max |
|
return audio |
|
|
|
|
|
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: |
|
""" |
|
Helper function to read an audio file through ffmpeg. |
|
""" |
|
ar = f"{sampling_rate}" |
|
ac = "1" |
|
format_for_conversion = "f32le" |
|
ffmpeg_command = [ |
|
"ffmpeg", |
|
"-i", |
|
"pipe:0", |
|
"-ac", |
|
ac, |
|
"-ar", |
|
ar, |
|
"-f", |
|
format_for_conversion, |
|
"-hide_banner", |
|
"-loglevel", |
|
"quiet", |
|
"pipe:1", |
|
] |
|
|
|
try: |
|
with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process: |
|
output_stream = ffmpeg_process.communicate(bpayload) |
|
except FileNotFoundError as error: |
|
raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error |
|
out_bytes = output_stream[0] |
|
audio = np.frombuffer(out_bytes, np.float32) |
|
if audio.shape[0] == 0: |
|
raise ValueError( |
|
"Soundfile is either not in the correct format or is malformed. Ensure that the soundfile has " |
|
"a valid audio file extension (e.g. wav, flac or mp3) and is not corrupted. If reading from a remote " |
|
"URL, ensure that the URL is the full address to **download** the audio file." |
|
) |
|
return audio |
|
|
|
|
|
|
|
def load_audio(file: str, sr: int = SAMPLE_RATE): |
|
""" |
|
Open an audio file and read as mono waveform, resampling as necessary |
|
|
|
Parameters |
|
---------- |
|
file: str |
|
The audio file to open |
|
|
|
sr: int |
|
The sample rate to resample the audio if necessary |
|
|
|
Returns |
|
------- |
|
A NumPy array containing the audio waveform, in float32 dtype. |
|
""" |
|
try: |
|
|
|
|
|
cmd = [ |
|
"ffmpeg", |
|
"-nostdin", |
|
"-threads", |
|
"0", |
|
"-i", |
|
file, |
|
"-f", |
|
"s16le", |
|
"-ac", |
|
"1", |
|
"-acodec", |
|
"pcm_s16le", |
|
"-ar", |
|
str(sr), |
|
"-", |
|
] |
|
out = subprocess.run(cmd, capture_output=True, check=True).stdout |
|
except subprocess.CalledProcessError as e: |
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e |
|
|
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |
|
|
|
|
|
def display_gpu_infos(): |
|
if not torch.cuda.is_available(): |
|
return "NO CUDA" |
|
|
|
infos = "torch.cuda.current_device(): " + str(torch.cuda.current_device()) + ", " |
|
infos = infos + "torch.cuda.device(0): " + str(torch.cuda.device(0)) + ", " |
|
infos = infos + "torch.cuda.device_count(): " + str(torch.cuda.device_count()) + ", " |
|
infos = infos + "torch.cuda.get_device_name(0): " + str(torch.cuda.get_device_name(0)) |
|
return infos |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
device, batch_size, compute_type, whisper_model = whisper_config() |
|
self.model = whisperx.load_model(whisper_model, device=device, compute_type=compute_type) |
|
|
|
|
|
logger.info(f"Model {whisper_model} initialized") |
|
|
|
self.diarize_model = whisperx.DiarizationPipeline( |
|
"pyannote/speaker-diarization-3.1", |
|
use_auth_token="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", device=device) |
|
|
|
logger.info(f"Model for diarization initialized") |
|
|
|
|
|
def __call__(self, data: Any) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the deserialized audio file as bytes |
|
Return: |
|
A :obj:`dict`:. base64 encoded image |
|
""" |
|
|
|
st = time.time() |
|
|
|
|
|
logger.info("--------------- CONFIGURATION ------------------------") |
|
device, batch_size, compute_type, whisper_model = whisper_config() |
|
logger.info(f"device: {device}, batch_size: {batch_size}, compute_type:{compute_type}, whisper_model: {whisper_model}") |
|
logger.info(display_gpu_infos()) |
|
|
|
|
|
inputs_encoded = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
options = data.pop("options", None) |
|
|
|
|
|
info = False |
|
if options and "info" in options.keys() and options['info']: |
|
info = True |
|
|
|
alignment = False |
|
if options and "alignment" in options.keys() and options['alignment']: |
|
alignment = True |
|
|
|
diarization = True |
|
if options and "diarization" in options.keys() and not options['diarization']: |
|
diarization = False |
|
|
|
language = "fr" |
|
if parameters and "language" in parameters.keys(): |
|
language = parameters["language"] |
|
|
|
inputs = base64.b64decode(inputs_encoded) |
|
|
|
with open('/tmp/myfile.tmp', 'wb') as w: |
|
w.write(inputs) |
|
|
|
|
|
audio_nparray = load_audio('/tmp/myfile.tmp', sr=SAMPLE_RATE) |
|
|
|
os.remove('/tmp/myfile.tmp') |
|
|
|
|
|
|
|
|
|
|
|
et = time.time() |
|
|
|
|
|
elapsed_time = et - st |
|
logger.info(f"TIME for audio processing : {elapsed_time:.2f} seconds") |
|
if info: |
|
print(f"TIME for audio processing : {elapsed_time:.2f} seconds") |
|
|
|
|
|
logger.info("--------------- STARTING TRANSCRIPTION ------------------------") |
|
transcription = self.model.transcribe(audio_nparray, batch_size=batch_size,language=language) |
|
if info: |
|
print(transcription["segments"][0:10000]) |
|
logger.info(transcription["segments"][0:10000]) |
|
|
|
try: |
|
first_text = transcription["segments"][0]["text"] |
|
except: |
|
logger.warning("No transcription") |
|
return {"transcription": transcription["segments"]} |
|
|
|
|
|
et = time.time() |
|
elapsed_time = et - st |
|
st = time.time() |
|
logger.info(f"TIME for audio transcription : {elapsed_time:.2f} seconds") |
|
if info: |
|
print(f"TIME for audio transcription : {elapsed_time:.2f} seconds") |
|
|
|
|
|
if alignment: |
|
logger.info("--------------- STARTING ALIGNMENT ------------------------") |
|
model_a, metadata = whisperx.load_align_model( |
|
language_code=transcription["language"], device=device) |
|
transcription = whisperx.align( |
|
transcription["segments"], model_a, metadata, audio_nparray, device, return_char_alignments=False) |
|
if info: |
|
print(transcription["segments"][0:10000]) |
|
logger.info(transcription["segments"][0:10000]) |
|
|
|
|
|
et = time.time() |
|
elapsed_time = et - st |
|
st = time.time() |
|
logger.info(f"TIME for alignment : {elapsed_time:.2f} seconds") |
|
if info: |
|
print(f"TIME for alignment : {elapsed_time:.2f} seconds") |
|
|
|
|
|
if diarization: |
|
logger.info("--------------- STARTING DIARIZATION ------------------------") |
|
|
|
diarize_segments = self.diarize_model(audio_nparray) |
|
if info: |
|
print(diarize_segments) |
|
logger.info(diarize_segments) |
|
|
|
|
|
transcription = whisperx.assign_word_speakers(diarize_segments, transcription) |
|
if info: |
|
print(transcription["segments"][0:10000]) |
|
logger.info(transcription["segments"][0:10000]) |
|
|
|
|
|
et = time.time() |
|
elapsed_time = et - st |
|
st = time.time() |
|
logger.info(f"TIME for audio diarization : {elapsed_time:.2f} seconds") |
|
if info: |
|
print(f"TIME for audio diarization : {elapsed_time:.2f} seconds") |
|
|
|
|
|
|
|
return {"transcription": transcription["segments"]} |
|
|
|
|
|
|
|
|
|
|