KIFF commited on
Commit
95b2ec3
·
verified ·
1 Parent(s): fd94d01

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -15
handler.py CHANGED
@@ -1,15 +1,16 @@
1
  from typing import Dict
2
  from pyannote.audio import Pipeline
3
- import torch
4
- import base64
5
- import numpy as np
6
 
7
  SAMPLE_RATE = 16000
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
- # load the model
12
- self.pipeline = Pipeline.from_pretrained("KIFF/pyannote-speaker-diarization-endpoint")
 
13
 
14
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
15
  """
@@ -21,16 +22,14 @@ class EndpointHandler():
21
  """
22
  # process input
23
  inputs = data.pop("inputs", data)
24
- parameters = data.pop("parameters", None) # min_speakers=2, max_speakers=5
25
 
26
- # decode the base64 audio data
27
- audio_data = base64.b64decode(inputs)
28
- audio_nparray = np.frombuffer(audio_data, dtype=np.int16)
 
 
29
 
30
- # prepare pynannote input
31
- audio_tensor= torch.from_numpy(audio_nparray).float().unsqueeze(0)
32
- pyannote_input = {"waveform": audio_tensor, "sample_rate": SAMPLE_RATE}
33
-
34
  # apply pretrained pipeline
35
  # pass inputs with all kwargs in data
36
  if parameters is not None:
@@ -43,5 +42,5 @@ class EndpointHandler():
43
  {"label": str(label), "start": str(segment.start), "stop": str(segment.end)}
44
  for segment, _, label in diarization.itertracks(yield_label=True)
45
  ]
46
-
47
- return {"diarization": processed_diarization}
 
1
  from typing import Dict
2
  from pyannote.audio import Pipeline
3
+ import torch
4
+ import io
5
+ from pyannote.audio import Audio
6
 
7
  SAMPLE_RATE = 16000
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
+ # Load the pipeline from the model repository (using config.yaml)
12
+ self.pipeline = Pipeline.from_pretrained(path)
13
+ self.audio = Audio(sample_rate=SAMPLE_RATE, mono="downmix")
14
 
15
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
16
  """
 
22
  """
23
  # process input
24
  inputs = data.pop("inputs", data)
25
+ parameters = data.pop("parameters", None)
26
 
27
+ # Load the audio using pyannote.audio
28
+ waveform, sample_rate = self.audio(io.BytesIO(inputs))
29
+
30
+ # prepare pyannote input
31
+ pyannote_input = {"waveform": waveform, "sample_rate": sample_rate}
32
 
 
 
 
 
33
  # apply pretrained pipeline
34
  # pass inputs with all kwargs in data
35
  if parameters is not None:
 
42
  {"label": str(label), "start": str(segment.start), "stop": str(segment.end)}
43
  for segment, _, label in diarization.itertracks(yield_label=True)
44
  ]
45
+
46
+ return {"diarization": processed_diarization}