Spaces:
Paused
Paused
import asyncio | |
import re | |
import threading | |
import numpy as np | |
import logging | |
from diart import SpeakerDiarization, SpeakerDiarizationConfig | |
from diart.inference import StreamingInference | |
from diart.sources import AudioSource | |
from timed_objects import SpeakerSegment | |
from diart.sources import MicrophoneAudioSource | |
from rx.core import Observer | |
from typing import Tuple, Any, List | |
from pyannote.core import Annotation | |
logger = logging.getLogger(__name__) | |
def extract_number(s: str) -> int: | |
m = re.search(r'\d+', s) | |
return int(m.group()) if m else None | |
class DiarizationObserver(Observer): | |
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" | |
def __init__(self): | |
self.speaker_segments = [] | |
self.processed_time = 0 | |
self.segment_lock = threading.Lock() | |
def on_next(self, value: Tuple[Annotation, Any]): | |
annotation, audio = value | |
logger.debug("\n--- New Diarization Result ---") | |
duration = audio.extent.end - audio.extent.start | |
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)") | |
logger.debug(f"Audio shape: {audio.data.shape}") | |
with self.segment_lock: | |
if audio.extent.end > self.processed_time: | |
self.processed_time = audio.extent.end | |
if annotation and len(annotation._labels) > 0: | |
logger.debug("\nSpeaker segments:") | |
for speaker, label in annotation._labels.items(): | |
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]): | |
print(f" {speaker}: {start:.2f}s-{end:.2f}s") | |
self.speaker_segments.append(SpeakerSegment( | |
speaker=speaker, | |
start=start, | |
end=end | |
)) | |
else: | |
logger.debug("\nNo speakers detected in this segment") | |
def get_segments(self) -> List[SpeakerSegment]: | |
"""Get a copy of the current speaker segments.""" | |
with self.segment_lock: | |
return self.speaker_segments.copy() | |
def clear_old_segments(self, older_than: float = 30.0): | |
"""Clear segments older than the specified time.""" | |
with self.segment_lock: | |
current_time = self.processed_time | |
self.speaker_segments = [ | |
segment for segment in self.speaker_segments | |
if current_time - segment.end < older_than | |
] | |
def on_error(self, error): | |
"""Handle an error in the stream.""" | |
logger.debug(f"Error in diarization stream: {error}") | |
def on_completed(self): | |
"""Handle the completion of the stream.""" | |
logger.debug("Diarization stream completed") | |
class WebSocketAudioSource(AudioSource): | |
""" | |
Custom AudioSource that blocks in read() until close() is called. | |
Use push_audio() to inject PCM chunks. | |
""" | |
def __init__(self, uri: str = "websocket", sample_rate: int = 16000): | |
super().__init__(uri, sample_rate) | |
self._closed = False | |
self._close_event = threading.Event() | |
def read(self): | |
self._close_event.wait() | |
def close(self): | |
if not self._closed: | |
self._closed = True | |
self.stream.on_completed() | |
self._close_event.set() | |
def push_audio(self, chunk: np.ndarray): | |
if not self._closed: | |
new_audio = np.expand_dims(chunk, axis=0) | |
logger.debug('Add new chunk with shape:', new_audio.shape) | |
self.stream.on_next(new_audio) | |
class DiartDiarization: | |
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): | |
self.pipeline = SpeakerDiarization(config=config) | |
self.observer = DiarizationObserver() | |
if use_microphone: | |
self.source = MicrophoneAudioSource() | |
self.custom_source = None | |
else: | |
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate) | |
self.source = self.custom_source | |
self.inference = StreamingInference( | |
pipeline=self.pipeline, | |
source=self.source, | |
do_plot=False, | |
show_progress=False, | |
) | |
self.inference.attach_observers(self.observer) | |
asyncio.get_event_loop().run_in_executor(None, self.inference) | |
async def diarize(self, pcm_array: np.ndarray): | |
""" | |
Process audio data for diarization. | |
Only used when working with WebSocketAudioSource. | |
""" | |
if self.custom_source: | |
self.custom_source.push_audio(pcm_array) | |
self.observer.clear_old_segments() | |
return self.observer.get_segments() | |
def close(self): | |
"""Close the audio source.""" | |
if self.custom_source: | |
self.custom_source.close() | |
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float: | |
""" | |
Assign speakers to tokens based on timing overlap with speaker segments. | |
Uses the segments collected by the observer. | |
""" | |
segments = self.observer.get_segments() | |
for token in tokens: | |
for segment in segments: | |
if not (segment.end <= token.start or segment.start >= token.end): | |
token.speaker = extract_number(segment.speaker) + 1 | |
end_attributed_speaker = max(token.end, end_attributed_speaker) | |
return end_attributed_speaker |