|
""" |
|
测试 TranscriptionService 中的 speaker segmentation 功能 |
|
特别是 _merge_speaker_segments 和 _split_transcription_segment 方法 |
|
""" |
|
|
|
import pytest |
|
import tempfile |
|
import os |
|
from typing import List, Dict |
|
from unittest.mock import Mock, patch |
|
|
|
|
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.services.transcription_service import TranscriptionService |
|
|
|
|
|
class TestSpeakerSegmentation: |
|
"""测试说话人分割功能""" |
|
|
|
def setup_method(self): |
|
"""设置测试环境""" |
|
self.service = TranscriptionService() |
|
|
|
def test_single_speaker_segment(self): |
|
"""测试单个说话人的情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
"text": "Hello, this is a test message." |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 1 |
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[0]["text"] == "Hello, this is a test message." |
|
assert result[0]["start"] == 0.0 |
|
assert result[0]["end"] == 5.0 |
|
|
|
def test_no_speaker_detected(self): |
|
"""测试没有检测到说话人的情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
"text": "Hello, this is a test message." |
|
} |
|
] |
|
|
|
speaker_segments = [] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 1 |
|
assert result[0]["speaker"] is None |
|
assert result[0]["text"] == "Hello, this is a test message." |
|
|
|
def test_multiple_speakers_in_single_segment(self): |
|
"""测试单个转录段中包含多个说话人的情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 10.0, |
|
"text": "Hello there how are you today I am doing well thank you for asking" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 4.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 4.0, |
|
"end": 7.0, |
|
"speaker": "SPEAKER_01" |
|
}, |
|
{ |
|
"start": 7.0, |
|
"end": 10.0, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
|
|
assert len(result) == 3 |
|
|
|
|
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[1]["speaker"] == "SPEAKER_01" |
|
assert result[2]["speaker"] == "SPEAKER_00" |
|
|
|
|
|
assert result[0]["start"] == 0.0 |
|
assert result[0]["end"] <= 4.0 |
|
assert result[1]["start"] >= 4.0 |
|
assert result[1]["end"] <= 7.0 |
|
assert result[2]["start"] >= 7.0 |
|
assert result[2]["end"] <= 10.0 |
|
|
|
|
|
combined_text = " ".join([seg["text"] for seg in result]) |
|
original_text = "Hello there how are you today I am doing well thank you for asking" |
|
assert combined_text.replace(" ", " ") == original_text |
|
|
|
def test_overlapping_speakers(self): |
|
"""测试说话人时间重叠的情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 6.0, |
|
"text": "This is a conversation between two people talking simultaneously" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 4.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 2.0, |
|
"end": 6.0, |
|
"speaker": "SPEAKER_01" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
|
|
assert len(result) == 2 |
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[1]["speaker"] == "SPEAKER_01" |
|
|
|
|
|
assert result[0]["start"] == 0.0 |
|
assert result[0]["end"] <= 4.0 |
|
assert result[1]["start"] >= 2.0 |
|
assert result[1]["end"] == 6.0 |
|
|
|
def test_partial_speaker_overlap(self): |
|
"""测试说话人部分重叠转录段的情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 1.0, |
|
"end": 4.0, |
|
"text": "This is in the middle of speaker segment" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 5.0, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 1 |
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[0]["start"] == 1.0 |
|
assert result[0]["end"] == 4.0 |
|
|
|
def test_multiple_transcription_segments_with_speakers(self): |
|
"""测试多个转录段与多个说话人的复杂情况""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 3.0, |
|
"text": "Hello how are you" |
|
}, |
|
{ |
|
"start": 3.0, |
|
"end": 6.0, |
|
"text": "I am fine thank you" |
|
}, |
|
{ |
|
"start": 6.0, |
|
"end": 10.0, |
|
"text": "That is great to hear from you today" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 3.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 3.0, |
|
"end": 6.0, |
|
"speaker": "SPEAKER_01" |
|
}, |
|
{ |
|
"start": 6.0, |
|
"end": 8.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 8.0, |
|
"end": 10.0, |
|
"speaker": "SPEAKER_01" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
|
|
assert len(result) == 4 |
|
|
|
|
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[0]["text"] == "Hello how are you" |
|
assert result[1]["speaker"] == "SPEAKER_01" |
|
assert result[1]["text"] == "I am fine thank you" |
|
|
|
|
|
assert result[2]["speaker"] == "SPEAKER_00" |
|
assert result[3]["speaker"] == "SPEAKER_01" |
|
|
|
|
|
combined_third_segment_text = result[2]["text"] + " " + result[3]["text"] |
|
assert "That is great to hear from you today" in combined_third_segment_text |
|
|
|
def test_word_boundary_preservation(self): |
|
"""测试文本分割时保持单词边界的功能""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 8.0, |
|
"text": "The quick brown fox jumps over the lazy dog" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 4.0, |
|
"speaker": "SPEAKER_00" |
|
}, |
|
{ |
|
"start": 4.0, |
|
"end": 8.0, |
|
"speaker": "SPEAKER_01" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 2 |
|
|
|
|
|
for segment in result: |
|
text = segment["text"] |
|
|
|
if text: |
|
words = text.split() |
|
assert len(words) > 0, f"Segment should contain complete words: '{text}'" |
|
|
|
assert not any(word.endswith('-') or word.startswith('-') for word in words), \ |
|
f"Should not contain partial words: {words}" |
|
|
|
def test_empty_text_handling(self): |
|
"""测试空文本的处理""" |
|
transcription_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 2.0, |
|
"text": "" |
|
} |
|
] |
|
|
|
speaker_segments = [ |
|
{ |
|
"start": 0.0, |
|
"end": 2.0, |
|
"speaker": "SPEAKER_00" |
|
} |
|
] |
|
|
|
result = self.service._merge_speaker_segments(transcription_segments, speaker_segments) |
|
|
|
assert len(result) == 1 |
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[0]["text"] == "" |
|
|
|
def test_split_transcription_segment_direct(self): |
|
"""直接测试 _split_transcription_segment 方法""" |
|
trans_seg = { |
|
"start": 0.0, |
|
"end": 6.0, |
|
"text": "Hello there how are you doing today" |
|
} |
|
|
|
overlapping_speakers = [ |
|
{ |
|
"speaker": "SPEAKER_00", |
|
"start": 0.0, |
|
"end": 3.0, |
|
"overlap_start": 0.0, |
|
"overlap_end": 3.0, |
|
"overlap_duration": 3.0 |
|
}, |
|
{ |
|
"speaker": "SPEAKER_01", |
|
"start": 3.0, |
|
"end": 6.0, |
|
"overlap_start": 3.0, |
|
"overlap_end": 6.0, |
|
"overlap_duration": 3.0 |
|
} |
|
] |
|
|
|
result = self.service._split_transcription_segment( |
|
trans_seg, overlapping_speakers, trans_seg["text"] |
|
) |
|
|
|
assert len(result) == 2 |
|
assert result[0]["speaker"] == "SPEAKER_00" |
|
assert result[1]["speaker"] == "SPEAKER_01" |
|
|
|
|
|
assert result[0]["start"] == 0.0 |
|
assert result[0]["end"] == 3.0 |
|
assert result[1]["start"] == 3.0 |
|
assert result[1]["end"] == 6.0 |
|
|
|
|
|
combined_text = result[0]["text"] + " " + result[1]["text"] |
|
assert "Hello there how are you doing today" in combined_text.replace(" ", " ") |
|
|
|
def test_unequal_speaker_durations(self): |
|
"""测试说话人持续时间不等的情况""" |
|
trans_seg = { |
|
"start": 0.0, |
|
"end": 10.0, |
|
"text": "This is a longer sentence with one speaker talking much longer than the other speaker" |
|
} |
|
|
|
overlapping_speakers = [ |
|
{ |
|
"speaker": "SPEAKER_00", |
|
"start": 0.0, |
|
"end": 8.0, |
|
"overlap_start": 0.0, |
|
"overlap_end": 8.0, |
|
"overlap_duration": 8.0 |
|
}, |
|
{ |
|
"speaker": "SPEAKER_01", |
|
"start": 8.0, |
|
"end": 10.0, |
|
"overlap_start": 8.0, |
|
"overlap_end": 10.0, |
|
"overlap_duration": 2.0 |
|
} |
|
] |
|
|
|
result = self.service._split_transcription_segment( |
|
trans_seg, overlapping_speakers, trans_seg["text"] |
|
) |
|
|
|
assert len(result) == 2 |
|
|
|
|
|
speaker_00_text_length = len(result[0]["text"]) |
|
speaker_01_text_length = len(result[1]["text"]) |
|
|
|
assert speaker_00_text_length > speaker_01_text_length, \ |
|
f"SPEAKER_00 should have more text. Got {speaker_00_text_length} vs {speaker_01_text_length}" |
|
|
|
|
|
assert result[0]["end"] == 8.0 |
|
assert result[1]["start"] == 8.0 |
|
|
|
@pytest.mark.integration |
|
def test_full_transcription_with_speaker_splitting(self): |
|
"""集成测试:完整的转录流程与说话人分割""" |
|
|
|
|
|
pytest.skip("Integration test requires actual audio file") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
pytest.main([__file__, "-v", "--tb=short"]) |