KIFF commited on
Commit
6f744e3
·
verified ·
1 Parent(s): a4109dd

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +54 -31
handler.py CHANGED
@@ -1,49 +1,72 @@
1
- import os
2
- from pyannote.audio import Pipeline, Audio
3
  import torch
 
 
 
4
 
 
5
 
6
- class EndpointHandler:
7
  def __init__(self, path=""):
8
- # Get the Hugging Face authentication token from the environment variable
9
- auth_token = os.getenv("MY_KEY")
10
- if not auth_token:
11
  raise ValueError("Hugging Face authentication token (MY_KEY) is missing.")
12
 
13
- # Initialize pretrained pipeline with the token
14
- self._pipeline = Pipeline.from_pretrained(
15
- "pyannote/speaker-diarization-3.1", use_auth_token=auth_token
16
  )
17
 
18
- # Send pipeline to GPU if available
19
- if torch.cuda.is_available():
20
- self._pipeline.to(torch.device("cuda"))
21
 
22
- # Initialize audio reader
23
- self._io = Audio()
24
 
25
- def __call__(self, data):
26
- # Extract inputs from request data
27
- inputs = data.pop("inputs", data)
28
- waveform, sample_rate = self._io(inputs)
 
 
 
 
 
 
 
29
 
30
- # Extract pipeline parameters if provided
31
- parameters = data.pop("parameters", dict())
 
32
 
33
- # Run speaker diarization
34
- diarization = self._pipeline(
35
- {"waveform": waveform, "sample_rate": sample_rate}, **parameters
36
- )
 
 
 
 
37
 
38
- # Process diarization results
 
 
 
 
 
 
 
 
 
39
  processed_diarization = [
40
  {
41
- "speaker": speaker,
42
- "start": f"{turn.start:.3f}",
43
- "end": f"{turn.end:.3f}",
44
  }
45
- for turn, _, speaker in diarization.itertracks(yield_label=True)
46
  ]
47
-
48
- # Return results as JSON
49
  return {"diarization": processed_diarization}
 
 
1
+ from typing import Dict
2
+ from pyannote.audio import Pipeline
3
  import torch
4
+ import base64
5
+ import numpy as np
6
+ import os
7
 
8
+ SAMPLE_RATE = 16000
9
 
10
+ class EndpointHandler():
11
  def __init__(self, path=""):
12
+ # Retrieve the Hugging Face authentication token from the environment variable
13
+ hf_token = os.getenv("MY_KEY")
14
+ if not hf_token:
15
  raise ValueError("Hugging Face authentication token (MY_KEY) is missing.")
16
 
17
+ # Initialize the pipeline with the authentication token
18
+ self.pipeline = Pipeline.from_pretrained(
19
+ "pyannote/speaker-diarization-3.1", use_auth_token=hf_token
20
  )
21
 
22
+ # Move the pipeline to the appropriate device (CPU or GPU)
23
+ self.pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
24
 
25
+ # Instantiate the pipeline with its parameters
26
+ self.pipeline = self.pipeline.instantiate(self.pipeline.parameters)
27
 
28
+ def __call__(self, data: Dict) -> Dict:
29
+ """
30
+ Args:
31
+ data (Dict):
32
+ 'inputs': Base64-encoded audio bytes
33
+ 'parameters': Additional diarization parameters (currently unused)
34
+ Return:
35
+ Dict: Speaker diarization results
36
+ """
37
+ inputs = data.get("inputs")
38
+ parameters = data.get("parameters", {}) # We are not using them now
39
 
40
+ # Decode the base64 audio data
41
+ audio_data = base64.b64decode(inputs)
42
+ audio_nparray = np.frombuffer(audio_data, dtype=np.int16)
43
 
44
+ # Handle multi-channel audio (convert to mono)
45
+ if audio_nparray.ndim > 1:
46
+ audio_nparray = audio_nparray.mean(axis=0) # Average channels to create mono
47
+
48
+ # Convert to PyTorch tensor
49
+ audio_tensor = torch.from_numpy(audio_nparray).float().unsqueeze(0)
50
+ if audio_tensor.dim() == 1:
51
+ audio_tensor = audio_tensor.unsqueeze(0)
52
 
53
+ pyannote_input = {"waveform": audio_tensor, "sample_rate": SAMPLE_RATE}
54
+
55
+ # Run diarization pipeline
56
+ try:
57
+ diarization = self.pipeline(pyannote_input) # No num_speakers parameter
58
+ except Exception as e:
59
+ print(f"An unexpected error occurred: {e}")
60
+ return {"error": "Diarization failed unexpectedly"}
61
+
62
+ # Build a friendly JSON response
63
  processed_diarization = [
64
  {
65
+ "label": str(label),
66
+ "start": str(segment.start),
67
+ "stop": str(segment.end),
68
  }
69
+ for segment, _, label in diarization.itertracks(yield_label=True)
70
  ]
 
 
71
  return {"diarization": processed_diarization}
72
+