ModalTranscriberMCP / tests /test_speaker_segmentation_advanced.py
richard-su's picture
Upload folder using huggingface_hub
76f9cd2 verified
raw
history blame
15 kB
"""
高级 Speaker Segmentation 测试用例
测试边缘情况、性能和复杂场景
"""
import pytest
import time
import random
from typing import List, Dict
import sys
import os
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.services.transcription_service import TranscriptionService
class TestSpeakerSegmentationAdvanced:
"""高级说话人分割测试"""
def setup_method(self):
"""设置测试环境"""
self.service = TranscriptionService()
def test_rapid_speaker_changes(self):
"""测试快速说话人切换的情况"""
transcription_segments = [
{
"start": 0.0,
"end": 5.0,
"text": "A B C D E F G H I J K L M N O P Q R S T"
}
]
# 模拟每0.25秒切换一次说话人
speaker_segments = []
for i in range(20):
start_time = i * 0.25
end_time = (i + 1) * 0.25
speaker = f"SPEAKER_{i % 2:02d}" # 两个说话人交替
speaker_segments.append({
"start": start_time,
"end": end_time,
"speaker": speaker
})
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 应该有多个分割段
assert len(result) > 1
# 检查说话人交替
speakers = [seg["speaker"] for seg in result]
assert "SPEAKER_00" in speakers
assert "SPEAKER_01" in speakers
# 检查文本完整性
combined_text = " ".join([seg["text"] for seg in result if seg["text"]])
original_words = "A B C D E F G H I J K L M N O P Q R S T".split()
result_words = combined_text.split()
# 允许一些文本丢失(由于快速切换),但大部分应该保留
preserved_ratio = len(result_words) / len(original_words)
assert preserved_ratio > 0.5, f"Too much text lost: {preserved_ratio:.2f}"
def test_very_short_speaker_segments(self):
"""测试非常短的说话人段"""
transcription_segments = [
{
"start": 0.0,
"end": 2.0,
"text": "Quick short conversation"
}
]
speaker_segments = [
{
"start": 0.0,
"end": 0.1,
"speaker": "SPEAKER_00"
},
{
"start": 0.1,
"end": 0.2,
"speaker": "SPEAKER_01"
},
{
"start": 0.2,
"end": 2.0,
"speaker": "SPEAKER_00"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 即使有很短的段,也应该正确处理
assert len(result) >= 1
# 检查没有空文本段(除非原本就是空的)
for seg in result:
if seg["start"] < seg["end"]: # 有效时间段
# 可以接受空文本(因为段太短)
pass
def test_overlapping_segments_complex(self):
"""测试复杂的重叠情况"""
transcription_segments = [
{
"start": 0.0,
"end": 10.0,
"text": "This is a complex conversation with multiple overlapping speakers talking at the same time"
}
]
speaker_segments = [
# 三个说话人同时说话
{
"start": 0.0,
"end": 6.0,
"speaker": "SPEAKER_00"
},
{
"start": 2.0,
"end": 8.0,
"speaker": "SPEAKER_01"
},
{
"start": 4.0,
"end": 10.0,
"speaker": "SPEAKER_02"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 应该有多个段
assert len(result) >= 2
# 检查所有说话人都被包含
speakers = set(seg["speaker"] for seg in result)
assert len(speakers) >= 2 # 至少包含两个说话人
# 检查时间合理性(不要求严格连续性,因为有重叠)
for seg in result:
assert seg["start"] < seg["end"], f"Invalid timing: {seg['start']} >= {seg['end']}"
def test_performance_large_segments(self):
"""测试大量段的性能"""
# 生成大量转录段
transcription_segments = []
for i in range(100):
transcription_segments.append({
"start": i * 1.0,
"end": (i + 1) * 1.0,
"text": f"This is segment number {i} with some text content for testing purposes"
})
# 生成大量说话人段
speaker_segments = []
for i in range(200):
speaker_segments.append({
"start": i * 0.5,
"end": (i + 1) * 0.5,
"speaker": f"SPEAKER_{i % 5:02d}" # 5个说话人循环
})
# 测量执行时间
start_time = time.time()
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
execution_time = time.time() - start_time
# 应该在合理时间内完成(< 1秒)
assert execution_time < 1.0, f"Performance too slow: {execution_time:.2f}s"
# 检查结果合理性
assert len(result) > 0
assert len(result) <= len(transcription_segments) * 2 # 每个段最多分割一次
def test_no_overlap_at_all(self):
"""测试完全没有重叠的情况"""
transcription_segments = [
{
"start": 0.0,
"end": 2.0,
"text": "First segment"
},
{
"start": 5.0,
"end": 7.0,
"text": "Second segment"
}
]
speaker_segments = [
{
"start": 3.0,
"end": 4.0,
"speaker": "SPEAKER_00"
},
{
"start": 8.0,
"end": 9.0,
"speaker": "SPEAKER_01"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 应该保持原始段,但没有说话人信息
assert len(result) == 2
assert result[0]["speaker"] is None
assert result[1]["speaker"] is None
assert result[0]["text"] == "First segment"
assert result[1]["text"] == "Second segment"
def test_exact_boundary_matching(self):
"""测试精确边界匹配"""
transcription_segments = [
{
"start": 1.0,
"end": 3.0,
"text": "Exact boundary match"
}
]
speaker_segments = [
{
"start": 1.0,
"end": 3.0,
"speaker": "SPEAKER_00"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 1
assert result[0]["speaker"] == "SPEAKER_00"
assert result[0]["start"] == 1.0
assert result[0]["end"] == 3.0
assert result[0]["text"] == "Exact boundary match"
def test_floating_point_precision(self):
"""测试浮点数精度问题"""
transcription_segments = [
{
"start": 0.1,
"end": 0.3,
"text": "Precision test"
}
]
speaker_segments = [
{
"start": 0.10000001, # 微小差异
"end": 0.29999999,
"speaker": "SPEAKER_00"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 1
assert result[0]["speaker"] == "SPEAKER_00"
def test_text_distribution_accuracy(self):
"""测试文本分配的准确性"""
transcription_segments = [
{
"start": 0.0,
"end": 10.0,
"text": "One two three four five six seven eight nine ten"
}
]
speaker_segments = [
{
"start": 0.0,
"end": 8.0, # 80% 的时间
"speaker": "SPEAKER_00"
},
{
"start": 8.0,
"end": 10.0, # 20% 的时间
"speaker": "SPEAKER_01"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 2
# SPEAKER_00 应该得到大约 80% 的文本
total_text_length = len("One two three four five six seven eight nine ten")
speaker_00_length = len(result[0]["text"])
speaker_01_length = len(result[1]["text"])
speaker_00_ratio = speaker_00_length / total_text_length
speaker_01_ratio = speaker_01_length / total_text_length
# 允许一定的误差范围
assert 0.6 <= speaker_00_ratio <= 0.9, f"SPEAKER_00 ratio: {speaker_00_ratio:.2f}"
assert 0.1 <= speaker_01_ratio <= 0.4, f"SPEAKER_01 ratio: {speaker_01_ratio:.2f}"
def test_single_word_segments(self):
"""测试单词级别的分割"""
transcription_segments = [
{
"start": 0.0,
"end": 4.0,
"text": "Hello world"
}
]
speaker_segments = [
{
"start": 0.0,
"end": 2.0,
"speaker": "SPEAKER_00"
},
{
"start": 2.0,
"end": 4.0,
"speaker": "SPEAKER_01"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 2
# 检查单词没有被切断
all_words = []
for seg in result:
if seg["text"]:
all_words.extend(seg["text"].split())
# 应该包含原始的两个完整单词
original_words = ["Hello", "world"]
for word in original_words:
assert any(word in all_words for word in original_words), \
f"Words not preserved correctly: {all_words}"
def test_empty_speaker_segments(self):
"""测试空的说话人段列表"""
transcription_segments = [
{
"start": 0.0,
"end": 5.0,
"text": "No speakers detected"
}
]
speaker_segments = []
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 1
assert result[0]["speaker"] is None
assert result[0]["text"] == "No speakers detected"
def test_malformed_input_handling(self):
"""测试处理格式错误的输入"""
# 缺少必要字段的转录段
transcription_segments = [
{
"start": 0.0,
"end": 5.0,
# 缺少 "text" 字段
}
]
speaker_segments = [
{
"start": 0.0,
"end": 5.0,
"speaker": "SPEAKER_00"
}
]
# 应该能够处理而不崩溃
try:
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 如果没有异常,检查结果
assert len(result) >= 0
except Exception as e:
# 如果有异常,确保是预期的类型
assert isinstance(e, (KeyError, AttributeError, TypeError))
class TestSpeakerSegmentationBenchmark:
"""性能基准测试"""
def setup_method(self):
"""设置测试环境"""
self.service = TranscriptionService()
@pytest.mark.benchmark
def test_benchmark_typical_podcast(self):
"""基准测试:典型播客场景(30分钟,3个说话人)"""
# 模拟30分钟的播客:每5秒一个转录段
transcription_segments = []
for i in range(360): # 30分钟 * 60秒 / 5秒
start_time = i * 5.0
end_time = (i + 1) * 5.0
transcription_segments.append({
"start": start_time,
"end": end_time,
"text": f"This is a podcast segment number {i} with typical conversation content"
})
# 模拟说话人切换:平均每30秒切换一次说话人
speaker_segments = []
current_time = 0.0
current_speaker = 0
while current_time < 1800.0: # 30分钟
segment_duration = random.uniform(15.0, 45.0) # 15-45秒的说话段
speaker_segments.append({
"start": current_time,
"end": min(current_time + segment_duration, 1800.0),
"speaker": f"SPEAKER_{current_speaker:02d}"
})
current_time += segment_duration
current_speaker = (current_speaker + 1) % 3 # 3个说话人循环
# 执行基准测试
start_time = time.time()
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
execution_time = time.time() - start_time
print(f"\n📊 Benchmark Results:")
print(f" Input segments: {len(transcription_segments)}")
print(f" Speaker segments: {len(speaker_segments)}")
print(f" Output segments: {len(result)}")
print(f" Execution time: {execution_time:.3f}s")
print(f" Segments per second: {len(result)/execution_time:.1f}")
# 性能要求:应该能够处理实时转录(每秒至少处理60个段)
assert execution_time < 2.0, f"Too slow for real-time processing: {execution_time:.3f}s"
assert len(result) > 0, "Should produce some output segments"
if __name__ == "__main__":
# 运行所有测试
pytest.main([__file__, "-v", "--tb=short", "-m", "not benchmark"])
# 单独运行基准测试
print("\n" + "="*60)
print("Running Benchmark Tests...")
print("="*60)
pytest.main([__file__, "-v", "--tb=short", "-m", "benchmark"])