File size: 15,029 Bytes
76f9cd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 |
"""
高级 Speaker Segmentation 测试用例
测试边缘情况、性能和复杂场景
"""
import pytest
import time
import random
from typing import List, Dict
import sys
import os
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.services.transcription_service import TranscriptionService
class TestSpeakerSegmentationAdvanced:
"""高级说话人分割测试"""
def setup_method(self):
"""设置测试环境"""
self.service = TranscriptionService()
def test_rapid_speaker_changes(self):
"""测试快速说话人切换的情况"""
transcription_segments = [
{
"start": 0.0,
"end": 5.0,
"text": "A B C D E F G H I J K L M N O P Q R S T"
}
]
# 模拟每0.25秒切换一次说话人
speaker_segments = []
for i in range(20):
start_time = i * 0.25
end_time = (i + 1) * 0.25
speaker = f"SPEAKER_{i % 2:02d}" # 两个说话人交替
speaker_segments.append({
"start": start_time,
"end": end_time,
"speaker": speaker
})
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 应该有多个分割段
assert len(result) > 1
# 检查说话人交替
speakers = [seg["speaker"] for seg in result]
assert "SPEAKER_00" in speakers
assert "SPEAKER_01" in speakers
# 检查文本完整性
combined_text = " ".join([seg["text"] for seg in result if seg["text"]])
original_words = "A B C D E F G H I J K L M N O P Q R S T".split()
result_words = combined_text.split()
# 允许一些文本丢失(由于快速切换),但大部分应该保留
preserved_ratio = len(result_words) / len(original_words)
assert preserved_ratio > 0.5, f"Too much text lost: {preserved_ratio:.2f}"
def test_very_short_speaker_segments(self):
"""测试非常短的说话人段"""
transcription_segments = [
{
"start": 0.0,
"end": 2.0,
"text": "Quick short conversation"
}
]
speaker_segments = [
{
"start": 0.0,
"end": 0.1,
"speaker": "SPEAKER_00"
},
{
"start": 0.1,
"end": 0.2,
"speaker": "SPEAKER_01"
},
{
"start": 0.2,
"end": 2.0,
"speaker": "SPEAKER_00"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 即使有很短的段,也应该正确处理
assert len(result) >= 1
# 检查没有空文本段(除非原本就是空的)
for seg in result:
if seg["start"] < seg["end"]: # 有效时间段
# 可以接受空文本(因为段太短)
pass
def test_overlapping_segments_complex(self):
"""测试复杂的重叠情况"""
transcription_segments = [
{
"start": 0.0,
"end": 10.0,
"text": "This is a complex conversation with multiple overlapping speakers talking at the same time"
}
]
speaker_segments = [
# 三个说话人同时说话
{
"start": 0.0,
"end": 6.0,
"speaker": "SPEAKER_00"
},
{
"start": 2.0,
"end": 8.0,
"speaker": "SPEAKER_01"
},
{
"start": 4.0,
"end": 10.0,
"speaker": "SPEAKER_02"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 应该有多个段
assert len(result) >= 2
# 检查所有说话人都被包含
speakers = set(seg["speaker"] for seg in result)
assert len(speakers) >= 2 # 至少包含两个说话人
# 检查时间合理性(不要求严格连续性,因为有重叠)
for seg in result:
assert seg["start"] < seg["end"], f"Invalid timing: {seg['start']} >= {seg['end']}"
def test_performance_large_segments(self):
"""测试大量段的性能"""
# 生成大量转录段
transcription_segments = []
for i in range(100):
transcription_segments.append({
"start": i * 1.0,
"end": (i + 1) * 1.0,
"text": f"This is segment number {i} with some text content for testing purposes"
})
# 生成大量说话人段
speaker_segments = []
for i in range(200):
speaker_segments.append({
"start": i * 0.5,
"end": (i + 1) * 0.5,
"speaker": f"SPEAKER_{i % 5:02d}" # 5个说话人循环
})
# 测量执行时间
start_time = time.time()
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
execution_time = time.time() - start_time
# 应该在合理时间内完成(< 1秒)
assert execution_time < 1.0, f"Performance too slow: {execution_time:.2f}s"
# 检查结果合理性
assert len(result) > 0
assert len(result) <= len(transcription_segments) * 2 # 每个段最多分割一次
def test_no_overlap_at_all(self):
"""测试完全没有重叠的情况"""
transcription_segments = [
{
"start": 0.0,
"end": 2.0,
"text": "First segment"
},
{
"start": 5.0,
"end": 7.0,
"text": "Second segment"
}
]
speaker_segments = [
{
"start": 3.0,
"end": 4.0,
"speaker": "SPEAKER_00"
},
{
"start": 8.0,
"end": 9.0,
"speaker": "SPEAKER_01"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 应该保持原始段,但没有说话人信息
assert len(result) == 2
assert result[0]["speaker"] is None
assert result[1]["speaker"] is None
assert result[0]["text"] == "First segment"
assert result[1]["text"] == "Second segment"
def test_exact_boundary_matching(self):
"""测试精确边界匹配"""
transcription_segments = [
{
"start": 1.0,
"end": 3.0,
"text": "Exact boundary match"
}
]
speaker_segments = [
{
"start": 1.0,
"end": 3.0,
"speaker": "SPEAKER_00"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 1
assert result[0]["speaker"] == "SPEAKER_00"
assert result[0]["start"] == 1.0
assert result[0]["end"] == 3.0
assert result[0]["text"] == "Exact boundary match"
def test_floating_point_precision(self):
"""测试浮点数精度问题"""
transcription_segments = [
{
"start": 0.1,
"end": 0.3,
"text": "Precision test"
}
]
speaker_segments = [
{
"start": 0.10000001, # 微小差异
"end": 0.29999999,
"speaker": "SPEAKER_00"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 1
assert result[0]["speaker"] == "SPEAKER_00"
def test_text_distribution_accuracy(self):
"""测试文本分配的准确性"""
transcription_segments = [
{
"start": 0.0,
"end": 10.0,
"text": "One two three four five six seven eight nine ten"
}
]
speaker_segments = [
{
"start": 0.0,
"end": 8.0, # 80% 的时间
"speaker": "SPEAKER_00"
},
{
"start": 8.0,
"end": 10.0, # 20% 的时间
"speaker": "SPEAKER_01"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 2
# SPEAKER_00 应该得到大约 80% 的文本
total_text_length = len("One two three four five six seven eight nine ten")
speaker_00_length = len(result[0]["text"])
speaker_01_length = len(result[1]["text"])
speaker_00_ratio = speaker_00_length / total_text_length
speaker_01_ratio = speaker_01_length / total_text_length
# 允许一定的误差范围
assert 0.6 <= speaker_00_ratio <= 0.9, f"SPEAKER_00 ratio: {speaker_00_ratio:.2f}"
assert 0.1 <= speaker_01_ratio <= 0.4, f"SPEAKER_01 ratio: {speaker_01_ratio:.2f}"
def test_single_word_segments(self):
"""测试单词级别的分割"""
transcription_segments = [
{
"start": 0.0,
"end": 4.0,
"text": "Hello world"
}
]
speaker_segments = [
{
"start": 0.0,
"end": 2.0,
"speaker": "SPEAKER_00"
},
{
"start": 2.0,
"end": 4.0,
"speaker": "SPEAKER_01"
}
]
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 2
# 检查单词没有被切断
all_words = []
for seg in result:
if seg["text"]:
all_words.extend(seg["text"].split())
# 应该包含原始的两个完整单词
original_words = ["Hello", "world"]
for word in original_words:
assert any(word in all_words for word in original_words), \
f"Words not preserved correctly: {all_words}"
def test_empty_speaker_segments(self):
"""测试空的说话人段列表"""
transcription_segments = [
{
"start": 0.0,
"end": 5.0,
"text": "No speakers detected"
}
]
speaker_segments = []
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
assert len(result) == 1
assert result[0]["speaker"] is None
assert result[0]["text"] == "No speakers detected"
def test_malformed_input_handling(self):
"""测试处理格式错误的输入"""
# 缺少必要字段的转录段
transcription_segments = [
{
"start": 0.0,
"end": 5.0,
# 缺少 "text" 字段
}
]
speaker_segments = [
{
"start": 0.0,
"end": 5.0,
"speaker": "SPEAKER_00"
}
]
# 应该能够处理而不崩溃
try:
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
# 如果没有异常,检查结果
assert len(result) >= 0
except Exception as e:
# 如果有异常,确保是预期的类型
assert isinstance(e, (KeyError, AttributeError, TypeError))
class TestSpeakerSegmentationBenchmark:
"""性能基准测试"""
def setup_method(self):
"""设置测试环境"""
self.service = TranscriptionService()
@pytest.mark.benchmark
def test_benchmark_typical_podcast(self):
"""基准测试:典型播客场景(30分钟,3个说话人)"""
# 模拟30分钟的播客:每5秒一个转录段
transcription_segments = []
for i in range(360): # 30分钟 * 60秒 / 5秒
start_time = i * 5.0
end_time = (i + 1) * 5.0
transcription_segments.append({
"start": start_time,
"end": end_time,
"text": f"This is a podcast segment number {i} with typical conversation content"
})
# 模拟说话人切换:平均每30秒切换一次说话人
speaker_segments = []
current_time = 0.0
current_speaker = 0
while current_time < 1800.0: # 30分钟
segment_duration = random.uniform(15.0, 45.0) # 15-45秒的说话段
speaker_segments.append({
"start": current_time,
"end": min(current_time + segment_duration, 1800.0),
"speaker": f"SPEAKER_{current_speaker:02d}"
})
current_time += segment_duration
current_speaker = (current_speaker + 1) % 3 # 3个说话人循环
# 执行基准测试
start_time = time.time()
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments)
execution_time = time.time() - start_time
print(f"\n📊 Benchmark Results:")
print(f" Input segments: {len(transcription_segments)}")
print(f" Speaker segments: {len(speaker_segments)}")
print(f" Output segments: {len(result)}")
print(f" Execution time: {execution_time:.3f}s")
print(f" Segments per second: {len(result)/execution_time:.1f}")
# 性能要求:应该能够处理实时转录(每秒至少处理60个段)
assert execution_time < 2.0, f"Too slow for real-time processing: {execution_time:.3f}s"
assert len(result) > 0, "Should produce some output segments"
if __name__ == "__main__":
# 运行所有测试
pytest.main([__file__, "-v", "--tb=short", "-m", "not benchmark"])
# 单独运行基准测试
print("\n" + "="*60)
print("Running Benchmark Tests...")
print("="*60)
pytest.main([__file__, "-v", "--tb=short", "-m", "benchmark"]) |