|
""" |
|
高级 Speaker Segmentation 测试用例 |
|
测试边缘情况、性能和复杂场景 |
|
""" |
|
|
|
import pytest |
|
import time |
|
import random |
|
from typing import List, Dict |
|
import sys |
|
import os |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.services.transcription_service import TranscriptionService |
|
|
|
|
|
class TestSpeakerSegmentationAdvanced: |
|
"""高级说话人分割测试""" |
|
|
|
def setup_method(self): |
|
"""设置测试环境""" |
|
self.service = TranscriptionService() |
|
|
|
def test_rapid_speaker_changes(self): |
|
"""测试快速说话人切换的情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
"text": "A B C D E F G H I J K L M N O P Q R S T" |
|
} |
|
] |
|
|
|
|
|
speaker_segments = [] |
|
for i in range(20): |
|
start_time = i * 0.25 |
|
end_time = (i + 1) * 0.25 |
|
speaker = f"SPEAKER_{i % 2:02d}" |
|
speaker_segments.append({ |
|
"start": start_time, |
|
"end": end_time, |
|
"speaker": speaker |
|
}) |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
|
|
assert len(result) > 1 |
|
|
|
|
|
speakers = [seg["speaker"] for seg in result] |
|
assert "SPEAKER_00" in speakers |
|
assert "SPEAKER_01" in speakers |
|
|
|
|
|
combined_text = " ".join([seg["text"] for seg in result if seg["text"]]) |
|
original_words = "A B C D E F G H I J K L M N O P Q R S T".split() |
|
result_words = combined_text.split() |
|
|
|
|
|
preserved_ratio = len(result_words) / len(original_words) |
|
assert preserved_ratio > 0.5, f"Too much text lost: {preserved_ratio:.2f}" |
|
|
|
def test_very_short_speaker_segments(self): |
|
"""测试非常短的说话人段""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 2.0, |
|
"text": "Quick short conversation" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 0.1, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 0.1, |
|
"end": 0.2, |
|
"speaker": "SPEAKER_01" |
|
}, |
|
{ |
|
"start": 0.2, |
|
"end": 2.0, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
|
|
assert len(result) >= 1 |
|
|
|
|
|
for seg in result: |
|
if seg["start"] < seg["end"]: |
|
|
|
pass |
|
|
|
def test_overlapping_segments_complex(self): |
|
"""测试复杂的重叠情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 10.0, |
|
"text": "This is a complex conversation with multiple overlapping speakers talking at the same time" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
|
|
{ |
|
"start": 0.0, |
|
"end": 6.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 2.0, |
|
"end": 8.0, |
|
"speaker": "SPEAKER_01" |
|
}, |
|
{ |
|
"start": 4.0, |
|
"end": 10.0, |
|
"speaker": "SPEAKER_02" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
|
|
assert len(result) >= 2 |
|
|
|
|
|
speakers = set(seg["speaker"] for seg in result) |
|
assert len(speakers) >= 2 |
|
|
|
|
|
for seg in result: |
|
assert seg["start"] < seg["end"], f"Invalid timing: {seg['start']} >= {seg['end']}" |
|
|
|
def test_performance_large_segments(self): |
|
"""测试大量段的性能""" |
|
|
|
transcription_segments = [] |
|
for i in range(100): |
|
transcription_segments.append({ |
|
"start": i * 1.0, |
|
"end": (i + 1) * 1.0, |
|
"text": f"This is segment number {i} with some text content for testing purposes" |
|
}) |
|
|
|
|
|
speaker_segments = [] |
|
for i in range(200): |
|
speaker_segments.append({ |
|
"start": i * 0.5, |
|
"end": (i + 1) * 0.5, |
|
"speaker": f"SPEAKER_{i % 5:02d}" |
|
}) |
|
|
|
|
|
start_time = time.time() |
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
execution_time = time.time() - start_time |
|
|
|
|
|
assert execution_time < 1.0, f"Performance too slow: {execution_time:.2f}s" |
|
|
|
|
|
assert len(result) > 0 |
|
assert len(result) <= len(transcription_segments) * 2 |
|
|
|
def test_no_overlap_at_all(self): |
|
"""测试完全没有重叠的情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 2.0, |
|
"text": "First segment" |
|
}, |
|
{ |
|
"start": 5.0, |
|
"end": 7.0, |
|
"text": "Second segment" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 3.0, |
|
"end": 4.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 8.0, |
|
"end": 9.0, |
|
"speaker": "SPEAKER_01" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
|
|
assert len(result) == 2 |
|
assert result[0]["speaker"] is None |
|
assert result[1]["speaker"] is None |
|
assert result[0]["text"] == "First segment" |
|
assert result[1]["text"] == "Second segment" |
|
|
|
def test_exact_boundary_matching(self): |
|
"""测试精确边界匹配""" |
|
transcription_segments = [ |
|
{ |
|
"start": 1.0, |
|
"end": 3.0, |
|
"text": "Exact boundary match" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 1.0, |
|
"end": 3.0, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 1 |
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[0]["start"] == 1.0 |
|
assert result[0]["end"] == 3.0 |
|
assert result[0]["text"] == "Exact boundary match" |
|
|
|
def test_floating_point_precision(self): |
|
"""测试浮点数精度问题""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.1, |
|
"end": 0.3, |
|
"text": "Precision test" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.10000001, |
|
"end": 0.29999999, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 1 |
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
|
|
def test_text_distribution_accuracy(self): |
|
"""测试文本分配的准确性""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 10.0, |
|
"text": "One two three four five six seven eight nine ten" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 8.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 8.0, |
|
"end": 10.0, |
|
"speaker": "SPEAKER_01" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 2 |
|
|
|
|
|
total_text_length = len("One two three four five six seven eight nine ten") |
|
speaker_00_length = len(result[0]["text"]) |
|
speaker_01_length = len(result[1]["text"]) |
|
|
|
speaker_00_ratio = speaker_00_length / total_text_length |
|
speaker_01_ratio = speaker_01_length / total_text_length |
|
|
|
|
|
assert 0.6 <= speaker_00_ratio <= 0.9, f"SPEAKER_00 ratio: {speaker_00_ratio:.2f}" |
|
assert 0.1 <= speaker_01_ratio <= 0.4, f"SPEAKER_01 ratio: {speaker_01_ratio:.2f}" |
|
|
|
def test_single_word_segments(self): |
|
"""测试单词级别的分割""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 4.0, |
|
"text": "Hello world" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 2.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 2.0, |
|
"end": 4.0, |
|
"speaker": "SPEAKER_01" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 2 |
|
|
|
|
|
all_words = [] |
|
for seg in result: |
|
if seg["text"]: |
|
all_words.extend(seg["text"].split()) |
|
|
|
|
|
original_words = ["Hello", "world"] |
|
for word in original_words: |
|
assert any(word in all_words for word in original_words), \ |
|
f"Words not preserved correctly: {all_words}" |
|
|
|
def test_empty_speaker_segments(self): |
|
"""测试空的说话人段列表""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
"text": "No speakers detected" |
|
} |
|
] |
|
|
|
speaker_segments = [] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 1 |
|
assert result[0]["speaker"] is None |
|
assert result[0]["text"] == "No speakers detected" |
|
|
|
def test_malformed_input_handling(self): |
|
"""测试处理格式错误的输入""" |
|
|
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
|
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
|
|
try: |
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) >= 0 |
|
except Exception as e: |
|
|
|
assert isinstance(e, (KeyError, AttributeError, TypeError)) |
|
|
|
|
|
class TestSpeakerSegmentationBenchmark: |
|
"""性能基准测试""" |
|
|
|
def setup_method(self): |
|
"""设置测试环境""" |
|
self.service = TranscriptionService() |
|
|
|
@pytest.mark.benchmark |
|
def test_benchmark_typical_podcast(self): |
|
"""基准测试:典型播客场景(30分钟,3个说话人)""" |
|
|
|
transcription_segments = [] |
|
for i in range(360): |
|
start_time = i * 5.0 |
|
end_time = (i + 1) * 5.0 |
|
transcription_segments.append({ |
|
"start": start_time, |
|
"end": end_time, |
|
"text": f"This is a podcast segment number {i} with typical conversation content" |
|
}) |
|
|
|
|
|
speaker_segments = [] |
|
current_time = 0.0 |
|
current_speaker = 0 |
|
|
|
while current_time < 1800.0: |
|
segment_duration = random.uniform(15.0, 45.0) |
|
speaker_segments.append({ |
|
"start": current_time, |
|
"end": min(current_time + segment_duration, 1800.0), |
|
"speaker": f"SPEAKER_{current_speaker:02d}" |
|
}) |
|
current_time += segment_duration |
|
current_speaker = (current_speaker + 1) % 3 |
|
|
|
|
|
start_time = time.time() |
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
execution_time = time.time() - start_time |
|
|
|
print(f"\n📊 Benchmark Results:") |
|
print(f" Input segments: {len(transcription_segments)}") |
|
print(f" Speaker segments: {len(speaker_segments)}") |
|
print(f" Output segments: {len(result)}") |
|
print(f" Execution time: {execution_time:.3f}s") |
|
print(f" Segments per second: {len(result)/execution_time:.1f}") |
|
|
|
|
|
assert execution_time < 2.0, f"Too slow for real-time processing: {execution_time:.3f}s" |
|
assert len(result) > 0, "Should produce some output segments" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
pytest.main([__file__, "-v", "--tb=short", "-m", "not benchmark"]) |
|
|
|
|
|
print("\n" + "="*60) |
|
print("Running Benchmark Tests...") |
|
print("="*60) |
|
pytest.main([__file__, "-v", "--tb=short", "-m", "benchmark"]) |