KIFF commited on
Commit
f9ef35e
1 Parent(s): 5dbbf5e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +50 -2
handler.py CHANGED
@@ -1,2 +1,50 @@
1
- torch==1.11.0
2
- git+https://github.com/philschmid/pyannote-audio[email protected]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from pyannote.audio import Pipeline
3
+ from pyannote.audio import Audio
4
+ import io
5
+ import torch
6
+ import os
7
+
8
+ SAMPLE_RATE = 16000
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ # Construct the full path to the model directory
13
+ model_path = os.path.join(".", "")
14
+
15
+ # Load the pipeline from the model repository using the full path
16
+ self.pipeline = Pipeline.from_pretrained(model_path)
17
+ self.audio = Audio(sample_rate=SAMPLE_RATE, mono="downmix")
18
+
19
+ def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
20
+ """
21
+ Args:
22
+ data (:obj:):
23
+ includes the deserialized audio file as bytes
24
+ Return:
25
+ A :obj:`dict`:. base64 encoded image
26
+ """
27
+ # process input
28
+ inputs = data.pop("inputs", data)
29
+ parameters = data.pop("parameters", None)
30
+
31
+ # Load the audio using pyannote.audio (downmixing to mono)
32
+ waveform, sample_rate = self.audio(io.BytesIO(inputs))
33
+
34
+ # prepare pyannote input
35
+ pyannote_input = {"waveform": waveform, "sample_rate": sample_rate}
36
+
37
+ # apply pretrained pipeline
38
+ # pass inputs with all kwargs in data
39
+ if parameters is not None:
40
+ diarization = self.pipeline(pyannote_input, **parameters)
41
+ else:
42
+ diarization = self.pipeline(pyannote_input)
43
+
44
+ # postprocess the prediction
45
+ processed_diarization = [
46
+ {"label": str(label), "start": str(segment.start), "stop": str(segment.end)}
47
+ for segment, _, label in diarization.itertracks(yield_label=True)
48
+ ]
49
+
50
+ return {"diarization": processed_diarization}