KIFF's picture
Update handler.py
6c0c215 verified
raw
history blame
1.56 kB
from typing import Dict
from pyannote.audio import Pipeline
from pyannote.audio import Audio
import io
import torch
SAMPLE_RATE = 16000
class EndpointHandler():
def __init__(self, path=""):
# Load the pipeline from the model repository using the path
self.pipeline = Pipeline.from_pretrained(path)
self.audio = Audio(sample_rate=SAMPLE_RATE, mono="downmix")
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the deserialized audio file as bytes
Return:
A :obj:`dict`:. base64 encoded image
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# Load the audio using pyannote.audio (downmixing to mono)
waveform, sample_rate = self.audio(io.BytesIO(inputs))
# prepare pyannote input
pyannote_input = {"waveform": waveform, "sample_rate": sample_rate}
# apply pretrained pipeline
# pass inputs with all kwargs in data
if parameters is not None:
diarization = self.pipeline(pyannote_input, **parameters)
else:
diarization = self.pipeline(pyannote_input)
# postprocess the prediction
processed_diarization = [
{"label": str(label), "start": str(segment.start), "stop": str(segment.end)}
for segment, _, label in diarization.itertracks(yield_label=True)
]
return {"diarization": processed_diarization}