File size: 7,895 Bytes
76f9cd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""
Test for concurrent processing functionality
"""
import asyncio
import pytest
from unittest.mock import Mock, AsyncMock, patch
from src.services.distributed_transcription_service import DistributedTranscriptionService
class TestConcurrentProcessing:
"""Test the new concurrent processing logic"""
@pytest.mark.asyncio
async def test_asyncio_task_creation_and_waiting(self):
"""Test that asyncio tasks are created and waited for correctly"""
service = DistributedTranscriptionService()
# Mock the chunk transcription method
async def mock_transcribe_chunk(*args, **kwargs):
await asyncio.sleep(0.1) # Simulate processing time
return {
"processing_status": "success",
"text": "Mock transcription",
"segments": [{"start": 0, "end": 1, "text": "test"}]
}
# Mock the audio splitting method to return a small number of chunks
mock_chunks = [
("/tmp/chunk1.wav", 0.0, 10.0),
("/tmp/chunk2.wav", 10.0, 20.0),
("/tmp/chunk3.wav", 20.0, 30.0)
]
with patch.object(service, 'split_audio_locally', return_value=mock_chunks):
with patch.object(service, 'transcribe_chunk_distributed', side_effect=mock_transcribe_chunk):
with patch.object(service, 'merge_chunk_results') as mock_merge:
mock_merge.return_value = {
"processing_status": "success",
"chunks_processed": 3,
"chunks_failed": 0
}
# Test the distributed transcription
result = await service.transcribe_audio_distributed(
audio_file_path="test.wav",
model_size="turbo",
chunk_endpoint_url="http://test.com"
)
# Verify result
assert result["processing_status"] == "success"
assert result["chunks_processed"] == 3
assert result["chunks_failed"] == 0
# Verify merge was called with correct number of results
mock_merge.assert_called_once()
call_args = mock_merge.call_args[0]
chunk_results = call_args[0]
assert len(chunk_results) == 3
# Verify all chunk results are successful
for chunk_result in chunk_results:
assert chunk_result["processing_status"] == "success"
@pytest.mark.asyncio
async def test_concurrent_processing_with_failures(self):
"""Test concurrent processing handles chunk failures correctly"""
service = DistributedTranscriptionService()
# Mock the chunk transcription method with mixed success/failure
async def mock_transcribe_chunk_mixed(chunk_path, *args, **kwargs):
await asyncio.sleep(0.1)
if "chunk1" in chunk_path:
return {
"processing_status": "success",
"text": "Success",
"segments": [{"start": 0, "end": 1, "text": "test"}]
}
else:
return {
"processing_status": "failed",
"error_message": "Mock failure"
}
# Mock chunks
mock_chunks = [
("/tmp/chunk1.wav", 0.0, 10.0),
("/tmp/chunk2.wav", 10.0, 20.0),
("/tmp/chunk3.wav", 20.0, 30.0)
]
with patch.object(service, 'split_audio_locally', return_value=mock_chunks):
with patch.object(service, 'transcribe_chunk_distributed', side_effect=mock_transcribe_chunk_mixed):
with patch.object(service, 'merge_chunk_results') as mock_merge:
mock_merge.return_value = {
"processing_status": "success",
"chunks_processed": 1,
"chunks_failed": 2
}
# Test the distributed transcription
result = await service.transcribe_audio_distributed(
audio_file_path="test.wav",
model_size="turbo",
chunk_endpoint_url="http://test.com"
)
# Verify result
assert result["processing_status"] == "success"
assert result["chunks_processed"] == 1
assert result["chunks_failed"] == 2
# Verify merge was called with mixed results
mock_merge.assert_called_once()
call_args = mock_merge.call_args[0]
chunk_results = call_args[0]
assert len(chunk_results) == 3
# Verify result distribution
successful_results = [r for r in chunk_results if r["processing_status"] == "success"]
failed_results = [r for r in chunk_results if r["processing_status"] == "failed"]
assert len(successful_results) == 1
assert len(failed_results) == 2
@pytest.mark.asyncio
async def test_concurrent_processing_exception_handling(self):
"""Test that exceptions in individual chunks are handled correctly"""
service = DistributedTranscriptionService()
# Mock the chunk transcription method that raises exceptions
async def mock_transcribe_chunk_exception(*args, **kwargs):
await asyncio.sleep(0.1)
raise Exception("Mock network error")
# Mock chunks
mock_chunks = [
("/tmp/chunk1.wav", 0.0, 10.0),
("/tmp/chunk2.wav", 10.0, 20.0)
]
with patch.object(service, 'split_audio_locally', return_value=mock_chunks):
with patch.object(service, 'transcribe_chunk_distributed', side_effect=mock_transcribe_chunk_exception):
with patch.object(service, 'merge_chunk_results') as mock_merge:
mock_merge.return_value = {
"processing_status": "failed",
"error_message": "All chunks failed to process",
"chunks_processed": 0,
"chunks_failed": 2
}
# Test the distributed transcription
result = await service.transcribe_audio_distributed(
audio_file_path="test.wav",
model_size="turbo",
chunk_endpoint_url="http://test.com"
)
# Verify result
assert result["processing_status"] == "failed"
assert result["chunks_processed"] == 0
assert result["chunks_failed"] == 2
# Verify merge was called with failed results
mock_merge.assert_called_once()
call_args = mock_merge.call_args[0]
chunk_results = call_args[0]
assert len(chunk_results) == 2
# All results should be failures
for chunk_result in chunk_results:
assert chunk_result["processing_status"] == "failed"
assert "Mock network error" in chunk_result["error_message"] |