KIFF commited on
Commit
38f584e
·
verified ·
1 Parent(s): e82e32f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -13
handler.py CHANGED
@@ -1,23 +1,16 @@
1
  from typing import Dict
2
  from pyannote.audio import Pipeline
3
- from pyannote.audio import Audio
4
- import io
5
  import torch
6
- import os
 
7
 
8
  SAMPLE_RATE = 16000
9
 
10
- # Set the PYANNOTE_CACHE environment variable
11
- os.environ["PYANNOTE_CACHE"] = "/repository/.cache"
12
-
13
  class EndpointHandler():
14
  def __init__(self, path=""):
15
- # Construct the full path to the model directory
16
- model_path = os.path.join("/repository", "") # Add trailing slash
17
-
18
- # Load the pipeline from the model repository using the full path
19
- self.pipeline = Pipeline.from_pretrained(model_path)
20
- self.audio = Audio(sample_rate=SAMPLE_RATE, mono="downmix") # Set mono="downmix"
21
 
22
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
23
  """
@@ -31,7 +24,7 @@ class EndpointHandler():
31
  inputs = data.pop("inputs", data)
32
  parameters = data.pop("parameters", None)
33
 
34
- # Load the audio using pyannote.audio (downmixing to mono)
35
  waveform, sample_rate = self.audio(io.BytesIO(inputs))
36
 
37
  # prepare pyannote input
 
1
  from typing import Dict
2
  from pyannote.audio import Pipeline
 
 
3
  import torch
4
+ import io
5
+ from pyannote.audio import Audio
6
 
7
  SAMPLE_RATE = 16000
8
 
 
 
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
+ # Load the pipeline from the model repository (using config.yaml)
12
+ self.pipeline = Pipeline.from_pretrained(path)
13
+ self.audio = Audio(sample_rate=SAMPLE_RATE, mono="downmix")
 
 
 
14
 
15
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
16
  """
 
24
  inputs = data.pop("inputs", data)
25
  parameters = data.pop("parameters", None)
26
 
27
+ # Load the audio using pyannote.audio
28
  waveform, sample_rate = self.audio(io.BytesIO(inputs))
29
 
30
  # prepare pyannote input