|
""" |
|
Test remote GPU transcription functionality |
|
ζ΅θ―θΏη¨GPU转ε½εθ½ |
|
""" |
|
|
|
import pytest |
|
import asyncio |
|
import os |
|
import tempfile |
|
from pathlib import Path |
|
from typing import Dict, Any |
|
|
|
from src.tools.transcription_tools import transcribe_audio_file_tool |
|
from src.services.audio_processing_service import AudioProcessingService |
|
|
|
|
|
class TestRemoteTranscription: |
|
"""Test remote GPU transcription integration""" |
|
|
|
def test_transcription_tools_initialization(self): |
|
"""Test transcription tools initialization""" |
|
print("\nπ§ Testing transcription tools initialization...") |
|
|
|
|
|
assert transcribe_audio_file_tool is not None |
|
|
|
print("β
Transcription tools initialized successfully") |
|
|
|
@pytest.mark.asyncio |
|
async def test_create_sample_audio_file(self, temp_dir: str): |
|
"""Test creating a sample audio file for transcription testing""" |
|
print("\nπ΅ Creating sample audio file for testing...") |
|
|
|
import ffmpeg |
|
|
|
|
|
sample_file = os.path.join(temp_dir, "sample_audio.mp3") |
|
|
|
try: |
|
|
|
( |
|
ffmpeg |
|
.input('sine=frequency=440:duration=5', f='lavfi') |
|
.output(sample_file, acodec='mp3', ar=16000) |
|
.overwrite_output() |
|
.run(quiet=True) |
|
) |
|
|
|
assert os.path.exists(sample_file) |
|
assert os.path.getsize(sample_file) > 0 |
|
|
|
print(f"β
Sample audio file created: {sample_file}") |
|
print(f" File size: {os.path.getsize(sample_file)} bytes") |
|
|
|
return sample_file |
|
|
|
except Exception as e: |
|
print(f"β Failed to create sample audio file: {str(e)}") |
|
pytest.skip(f"Failed to create sample audio file: {str(e)}") |
|
|
|
@pytest.mark.asyncio |
|
async def test_remote_transcription_endpoint_connectivity(self): |
|
"""Test connectivity to remote transcription endpoint""" |
|
print("\nπ Testing remote transcription endpoint connectivity...") |
|
|
|
import aiohttp |
|
import json |
|
|
|
|
|
try: |
|
with open("endpoint_config.json", "r") as f: |
|
endpoint_config = json.load(f) |
|
|
|
endpoint_url = endpoint_config["transcribe_audio"] |
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
|
async with session.head(endpoint_url, timeout=10) as response: |
|
print(f"β
Endpoint connectivity test:") |
|
print(f" URL: {endpoint_url}") |
|
print(f" Status: {response.status}") |
|
print(f" Headers: {dict(response.headers)}") |
|
|
|
|
|
assert response.status in [200, 405, 404], f"Unexpected status: {response.status}" |
|
|
|
except asyncio.TimeoutError: |
|
print(f"β οΈ Endpoint connectivity timeout (expected if Modal is sleeping)") |
|
pytest.skip("Endpoint connectivity timeout") |
|
except Exception as e: |
|
print(f"β οΈ Endpoint connectivity test failed: {str(e)}") |
|
print(" This might be expected if Modal endpoint is not running") |
|
|
|
@pytest.mark.asyncio |
|
async def test_transcription_tool_interface(self, temp_dir: str): |
|
"""Test transcription tool interface with sample audio""" |
|
print("\nπ€ Testing transcription tool interface...") |
|
|
|
|
|
sample_file = await self.test_create_sample_audio_file(temp_dir) |
|
|
|
try: |
|
|
|
result = await transcribe_audio_file_tool( |
|
audio_file_path=sample_file, |
|
model_size="base", |
|
language="en", |
|
output_format="srt", |
|
enable_speaker_diarization=False |
|
) |
|
|
|
print(f"π Transcription tool result:") |
|
print(f" Status: {result.get('processing_status', 'unknown')}") |
|
print(f" Audio file: {result.get('audio_file', 'N/A')}") |
|
print(f" Model used: {result.get('model_used', 'N/A')}") |
|
print(f" Duration: {result.get('audio_duration', 0):.2f}s") |
|
|
|
if result.get("processing_status") == "success": |
|
print(f" TXT file: {result.get('txt_file_path', 'N/A')}") |
|
print(f" SRT file: {result.get('srt_file_path', 'N/A')}") |
|
print(f" Segments: {result.get('segment_count', 0)}") |
|
print("β
Transcription tool interface test successful") |
|
|
|
|
|
if result.get('txt_file_path'): |
|
assert os.path.exists(result['txt_file_path']) |
|
if result.get('srt_file_path'): |
|
assert os.path.exists(result['srt_file_path']) |
|
|
|
else: |
|
print(f" Error: {result.get('error_message', 'Unknown error')}") |
|
print("β οΈ Transcription failed (might be due to remote endpoint)") |
|
|
|
except Exception as e: |
|
print(f"β Transcription tool test failed: {str(e)}") |
|
print(" This might be expected if remote endpoint is not available") |
|
|
|
@pytest.mark.asyncio |
|
async def test_transcription_with_speaker_diarization(self, temp_dir: str): |
|
"""Test transcription with speaker diarization enabled""" |
|
print("\nπ₯ Testing transcription with speaker diarization...") |
|
|
|
|
|
sample_file = await self.test_create_sample_audio_file(temp_dir) |
|
|
|
try: |
|
|
|
result = await transcribe_audio_file_tool( |
|
audio_file_path=sample_file, |
|
model_size="base", |
|
language="auto", |
|
output_format="srt", |
|
enable_speaker_diarization=True |
|
) |
|
|
|
print(f"π Speaker diarization result:") |
|
print(f" Status: {result.get('processing_status', 'unknown')}") |
|
print(f" Speaker diarization enabled: {result.get('speaker_diarization_enabled', False)}") |
|
print(f" Global speaker count: {result.get('global_speaker_count', 0)}") |
|
|
|
if result.get("processing_status") == "success": |
|
speaker_summary = result.get('speaker_summary', {}) |
|
print(f" Speaker summary: {speaker_summary}") |
|
print("β
Speaker diarization test successful") |
|
else: |
|
print(f" Error: {result.get('error_message', 'Unknown error')}") |
|
print("β οΈ Speaker diarization failed (might be due to remote endpoint or HF token)") |
|
|
|
except Exception as e: |
|
print(f"β Speaker diarization test failed: {str(e)}") |
|
print(" This might be expected if HF token is not configured or endpoint unavailable") |
|
|
|
@pytest.mark.asyncio |
|
async def test_different_transcription_models(self, temp_dir: str): |
|
"""Test transcription with different models""" |
|
print("\nπ§ Testing different transcription models...") |
|
|
|
sample_file = await self.test_create_sample_audio_file(temp_dir) |
|
|
|
models_to_test = ["tiny", "base", "small"] |
|
|
|
for model in models_to_test: |
|
print(f"\n Testing model: {model}") |
|
try: |
|
result = await transcribe_audio_file_tool( |
|
audio_file_path=sample_file, |
|
model_size=model, |
|
language="auto", |
|
output_format="txt", |
|
enable_speaker_diarization=False |
|
) |
|
|
|
if result.get("processing_status") == "success": |
|
print(f" β
{model} model: Success") |
|
print(f" Segments: {result.get('segment_count', 0)}") |
|
print(f" Duration: {result.get('audio_duration', 0):.2f}s") |
|
else: |
|
print(f" β οΈ {model} model: Failed - {result.get('error_message', 'Unknown')}") |
|
|
|
except Exception as e: |
|
print(f" β {model} model: Exception - {str(e)}") |
|
|
|
@pytest.mark.asyncio |
|
async def test_transcription_output_formats(self, temp_dir: str): |
|
"""Test different transcription output formats""" |
|
print("\nπ Testing different output formats...") |
|
|
|
sample_file = await self.test_create_sample_audio_file(temp_dir) |
|
|
|
formats_to_test = ["txt", "srt", "json"] |
|
|
|
for format_type in formats_to_test: |
|
print(f"\n Testing format: {format_type}") |
|
try: |
|
result = await transcribe_audio_file_tool( |
|
audio_file_path=sample_file, |
|
model_size="base", |
|
language="auto", |
|
output_format=format_type, |
|
enable_speaker_diarization=False |
|
) |
|
|
|
if result.get("processing_status") == "success": |
|
print(f" β
{format_type} format: Success") |
|
|
|
|
|
if format_type == "txt" and result.get('txt_file_path'): |
|
assert os.path.exists(result['txt_file_path']) |
|
elif format_type == "srt" and result.get('srt_file_path'): |
|
assert os.path.exists(result['srt_file_path']) |
|
|
|
else: |
|
print(f" β οΈ {format_type} format: Failed - {result.get('error_message', 'Unknown')}") |
|
|
|
except Exception as e: |
|
print(f" β {format_type} format: Exception - {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
pytest.main([__file__, "-v", "-s"]) |