|
""" |
|
Integration tests for Speaker Embedding functionality |
|
Tests the complete workflow of speaker unification in distributed transcription |
|
""" |
|
|
|
import pytest |
|
import asyncio |
|
import tempfile |
|
import shutil |
|
import json |
|
import numpy as np |
|
from pathlib import Path |
|
from unittest.mock import Mock, patch, AsyncMock, MagicMock |
|
import torch |
|
|
|
from src.services.distributed_transcription_service import DistributedTranscriptionService |
|
from src.services.speaker_embedding_service import SpeakerEmbeddingService, SpeakerIdentificationService |
|
from src.utils.config import AudioProcessingConfig |
|
|
|
|
|
class TestSpeakerEmbeddingIntegration: |
|
"""Integration tests for speaker embedding with distributed transcription""" |
|
|
|
def setup_method(self): |
|
"""Setup test environment""" |
|
self.temp_dir = tempfile.mkdtemp() |
|
self.service = DistributedTranscriptionService(cache_dir=self.temp_dir) |
|
|
|
def teardown_method(self): |
|
"""Cleanup test environment""" |
|
shutil.rmtree(self.temp_dir, ignore_errors=True) |
|
|
|
@pytest.mark.asyncio |
|
@patch.dict('os.environ', {'HF_TOKEN': 'test_token'}) |
|
async def test_merge_chunk_results_with_speaker_unification(self): |
|
"""Test complete speaker unification workflow in merge_chunk_results""" |
|
|
|
|
|
chunk_results = [ |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 0.0, |
|
"audio_duration": 60.0, |
|
"language_detected": "en", |
|
"model_used": "turbo", |
|
"segments": [ |
|
{"start": 0.0, "end": 5.0, "text": "Hello everyone", "speaker": "SPEAKER_00"}, |
|
{"start": 5.0, "end": 10.0, "text": "How are you?", "speaker": "SPEAKER_01"}, |
|
{"start": 10.0, "end": 15.0, "text": "I'm doing well", "speaker": "SPEAKER_00"}, |
|
] |
|
}, |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 60.0, |
|
"audio_duration": 60.0, |
|
"language_detected": "en", |
|
"model_used": "turbo", |
|
"segments": [ |
|
{"start": 0.0, "end": 5.0, "text": "Let's continue", "speaker": "SPEAKER_00"}, |
|
{"start": 5.0, "end": 10.0, "text": "Great idea", "speaker": "SPEAKER_01"}, |
|
{"start": 10.0, "end": 15.0, "text": "New person here", "speaker": "SPEAKER_02"}, |
|
] |
|
} |
|
] |
|
|
|
with patch('src.services.speaker_embedding_service.SpeakerIdentificationService') as mock_service_class: |
|
|
|
mock_service = Mock() |
|
mock_service.unify_distributed_speakers = AsyncMock() |
|
|
|
|
|
mock_speaker_mapping = { |
|
"chunk_0_SPEAKER_00": "SPEAKER_GLOBAL_001", |
|
"chunk_0_SPEAKER_01": "SPEAKER_GLOBAL_002", |
|
"chunk_1_SPEAKER_00": "SPEAKER_GLOBAL_001", |
|
"chunk_1_SPEAKER_01": "SPEAKER_GLOBAL_002", |
|
"chunk_1_SPEAKER_02": "SPEAKER_GLOBAL_003", |
|
} |
|
|
|
mock_service.unify_distributed_speakers.return_value = mock_speaker_mapping |
|
mock_service_class.return_value = mock_service |
|
|
|
|
|
result = await self.service.merge_chunk_results( |
|
chunk_results=chunk_results, |
|
output_format="srt", |
|
enable_speaker_diarization=True, |
|
audio_file_path="test_audio.wav" |
|
) |
|
|
|
|
|
assert result["processing_status"] == "success" |
|
assert result["speaker_diarization_enabled"] is True |
|
assert result["speaker_embedding_unified"] is True |
|
assert result["distributed_processing"] is True |
|
assert result["chunks_processed"] == 2 |
|
assert result["chunks_failed"] == 0 |
|
|
|
|
|
assert result["global_speaker_count"] == 3 |
|
expected_speakers = {"SPEAKER_GLOBAL_001", "SPEAKER_GLOBAL_002", "SPEAKER_GLOBAL_003"} |
|
assert set(result["speakers_detected"]) == expected_speakers |
|
|
|
|
|
segments = result["segments"] |
|
assert len(segments) == 6 |
|
|
|
|
|
for segment in segments: |
|
assert "speaker" in segment |
|
assert segment["speaker"] in expected_speakers |
|
|
|
@pytest.mark.asyncio |
|
async def test_merge_chunk_results_without_speaker_diarization(self): |
|
"""Test merge_chunk_results when speaker diarization is disabled""" |
|
|
|
chunk_results = [ |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 0.0, |
|
"audio_duration": 60.0, |
|
"language_detected": "en", |
|
"model_used": "turbo", |
|
"segments": [ |
|
{"start": 0.0, "end": 5.0, "text": "Hello everyone"}, |
|
{"start": 5.0, "end": 10.0, "text": "How are you?"}, |
|
] |
|
} |
|
] |
|
|
|
result = await self.service.merge_chunk_results( |
|
chunk_results=chunk_results, |
|
output_format="srt", |
|
enable_speaker_diarization=False, |
|
audio_file_path="test_audio.wav" |
|
) |
|
|
|
|
|
assert result["processing_status"] == "success" |
|
assert result["speaker_diarization_enabled"] is False |
|
assert "speaker_embedding_unified" not in result or result["speaker_embedding_unified"] is False |
|
|
|
|
|
@pytest.mark.asyncio |
|
async def test_merge_chunk_results_speaker_unification_failure(self): |
|
"""Test merge_chunk_results when speaker unification fails""" |
|
|
|
chunk_results = [ |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 0.0, |
|
"audio_duration": 60.0, |
|
"language_detected": "en", |
|
"model_used": "turbo", |
|
"segments": [ |
|
{"start": 0.0, "end": 5.0, "text": "Hello", "speaker": "SPEAKER_00"}, |
|
] |
|
} |
|
] |
|
|
|
with patch('src.services.speaker_embedding_service.SpeakerIdentificationService') as mock_service_class: |
|
|
|
mock_service = Mock() |
|
mock_service.unify_distributed_speakers = AsyncMock(side_effect=Exception("Model not available")) |
|
mock_service_class.return_value = mock_service |
|
|
|
result = await self.service.merge_chunk_results( |
|
chunk_results=chunk_results, |
|
output_format="srt", |
|
enable_speaker_diarization=True, |
|
audio_file_path="test_audio.wav" |
|
) |
|
|
|
|
|
assert result["processing_status"] == "success" |
|
assert result["speaker_diarization_enabled"] is True |
|
assert "speaker_embedding_unified" not in result or result["speaker_embedding_unified"] is False |
|
|
|
|
|
segments = result["segments"] |
|
assert len(segments) == 1 |
|
assert segments[0]["speaker"] == "SPEAKER_CHUNK_0_SPEAKER_00" |
|
|
|
@pytest.mark.asyncio |
|
async def test_merge_chunk_results_unknown_speaker_filtering(self): |
|
"""Test that UNKNOWN speakers are properly filtered from output""" |
|
|
|
service = DistributedTranscriptionService() |
|
|
|
|
|
chunk_results = [ |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 0.0, |
|
"chunk_end_time": 30.0, |
|
"segments": [ |
|
{ |
|
"start": 0.0, |
|
"end": 2.0, |
|
"text": "Hello world", |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 2.0, |
|
"end": 4.0, |
|
"text": "This has no speaker", |
|
|
|
}, |
|
{ |
|
"start": 4.0, |
|
"end": 6.0, |
|
"text": "Another good segment", |
|
"speaker": "SPEAKER_01" |
|
} |
|
] |
|
}, |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 30.0, |
|
"chunk_end_time": 60.0, |
|
"segments": [ |
|
{ |
|
"start": 0.0, |
|
"end": 2.0, |
|
"text": "", |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 2.0, |
|
"end": 4.0, |
|
"text": "Good segment from chunk 2", |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
} |
|
] |
|
|
|
|
|
with patch('src.services.distributed_transcription_service.SpeakerIdentificationService', create=True) as mock_service_class: |
|
with patch('src.services.distributed_transcription_service.SpeakerEmbeddingService', create=True) as mock_manager_class: |
|
|
|
result = await service.merge_chunk_results( |
|
chunk_results=chunk_results, |
|
output_format="srt", |
|
enable_speaker_diarization=False, |
|
audio_file_path="test.wav" |
|
) |
|
|
|
|
|
assert result["processing_status"] == "success" |
|
assert result["chunks_processed"] == 2 |
|
assert result["chunks_failed"] == 0 |
|
|
|
|
|
assert "total_segments_collected" in result |
|
assert "unknown_segments_filtered" in result |
|
assert result["total_segments_collected"] == 5 |
|
assert result["unknown_segments_filtered"] == 1 |
|
assert result["segment_count"] == 4 |
|
|
|
|
|
segments = result["segments"] |
|
assert len(segments) == 4 |
|
|
|
|
|
for segment in segments: |
|
assert "speaker" in segment |
|
assert segment["speaker"] != "UNKNOWN" |
|
|
|
|
|
|
|
|
|
full_text = result["text"] |
|
assert "This has no speaker" not in full_text |
|
assert "Hello world" in full_text |
|
assert "Another good segment" in full_text |
|
assert "Good segment from chunk 2" in full_text |
|
|
|
|
|
class TestSpeakerEmbeddingWorkflow: |
|
"""Test complete workflow scenarios""" |
|
|
|
def setup_method(self): |
|
"""Setup test environment""" |
|
self.temp_dir = tempfile.mkdtemp() |
|
|
|
def teardown_method(self): |
|
"""Cleanup test environment""" |
|
shutil.rmtree(self.temp_dir, ignore_errors=True) |
|
|
|
@pytest.mark.asyncio |
|
@patch.dict('os.environ', {'HF_TOKEN': 'test_token'}) |
|
async def test_end_to_end_speaker_unification(self): |
|
"""Test complete end-to-end speaker unification workflow""" |
|
|
|
|
|
embedding_service = SpeakerEmbeddingService( |
|
storage_path=str(Path(self.temp_dir) / "speakers.json") |
|
) |
|
speaker_service = SpeakerIdentificationService(embedding_service) |
|
|
|
|
|
speaker_service.embedding_model = Mock() |
|
|
|
|
|
|
|
chunk_results = [ |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 0, |
|
"segments": [ |
|
{"start": 0, "end": 3, "text": "Hi there", "speaker": "SPEAKER_00"}, |
|
{"start": 3, "end": 6, "text": "Hello", "speaker": "SPEAKER_01"}, |
|
] |
|
}, |
|
{ |
|
"processing_status": "success", |
|
"chunk_start_time": 60, |
|
"segments": [ |
|
{"start": 0, "end": 3, "text": "How are things?", "speaker": "SPEAKER_00"}, |
|
{"start": 3, "end": 6, "text": "Going well", "speaker": "SPEAKER_01"}, |
|
] |
|
} |
|
] |
|
|
|
|
|
person_a_base = np.zeros(512) |
|
person_a_base[0] = 1.0 |
|
|
|
person_b_base = np.zeros(512) |
|
person_b_base[256] = 1.0 |
|
|
|
def mock_crop_side_effect(waveform, segment): |
|
|
|
segment_start = float(segment.start) |
|
if segment_start < 3 or (segment_start >= 60 and segment_start < 63): |
|
|
|
return torch.tensor(person_a_base + np.random.normal(0, 0.005, 512)) |
|
else: |
|
|
|
return torch.tensor(person_b_base + np.random.normal(0, 0.005, 512)) |
|
|
|
mock_inference = Mock() |
|
mock_inference.crop.side_effect = mock_crop_side_effect |
|
mock_waveform = torch.rand(1, 96000) |
|
|
|
with patch('torchaudio.load', return_value=(mock_waveform, 16000)), \ |
|
patch('pyannote.audio.core.inference.Inference', return_value=mock_inference): |
|
|
|
|
|
mapping = await speaker_service.unify_distributed_speakers( |
|
chunk_results, "test_audio.wav" |
|
) |
|
|
|
|
|
assert len(mapping) == 4 |
|
|
|
|
|
chunk_0_speaker_00 = mapping["chunk_0_SPEAKER_00"] |
|
chunk_1_speaker_00 = mapping["chunk_1_SPEAKER_00"] |
|
chunk_0_speaker_01 = mapping["chunk_0_SPEAKER_01"] |
|
chunk_1_speaker_01 = mapping["chunk_1_SPEAKER_01"] |
|
|
|
|
|
assert chunk_0_speaker_00 == chunk_1_speaker_00 |
|
assert chunk_0_speaker_01 == chunk_1_speaker_01 |
|
assert chunk_0_speaker_00 != chunk_0_speaker_01 |
|
|
|
|
|
assert chunk_0_speaker_00.startswith("SPEAKER_GLOBAL_") |
|
assert chunk_0_speaker_01.startswith("SPEAKER_GLOBAL_") |
|
|
|
|
|
if __name__ == "__main__": |
|
pytest.main([__file__, "-v"]) |