diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c425f970670ebfe180e9b6beaa70f2b76fdd0fe6 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,13 @@ +""" +PodcastMcpGradio - Podcast Processing and Analysis Framework + +A comprehensive framework for podcast downloading, transcription, and analysis +with MCP (Model Context Protocol) integration and Gradio UI. +""" + +__version__ = "2.0.0" +__author__ = "PodcastMcpGradio Team" +__description__ = "Podcast Processing and Analysis Framework" + +# Core modules will be imported as needed +__all__ = [] \ No newline at end of file diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff586ad659f2833145fe0d432f51765f8c56ab6d Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/__pycache__/app.cpython-310.pyc b/src/__pycache__/app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0de4d7fc55655330140818deac63e58669cb3049 Binary files /dev/null and b/src/__pycache__/app.cpython-310.pyc differ diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c93a6a0f6131c07e66825f0cfb90ab117d64584 --- /dev/null +++ b/src/adapters/__init__.py @@ -0,0 +1,13 @@ +""" +Adapters for different transcription backends +""" + +from .transcription_adapter_factory import TranscriptionAdapterFactory +from .local_adapter import LocalTranscriptionAdapter +from .modal_adapter import ModalTranscriptionAdapter + +__all__ = [ + "TranscriptionAdapterFactory", + "LocalTranscriptionAdapter", + "ModalTranscriptionAdapter" +] \ No newline at end of file diff --git a/src/adapters/__pycache__/__init__.cpython-310.pyc b/src/adapters/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4395b0a5cf4c095200e2f386ff2b3b16929253f Binary files /dev/null and b/src/adapters/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/adapters/__pycache__/local_adapter.cpython-310.pyc b/src/adapters/__pycache__/local_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc9eb37d6c4d6621e6363583704d0d969055778f Binary files /dev/null and b/src/adapters/__pycache__/local_adapter.cpython-310.pyc differ diff --git a/src/adapters/__pycache__/modal_adapter.cpython-310.pyc b/src/adapters/__pycache__/modal_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eede1d673851705dce84545257611772f07aec22 Binary files /dev/null and b/src/adapters/__pycache__/modal_adapter.cpython-310.pyc differ diff --git a/src/adapters/__pycache__/transcription_adapter_factory.cpython-310.pyc b/src/adapters/__pycache__/transcription_adapter_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..769a2ddf141ee72d6bad898313b7ebfe50eeb894 Binary files /dev/null and b/src/adapters/__pycache__/transcription_adapter_factory.cpython-310.pyc differ diff --git a/src/adapters/local_adapter.py b/src/adapters/local_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..21935112df0d2d2c181523e90b1371b94f2a27b4 --- /dev/null +++ b/src/adapters/local_adapter.py @@ -0,0 +1,93 @@ +""" +Local transcription adapter for direct processing +""" + +import asyncio +from typing import List, Optional + +from ..interfaces.transcriber import ITranscriber, TranscriptionResult +from ..utils.config import AudioProcessingConfig +from ..utils.errors import TranscriptionError + + +class LocalTranscriptionAdapter(ITranscriber): + """Adapter for local transcription processing""" + + def __init__(self, config: Optional[AudioProcessingConfig] = None): + self.config = config or AudioProcessingConfig() + + async def transcribe( + self, + audio_file_path: str, + model_size: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False + ) -> TranscriptionResult: + """Transcribe audio using local processing""" + + try: + # Use the new AudioProcessingService instead of old methods + from ..services.audio_processing_service import AudioProcessingService + from ..models.services import AudioProcessingRequest + + print(f"🔄 Starting local transcription for: {audio_file_path}") + print(f"🚀 Running transcription with {model_size} model...") + + # Create service and request + audio_service = AudioProcessingService() + request = AudioProcessingRequest( + audio_file_path=audio_file_path, + model_size=model_size, + language=language, + output_format="json", + enable_speaker_diarization=enable_speaker_diarization + ) + + # Process transcription + result = audio_service.transcribe_full_audio(request) + + # Convert service result to adapter format + return self._convert_service_result(result) + + except Exception as e: + raise TranscriptionError( + f"Local transcription failed: {str(e)}", + model=model_size, + audio_file=audio_file_path + ) + + def get_supported_models(self) -> List[str]: + """Get list of supported model sizes""" + return list(self.config.whisper_models.keys()) + + def get_supported_languages(self) -> List[str]: + """Get list of supported language codes""" + # This would normally come from Whisper's supported languages + return ["en", "zh", "ja", "ko", "es", "fr", "de", "ru", "auto"] + + def _convert_service_result(self, service_result) -> TranscriptionResult: + """Convert service result format to TranscriptionResult""" + from ..interfaces.transcriber import TranscriptionSegment + + # Extract segments from service result if available + segments = [] + if hasattr(service_result, 'segments') and service_result.segments: + for seg in service_result.segments: + segments.append(TranscriptionSegment( + start=getattr(seg, 'start', 0), + end=getattr(seg, 'end', 0), + text=getattr(seg, 'text', ''), + speaker=getattr(seg, 'speaker', None) + )) + + return TranscriptionResult( + text=getattr(service_result, 'text', ''), + segments=segments, + language=getattr(service_result, 'language_detected', 'unknown'), + model_used=getattr(service_result, 'model_used', 'unknown'), + audio_duration=getattr(service_result, 'audio_duration', 0), + processing_time=getattr(service_result, 'processing_time', 0), + speaker_diarization_enabled=getattr(service_result, 'speaker_diarization_enabled', False), + global_speaker_count=getattr(service_result, 'global_speaker_count', 0), + error_message=getattr(service_result, 'error_message', None) + ) \ No newline at end of file diff --git a/src/adapters/modal_adapter.py b/src/adapters/modal_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b1aa1674bd90fd3262396f7cd4635564a6503d --- /dev/null +++ b/src/adapters/modal_adapter.py @@ -0,0 +1,126 @@ +""" +Modal transcription adapter for remote processing +""" + +import requests +import base64 +import pathlib +from typing import List, Optional + +from ..interfaces.transcriber import ITranscriber, TranscriptionResult, TranscriptionSegment +from ..utils.config import AudioProcessingConfig +from ..utils.errors import TranscriptionError + + +class ModalTranscriptionAdapter(ITranscriber): + """Adapter for Modal remote transcription processing""" + + def __init__(self, config: Optional[AudioProcessingConfig] = None, endpoint_url: Optional[str] = None): + self.config = config or AudioProcessingConfig() + self.endpoint_url = endpoint_url + + async def transcribe( + self, + audio_file_path: str, + model_size: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False + ) -> TranscriptionResult: + """Transcribe audio using Modal endpoint""" + + if not self.endpoint_url: + raise TranscriptionError( + "Modal endpoint URL not configured", + model=model_size, + audio_file=audio_file_path + ) + + try: + # Read and encode audio file + audio_path = pathlib.Path(audio_file_path) + if not audio_path.exists(): + raise TranscriptionError( + f"Audio file not found: {audio_file_path}", + audio_file=audio_file_path + ) + + with open(audio_path, 'rb') as f: + audio_data = f.read() + + audio_base64 = base64.b64encode(audio_data).decode('utf-8') + + # Prepare request data + request_data = { + "audio_file_data": audio_base64, + "audio_file_name": audio_path.name, + "model_size": model_size, + "language": language, + "output_format": "json", + "enable_speaker_diarization": enable_speaker_diarization + } + + print(f"🔄 Sending transcription request to Modal endpoint") + print(f"📁 File: {audio_file_path} ({len(audio_data) / (1024*1024):.2f} MB)") + print(f"🔧 Model: {model_size}, Speaker diarization: {enable_speaker_diarization}") + + # Make request to Modal endpoint + response = requests.post( + self.endpoint_url, + json=request_data, + timeout=1800 # 30 minutes timeout + ) + + response.raise_for_status() + result = response.json() + + print(f"✅ Modal transcription completed") + + # Convert result to TranscriptionResult format + return self._convert_modal_result(result) + + except requests.exceptions.RequestException as e: + raise TranscriptionError( + f"Failed to call Modal endpoint: {str(e)}", + model=model_size, + audio_file=audio_file_path + ) + except Exception as e: + raise TranscriptionError( + f"Modal transcription failed: {str(e)}", + model=model_size, + audio_file=audio_file_path + ) + + def get_supported_models(self) -> List[str]: + """Get list of supported model sizes""" + return list(self.config.whisper_models.keys()) + + def get_supported_languages(self) -> List[str]: + """Get list of supported language codes""" + return ["en", "zh", "ja", "ko", "es", "fr", "de", "ru", "auto"] + + def _convert_modal_result(self, modal_result: dict) -> TranscriptionResult: + """Convert Modal result format to TranscriptionResult""" + + # Extract segments if available + segments = [] + if "segments" in modal_result: + for seg in modal_result["segments"]: + segments.append(TranscriptionSegment( + start=seg.get("start", 0), + end=seg.get("end", 0), + text=seg.get("text", ""), + speaker=seg.get("speaker") + )) + + return TranscriptionResult( + text=modal_result.get("text", ""), + segments=segments, + language=modal_result.get("language_detected", "unknown"), + model_used=modal_result.get("model_used", "unknown"), + audio_duration=modal_result.get("audio_duration", 0), + processing_time=modal_result.get("processing_time", 0), + speaker_diarization_enabled=modal_result.get("speaker_diarization_enabled", False), + global_speaker_count=modal_result.get("global_speaker_count", 0), + error_message=modal_result.get("error_message") + ) \ No newline at end of file diff --git a/src/adapters/transcription_adapter_factory.py b/src/adapters/transcription_adapter_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..469eb57b8c4a3412bbe9a90a49109b7cc57467d4 --- /dev/null +++ b/src/adapters/transcription_adapter_factory.py @@ -0,0 +1,77 @@ +""" +Factory for creating transcription adapters +""" + +import os +from typing import Optional + +from ..interfaces.transcriber import ITranscriber +from ..utils.config import AudioProcessingConfig +from ..utils.errors import ConfigurationError +from .local_adapter import LocalTranscriptionAdapter +from .modal_adapter import ModalTranscriptionAdapter + + +class TranscriptionAdapterFactory: + """Factory for creating appropriate transcription adapters""" + + @staticmethod + def create_adapter( + deployment_mode: str = "auto", + config: Optional[AudioProcessingConfig] = None, + endpoint_url: Optional[str] = None + ) -> ITranscriber: + """ + Create transcription adapter based on deployment mode + + Args: + deployment_mode: "local", "modal", or "auto" + config: Configuration object + endpoint_url: Modal endpoint URL (for modal/auto mode) + + Returns: + ITranscriber: Appropriate transcription adapter + """ + + config = config or AudioProcessingConfig() + + # Auto mode: decide based on environment and endpoint availability + if deployment_mode == "auto": + if endpoint_url: + print(f"🌐 Auto mode: Using Modal adapter with endpoint {endpoint_url}") + return ModalTranscriptionAdapter(config=config, endpoint_url=endpoint_url) + else: + print(f"🏠 Auto mode: Using Local adapter (no endpoint configured)") + return LocalTranscriptionAdapter(config=config) + + # Explicit local mode + elif deployment_mode == "local": + print(f"🏠 Using Local transcription adapter") + return LocalTranscriptionAdapter(config=config) + + # Explicit modal mode + elif deployment_mode == "modal": + if not endpoint_url: + raise ConfigurationError( + "Modal endpoint URL is required for modal mode", + config_key="endpoint_url" + ) + print(f"🌐 Using Modal transcription adapter with endpoint {endpoint_url}") + return ModalTranscriptionAdapter(config=config, endpoint_url=endpoint_url) + + else: + raise ConfigurationError( + f"Unsupported deployment mode: {deployment_mode}. Use 'local', 'modal', or 'auto'", + config_key="deployment_mode" + ) + + @staticmethod + def _detect_deployment_mode() -> str: + """Auto-detect deployment mode based on environment""" + import os + + # Check if running in Modal environment + if os.environ.get("MODAL_TASK_ID"): + return "local" # We're inside Modal, use local processing + else: + return "modal" # We're outside Modal, use remote endpoint \ No newline at end of file diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d23c0c8bbd447d9e8714c13870224d8b410afa --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,5 @@ +""" +API Module - External interfaces and endpoints +""" + +__all__ = [] \ No newline at end of file diff --git a/src/api/__pycache__/__init__.cpython-310.pyc b/src/api/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..022495b0e5c3aa105f67210e6c768440138b0599 Binary files /dev/null and b/src/api/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/api/__pycache__/transcription_api.cpython-310.pyc b/src/api/__pycache__/transcription_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..093dfbbfd7d6d47d5b2fc7087add548eb5132fbc Binary files /dev/null and b/src/api/__pycache__/transcription_api.cpython-310.pyc differ diff --git a/src/api/transcription_api.py b/src/api/transcription_api.py new file mode 100644 index 0000000000000000000000000000000000000000..4cae98f0af21481681e51e1ec8c2744903b1bca9 --- /dev/null +++ b/src/api/transcription_api.py @@ -0,0 +1,112 @@ +""" +Transcription API module +""" + +import os +from typing import Optional, Dict, Any + +from ..adapters import TranscriptionAdapterFactory +from ..services import TranscriptionService +from ..core import FFmpegAudioSplitter +from ..utils import AudioProcessingConfig, AudioProcessingError + + +class TranscriptionAPI: + """High-level API for transcription operations""" + + def __init__(self, config: Optional[AudioProcessingConfig] = None): + self.config = config or AudioProcessingConfig() + self.transcription_service = None + self._initialize_service() + + def _initialize_service(self): + """Initialize transcription service with appropriate adapter""" + try: + # Get endpoint URL from config file if available + endpoint_url = self._get_endpoint_url() + + # Create appropriate adapter + transcriber = TranscriptionAdapterFactory.create_adapter( + deployment_mode="auto", + config=self.config, + endpoint_url=endpoint_url + ) + + # Create audio splitter + audio_splitter = FFmpegAudioSplitter() + + # Create transcription service + self.transcription_service = TranscriptionService( + transcriber=transcriber, + audio_splitter=audio_splitter, + speaker_detector=None, # TODO: Add speaker detector when implemented + config=self.config + ) + + except Exception as e: + print(f"⚠️ Failed to initialize transcription service: {e}") + raise AudioProcessingError(f"Service initialization failed: {e}") + + def _get_endpoint_url(self) -> Optional[str]: + """Get Modal endpoint URL from configuration""" + try: + import json + config_file = "endpoint_config.json" + if os.path.exists(config_file): + with open(config_file, 'r') as f: + config = json.load(f) + return config.get("transcribe_audio") + except Exception: + pass + return None + + async def transcribe_audio_file( + self, + audio_file_path: str, + model_size: str = "turbo", + language: Optional[str] = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False + ) -> Dict[str, Any]: + """Transcribe audio file using the configured service""" + + if not self.transcription_service: + raise AudioProcessingError("Transcription service not initialized") + + return await self.transcription_service.transcribe_audio_file( + audio_file_path=audio_file_path, + model_size=model_size, + language=language, + output_format=output_format, + enable_speaker_diarization=enable_speaker_diarization + ) + + +# Create global API instance +_api_instance = None + +def get_transcription_api() -> TranscriptionAPI: + """Get global transcription API instance""" + global _api_instance + if _api_instance is None: + _api_instance = TranscriptionAPI() + return _api_instance + +async def transcribe_audio_adaptive_sync( + audio_file_path: str, + model_size: str = "turbo", + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False +) -> Dict[str, Any]: + """ + Adaptive transcription function that routes to appropriate backend + """ + api = get_transcription_api() + return await api.transcribe_audio_file( + audio_file_path=audio_file_path, + model_size=model_size, + language=language, + output_format=output_format, + enable_speaker_diarization=enable_speaker_diarization + ) \ No newline at end of file diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000000000000000000000000000000000000..034ff50ddb702e3b33c0c562cf497319a79e5a7f --- /dev/null +++ b/src/app.py @@ -0,0 +1,169 @@ +# FastAPI + Gradio + FastMCP MCP server main entry point + +import modal +from contextlib import asynccontextmanager +from fastapi import FastAPI +from gradio.routes import mount_gradio_app +import os +from dotenv import load_dotenv +import uvicorn +from mcp.server.fastmcp import FastMCP + +# Import modules +from .tools import mcp_tools # Import the module, not get_mcp_server function +from .ui.gradio_ui import create_gradio_interface +from .config.config import is_modal_mode, is_local_mode + +# Always import modal config since this module might be imported in modal context +try: + from .config.modal_config import app, image, volume, cache_dir, secrets + _modal_available = True +except ImportError: + _modal_available = False + +# ==================== Application Creation Function ==================== + +def create_app(): + """Create and return complete Gradio + MCP application""" + + print("🚀 Starting Gradio + FastMCP server") + + # Create FastMCP server with new tools + mcp = FastMCP("Podcast MCP") + + # Register tools using the new service architecture + @mcp.tool(description="Transcribe audio files to text using Whisper model with speaker diarization support") + async def transcribe_audio_file_tool( + audio_file_path: str, + model_size: str = "turbo", + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False + ): + return await mcp_tools.transcribe_audio_file( + audio_file_path, model_size, language, output_format, enable_speaker_diarization + ) + + @mcp.tool(description="Download Apple Podcast audio files") + async def download_apple_podcast_tool(url: str): + return await mcp_tools.download_apple_podcast(url) + + @mcp.tool(description="Download XiaoYuZhou podcast audio files") + async def download_xyz_podcast_tool(url: str): + return await mcp_tools.download_xyz_podcast(url) + + @mcp.tool(description="Scan directory for MP3 audio files") + async def get_mp3_files_tool(directory: str): + return await mcp_tools.get_mp3_files(directory) + + @mcp.tool(description="Get basic file information") + async def get_file_info_tool(file_path: str): + return await mcp_tools.get_file_info(file_path) + + @mcp.tool(description="Read text file content in segments") + async def read_text_file_segments_tool( + file_path: str, + chunk_size: int = 65536, + start_position: int = 0 + ): + return await mcp_tools.read_text_file_segments(file_path, chunk_size, start_position) + + # Create FastAPI wrapper + fastapi_wrapper = FastAPI( + title="Modal AudioTranscriber MCP", + description="Gradio UI + FastMCP Tool + Modal Integration AudioTranscriber MCP", + version="1.0.0", + lifespan=lambda app: mcp.session_manager.run() + ) + + # Get FastMCP's streamable HTTP app + mcp_app = mcp.streamable_http_app() + + # Mount FastMCP application to /api path + fastapi_wrapper.mount("/api", mcp_app) + + # Create Gradio interface + ui_app = create_gradio_interface() + + # Use Gradio's standard mounting approach + final_app = mount_gradio_app( + app=fastapi_wrapper, + blocks=ui_app, + path="/", + app_kwargs={ + "docs_url": "/docs", + "redoc_url": "/redoc", + } + ) + + print("✅ Server startup completed") + print("🎨 Gradio UI: /") + print("🔧 MCP Streamable HTTP: /api/mcp") + print(f"📝 Server name: {mcp.name}") + + return final_app + +# ==================== Modal Deployment Configuration ==================== + +# Create a separate Modal app for the Gradio interface +if _modal_available: + gradio_mcp_app = modal.App(name="gradio-mcp-ui") + + @gradio_mcp_app.function( + image=image, + cpu=2, # Adequate CPU for UI operations + memory=4096, # 4GB memory for stable UI performance + max_containers=5, # Reduced to control resource usage + min_containers=1, # Keep minimum containers for faster response + scaledown_window=600, # 20 minutes before scaling down + timeout=1800, # 30 minutes timeout to prevent preemption + volumes={cache_dir: volume}, + secrets=secrets, + ) + @modal.concurrent(max_inputs=100) + @modal.asgi_app() + def app_entry(): + """Modal deployment function - create and return complete Gradio + MCP application""" + return create_app() + +# ==================== Main Entry Point ==================== + +def main(): + """Main entry point for all deployment modes""" + + if is_modal_mode(): + print("☁️ Modal mode: Use 'modal deploy src.app::gradio_mcp_app'") + return None + else: + print("🏠 Starting in local mode") + print("💡 GPU functions will be routed to Modal endpoints") + + app = create_app() + return app + +def run_local(): + """Run local server with uvicorn (for direct execution)""" + app = main() + if app: + uvicorn.run( + app, + host="0.0.0.0", + port=8000, + reload=False + ) + +# ==================== Hugging Face Spaces Support ==================== + +# For Hugging Face Spaces, directly create the app +def get_app(): + """Get app instance for HF Spaces""" + if "DEPLOYMENT_MODE" not in os.environ: + os.environ["DEPLOYMENT_MODE"] = "local" + return main() + +# Create app for HF Spaces when imported +if __name__ != "__main__": + app = get_app() + +if __name__ == "__main__": + run_local() \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1816ab891fa7988e69637482bb63e0d1852e2d83 --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,5 @@ +""" +Config Module - Configuration management +""" + +__all__ = [] \ No newline at end of file diff --git a/src/config/__pycache__/__init__.cpython-310.pyc b/src/config/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70370d6cc45630fed58a924efb71a9cc50579e1f Binary files /dev/null and b/src/config/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/config/__pycache__/config.cpython-310.pyc b/src/config/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..890c9d6faf3950d21fd609e953336afba4b2eca9 Binary files /dev/null and b/src/config/__pycache__/config.cpython-310.pyc differ diff --git a/src/config/__pycache__/modal_config.cpython-310.pyc b/src/config/__pycache__/modal_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cc47086458533ae4d608a5caa1c4d8b20ca0476 Binary files /dev/null and b/src/config/__pycache__/modal_config.cpython-310.pyc differ diff --git a/src/config/config.py b/src/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c63cb90015be035b6e05d1f4c387fd3518a39a9 --- /dev/null +++ b/src/config/config.py @@ -0,0 +1,81 @@ +""" +Deployment configuration for Gradio + MCP Server +Supports two deployment modes: +1. Local mode: Gradio runs locally, GPU functions call Modal endpoints +2. Modal mode: Gradio runs on Modal, GPU functions run locally on Modal +""" + +import os +from enum import Enum +from typing import Optional + +class DeploymentMode(Enum): + LOCAL = "local" # Local Gradio + Remote GPU (Modal endpoints) + MODAL = "modal" # Modal Gradio + Local GPU (Modal functions) + +# Get deployment mode from environment variable +DEPLOYMENT_MODE = DeploymentMode(os.getenv("DEPLOYMENT_MODE", "local")) + +# Modal endpoints configuration +MODAL_APP_NAME = "gradio-mcp-server" + +# Endpoint URLs (will be set when deployed) +ENDPOINTS = { + "transcribe_audio": None, # Will be filled with actual endpoint URL +} + +def get_deployment_mode() -> DeploymentMode: + """Get current deployment mode""" + return DEPLOYMENT_MODE + +def is_local_mode() -> bool: + """Check if running in local mode""" + return DEPLOYMENT_MODE == DeploymentMode.LOCAL + +def is_modal_mode() -> bool: + """Check if running in modal mode""" + return DEPLOYMENT_MODE == DeploymentMode.MODAL + +def set_endpoint_url(endpoint_name: str, url: str): + """Set endpoint URL for local mode""" + global ENDPOINTS + ENDPOINTS[endpoint_name] = url + +def get_endpoint_url(endpoint_name: str) -> Optional[str]: + """Get endpoint URL for local mode""" + return ENDPOINTS.get(endpoint_name) + +def get_transcribe_endpoint_url() -> Optional[str]: + """Get transcription endpoint URL""" + return get_endpoint_url("transcribe_audio") + +# Environment-specific cache directory +def get_cache_dir() -> str: + """Get cache directory based on deployment mode""" + if is_modal_mode(): + return "/root/cache" + else: + # Local mode - use user's home directory + home_dir = os.path.expanduser("~") + cache_dir = os.path.join(home_dir, ".gradio_mcp_cache") + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + +# Auto-load endpoint configuration in local mode +if is_local_mode(): + import json + config_file = "endpoint_config.json" + if os.path.exists(config_file): + try: + with open(config_file, 'r') as f: + config = json.load(f) + for endpoint_name, url in config.items(): + set_endpoint_url(endpoint_name, url) + print(f"✅ Loaded endpoint configuration from {config_file}") + except Exception as e: + print(f"⚠️ Failed to load endpoint configuration: {e}") + else: + print(f"⚠️ No endpoint configuration found. Run 'python deploy_endpoints.py deploy' first.") + +print(f"🚀 Deployment mode: {DEPLOYMENT_MODE.value}") +print(f"📁 Cache directory: {get_cache_dir()}") \ No newline at end of file diff --git a/src/config/modal_config.py b/src/config/modal_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5d404daa089bac8e9e6b0ce21bab834c34e50889 --- /dev/null +++ b/src/config/modal_config.py @@ -0,0 +1,210 @@ +import modal +import os + +# Create Modal application +app = modal.App(name="gradio-mcp-server") + +# Try to get Hugging Face token from Modal secrets (required for speaker diarization) +try: + hf_secret = modal.Secret.from_name("huggingface-secret") + print("✅ Found Hugging Face secret configuration") +except Exception: + hf_secret = None + print("⚠️ Hugging Face secret not found, speaker diarization will be disabled") + +# Create mounted volume +volume = modal.Volume.from_name("cache-volume", create_if_missing=True) +cache_dir = "/root/cache" + +# Model preloading function +def download_models() -> None: + """Download and cache Whisper and speaker diarization models""" + import whisper + import os + from pathlib import Path + + # Create model cache directory + model_cache_dir = Path("/model") + model_cache_dir.mkdir(exist_ok=True) + + print("📥 Downloading Whisper turbo model...") + # Download and cache Whisper turbo model + whisper_model = whisper.load_model("turbo", download_root="/model") + print("✅ Whisper turbo model downloaded and cached") + + # Download speaker diarization models if HF token is available + if os.environ.get("HF_TOKEN"): + try: + print("📥 Downloading speaker diarization models...") + from pyannote.audio import Pipeline, Model + from pyannote.audio.core.inference import Inference + import torch + + # Set proper cache directory for pyannote + os.environ["PYANNOTE_CACHE"] = "/model/speaker-diarization" + + # Download and cache speaker diarization pipeline + # This will automatically cache to the PYANNOTE_CACHE directory + pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=os.environ["HF_TOKEN"], + cache_dir="/model/speaker-diarization" + ) + + # Preload speaker embedding model for speaker identification + print("📥 Downloading speaker embedding model...") + embedding_model = Model.from_pretrained( + "pyannote/embedding", + use_auth_token=os.environ["HF_TOKEN"], + cache_dir="/model/speaker-embedding" + ) + + # Set device for models + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + embedding_model.to(device) + embedding_model.eval() + + # Create inference object for embedding extraction + inference = Inference(embedding_model, window="whole") + + # Verify the pipeline works + print("🧪 Testing speaker diarization pipeline...") + + # Create a simple marker file to indicate successful download + import json + speaker_dir = Path("/model/speaker-diarization") + speaker_dir.mkdir(exist_ok=True, parents=True) + + embedding_dir = Path("/model/speaker-embedding") + embedding_dir.mkdir(exist_ok=True, parents=True) + + config = { + "model_name": "pyannote/speaker-diarization-3.1", + "embedding_model_name": "pyannote/embedding", + "cached_at": str(speaker_dir), + "embedding_cached_at": str(embedding_dir), + "cache_complete": True, + "embedding_cache_complete": True, + "pyannote_cache_env": "/model/speaker-diarization", + "device": str(device) + } + with open(speaker_dir / "download_complete.json", "w") as f: + json.dump(config, f) + + print("✅ Speaker diarization and embedding models downloaded and cached") + except Exception as e: + print(f"⚠️ Failed to download speaker diarization models: {e}") + else: + print("⚠️ No HF_TOKEN found, skipping speaker diarization model download") + +# Create image environment with model preloading +image = modal.Image.debian_slim(python_version="3.11").apt_install( + # Basic tools + "ffmpeg", + "wget", + "curl", + "unzip", + "gnupg2", + "git", # Required by Whisper + # Chrome dependencies + "libglib2.0-0", + "libnss3", + "libatk-bridge2.0-0", + "libdrm2", + "libxkbcommon0", + "libxcomposite1", + "libxdamage1", + "libxrandr2", + "libgbm1", + "libxss1", + "libasound2" +).run_commands( + # Download and install Chrome directly (faster method) + "wget -q https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb", + "apt-get install -y ./google-chrome-stable_current_amd64.deb || apt-get install -y -f", + "rm google-chrome-stable_current_amd64.deb" +).pip_install( + # Web frameworks and basic libraries + "gradio>=5.31.0", + "fastapi", + "pydantic", + "python-dotenv", + # MCP related + "mcp[cli]", + "fastmcp>=2.7.0", + "starlette", + # Network and parsing + "beautifulsoup4", + "selenium", + "requests", + # Whisper and audio processing related + "git+https://github.com/openai/whisper.git", + "ffmpeg-python", + "torchaudio==2.1.0", + "numpy<2", + # Audio processing dependencies + "librosa", + "soundfile", + # Other Whisper ecosystem dependencies + "dacite", + "jiwer", + "pandas", + "loguru==0.6.0", + # GraphQL client (if needed) + "gql[all]~=3.0.0a5", + # Speaker diarization related dependencies + "pyannote.audio==3.1.0", + # System monitoring + "psutil", +).run_function( + download_models, + secrets=[hf_secret] if hf_secret else [] +) + +# Update file paths to reflect new structure +image = image.add_local_dir("../src", remote_path="/root/src") +secrets = [hf_secret] if hf_secret else [] + +# ==================== Modal Endpoints Configuration ==================== + +@app.function( + image=image, + volumes={cache_dir: volume}, + cpu=4, # Increased CPU for better performance + memory=8192, # 8GB memory for stable transcription + gpu="A10G", + timeout=1800, # 30 minutes timeout for speaker diarization support + scaledown_window=40, # 15 minutes before scaling down + secrets=secrets, +) +@modal.fastapi_endpoint(method="POST", label="transcribe-audio-chunk-endpoint") +def transcribe_audio_chunk_endpoint(request_data: dict): + """FastAPI endpoint for transcribing a single audio chunk (for distributed processing)""" + import sys + sys.path.append('/root') + + from src.services.modal_transcription_service import ModalTranscriptionService + + modal_service = ModalTranscriptionService(cache_dir="/root/cache", use_direct_modal_calls=True) + return modal_service.process_chunk_request(request_data) + +@app.function( + image=image, + cpu=2, # Increased CPU for better health check performance + memory=2048, # 2GB memory for stability + timeout=300, # 5 minutes timeout for health checks + scaledown_window=600, # 10 minutes before scaling down + secrets=secrets, +) +@modal.fastapi_endpoint(method="GET", label="health-check-endpoint") +def health_check_endpoint(): + """Health check endpoint to verify service status""" + import sys + sys.path.append('/root') + + from src.services.health_service import HealthService + + health_service = HealthService() + return health_service.get_health_status() + + diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b23965522701ff643382235025880cd5d62c43d0 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,29 @@ +""" +Core components for application and audio processing +""" + +# Original core components +from .config import AppConfig, app_config, get_deployment_mode, is_local_mode, is_modal_mode +from .exceptions import AppError, ConfigError, ValidationError + +# Audio processing core components +from .audio_splitter import FFmpegAudioSplitter +from .whisper_transcriber import WhisperTranscriber +from .speaker_diarization import PyannoteSpeikerDetector + +__all__ = [ + # Original core + "AppConfig", + "app_config", + "get_deployment_mode", + "is_local_mode", + "is_modal_mode", + "AppError", + "ConfigError", + "ValidationError", + + # Audio processing core + "FFmpegAudioSplitter", + "WhisperTranscriber", + "PyannoteSpeikerDetector" +] \ No newline at end of file diff --git a/src/core/__pycache__/__init__.cpython-310.pyc b/src/core/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdb317ae1bd17aedc43446392194361190025fd6 Binary files /dev/null and b/src/core/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/core/__pycache__/audio_splitter.cpython-310.pyc b/src/core/__pycache__/audio_splitter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5df9911e84ea95344805787c6a4639116a6d969 Binary files /dev/null and b/src/core/__pycache__/audio_splitter.cpython-310.pyc differ diff --git a/src/core/__pycache__/config.cpython-310.pyc b/src/core/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b211224af35fe5753530ac786e6465989b4046f7 Binary files /dev/null and b/src/core/__pycache__/config.cpython-310.pyc differ diff --git a/src/core/__pycache__/exceptions.cpython-310.pyc b/src/core/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2adfff7af112c5089fac16ffd2804f5be5d169f2 Binary files /dev/null and b/src/core/__pycache__/exceptions.cpython-310.pyc differ diff --git a/src/core/__pycache__/speaker_diarization.cpython-310.pyc b/src/core/__pycache__/speaker_diarization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..909e451abe401910a97a903ada9b75f833e32b32 Binary files /dev/null and b/src/core/__pycache__/speaker_diarization.cpython-310.pyc differ diff --git a/src/core/__pycache__/whisper_transcriber.cpython-310.pyc b/src/core/__pycache__/whisper_transcriber.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1d9410090a496396f0b8edfdb269b3e30aca0ef Binary files /dev/null and b/src/core/__pycache__/whisper_transcriber.cpython-310.pyc differ diff --git a/src/core/audio_splitter.py b/src/core/audio_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..bb029ac0eebdf16486e846ae891fb7f2f701ba61 --- /dev/null +++ b/src/core/audio_splitter.py @@ -0,0 +1,90 @@ +""" +Audio splitter implementation using FFmpeg +""" + +import re +from typing import Iterator +import ffmpeg + +from ..interfaces.audio_splitter import IAudioSplitter, AudioSegment +from ..utils.errors import AudioSplittingError + + +class FFmpegAudioSplitter(IAudioSplitter): + """Audio splitter using FFmpeg's silence detection""" + + def split_audio( + self, + audio_path: str, + min_segment_length: float = 30.0, + min_silence_length: float = 1.0 + ) -> Iterator[AudioSegment]: + """Split audio by silence detection""" + + try: + silence_end_re = re.compile( + r" silence_end: (?P[0-9]+(\.?[0-9]*)) \| silence_duration: (?P[0-9]+(\.?[0-9]*))" + ) + + # Get audio duration + duration = self.get_audio_duration(audio_path) + + # Use silence detection filter + reader = ( + ffmpeg.input(str(audio_path)) + .filter("silencedetect", n="-10dB", d=min_silence_length) + .output("pipe:", format="null") + .run_async(pipe_stderr=True) + ) + + cur_start = 0.0 + segment_count = 0 + + while True: + line = reader.stderr.readline().decode("utf-8") + if not line: + break + + match = silence_end_re.search(line) + if match: + silence_end, silence_dur = match.group("end"), match.group("dur") + split_at = float(silence_end) - (float(silence_dur) / 2) + + if (split_at - cur_start) < min_segment_length: + continue + + yield AudioSegment( + start=cur_start, + end=split_at, + duration=split_at - cur_start + ) + cur_start = split_at + segment_count += 1 + + # Handle the last segment + if duration > cur_start: + yield AudioSegment( + start=cur_start, + end=duration, + duration=duration - cur_start + ) + segment_count += 1 + + print(f"Audio split into {segment_count} segments") + + except Exception as e: + raise AudioSplittingError( + f"Failed to split audio: {str(e)}", + audio_file=audio_path + ) + + def get_audio_duration(self, audio_path: str) -> float: + """Get total duration of audio file""" + try: + metadata = ffmpeg.probe(audio_path) + return float(metadata["format"]["duration"]) + except Exception as e: + raise AudioSplittingError( + f"Failed to get audio duration: {str(e)}", + audio_file=audio_path + ) \ No newline at end of file diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..bb2162ffd35c0966634b155e293552db5867ea96 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,150 @@ +""" +Configuration management for PodcastMCP +""" + +import os +import json +from enum import Enum +from typing import Optional, Dict, Any +from pathlib import Path + + +class DeploymentMode(Enum): + """部署模式枚举""" + LOCAL = "local" # 本地Gradio + Modal GPU endpoints + MODAL = "modal" # 完全在Modal平台运行 + HF_SPACES = "hf" # Hugging Face Spaces部署 + + +class AppConfig: + """应用配置管理器""" + + def __init__(self): + self._deployment_mode = self._detect_deployment_mode() + self._cache_dir = self._get_cache_directory() + self._endpoints = self._load_endpoints() + + @property + def deployment_mode(self) -> DeploymentMode: + """获取当前部署模式""" + return self._deployment_mode + + @property + def cache_dir(self) -> str: + """获取缓存目录""" + return self._cache_dir + + @property + def is_local_mode(self) -> bool: + """是否为本地模式""" + return self._deployment_mode == DeploymentMode.LOCAL + + @property + def is_modal_mode(self) -> bool: + """是否为Modal模式""" + return self._deployment_mode == DeploymentMode.MODAL + + @property + def is_hf_spaces_mode(self) -> bool: + """是否为HF Spaces模式""" + return self._deployment_mode == DeploymentMode.HF_SPACES + + def get_transcribe_endpoint_url(self) -> Optional[str]: + """获取转录endpoint URL""" + return self._endpoints.get("transcribe_audio") + + def set_endpoint_url(self, service: str, url: str): + """设置endpoint URL""" + self._endpoints[service] = url + self._save_endpoints() + + def _detect_deployment_mode(self) -> DeploymentMode: + """自动检测部署模式""" + # 检查环境变量 + mode = os.environ.get("DEPLOYMENT_MODE", "").lower() + if mode == "modal": + return DeploymentMode.MODAL + elif mode == "hf": + return DeploymentMode.HF_SPACES + + # 检查是否在HF Spaces环境 + if os.environ.get("SPACE_ID") or os.environ.get("SPACES_ZERO_GPU"): + return DeploymentMode.HF_SPACES + + # 检查是否在Modal环境 + if os.environ.get("MODAL_TASK_ID") or os.environ.get("MODAL_IS_INSIDE_CONTAINER"): + return DeploymentMode.MODAL + + # 默认为本地模式 + return DeploymentMode.LOCAL + + def _get_cache_directory(self) -> str: + """获取缓存目录路径""" + if self.is_modal_mode: + return "/root/cache" + else: + # 本地模式和HF Spaces使用用户缓存目录 + home_dir = Path.home() + cache_dir = home_dir / ".gradio_mcp_cache" + cache_dir.mkdir(exist_ok=True) + return str(cache_dir) + + def _load_endpoints(self) -> Dict[str, str]: + """加载endpoint配置""" + config_file = Path("endpoint_config.json") + if config_file.exists(): + try: + with open(config_file, 'r') as f: + endpoints = json.load(f) + print(f"✅ Loaded endpoint configuration from {config_file}") + return endpoints + except Exception as e: + print(f"⚠️ Failed to load endpoint config: {e}") + else: + print("⚠️ No endpoint configuration found. Run deployment first.") + + return {} + + def _save_endpoints(self): + """保存endpoint配置""" + config_file = Path("endpoint_config.json") + try: + with open(config_file, 'w') as f: + json.dump(self._endpoints, f, indent=2) + print(f"💾 Endpoint configuration saved to {config_file}") + except Exception as e: + print(f"⚠️ Failed to save endpoint config: {e}") + + +# 全局配置实例 +app_config = AppConfig() + +# 向后兼容的函数接口 +def get_deployment_mode() -> str: + """获取部署模式字符串""" + return app_config.deployment_mode.value + +def is_local_mode() -> bool: + """是否为本地模式""" + return app_config.is_local_mode + +def is_modal_mode() -> bool: + """是否为Modal模式""" + return app_config.is_modal_mode + +def get_cache_dir() -> str: + """获取缓存目录""" + return app_config.cache_dir + +def get_transcribe_endpoint_url() -> Optional[str]: + """获取转录endpoint URL""" + return app_config.get_transcribe_endpoint_url() + +def set_endpoint_url(service: str, url: str): + """设置endpoint URL""" + app_config.set_endpoint_url(service, url) + + +# 打印配置信息 +print(f"🚀 Deployment mode: {app_config.deployment_mode.value}") +print(f"📁 Cache directory: {app_config.cache_dir}") \ No newline at end of file diff --git a/src/core/exceptions.py b/src/core/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a7250b6bda83c052ac1c9b9722a72191ea7210 --- /dev/null +++ b/src/core/exceptions.py @@ -0,0 +1,43 @@ +""" +Custom exceptions for PodcastMCP +""" + + +class PodcastMCPError(Exception): + """PodcastMCP基础异常类""" + pass + + +class AppError(PodcastMCPError): + """应用程序异常""" + pass + + +class ConfigError(PodcastMCPError): + """配置相关异常""" + pass + + +class ValidationError(PodcastMCPError): + """验证相关异常""" + pass + + +class TranscriptionError(PodcastMCPError): + """转录相关异常""" + pass + + +class DeploymentError(PodcastMCPError): + """部署相关异常""" + pass + + +class FileNotFoundError(PodcastMCPError): + """文件未找到异常""" + pass + + +class EndpointError(PodcastMCPError): + """Endpoint相关异常""" + pass \ No newline at end of file diff --git a/src/core/speaker_diarization.py b/src/core/speaker_diarization.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb624ccf0f8b2c6f5a05ad11f860bbfd9ae9ff6 --- /dev/null +++ b/src/core/speaker_diarization.py @@ -0,0 +1,126 @@ +""" +Speaker diarization implementation using pyannote.audio +""" + +import os +import torch +from typing import Optional, List, Dict, Any + +from ..interfaces.speaker_detector import ISpeakerDetector +from ..utils.config import AudioProcessingConfig +from ..utils.errors import SpeakerDiarizationError, ModelLoadError + + +class PyannoteSpeikerDetector(ISpeakerDetector): + """Speaker diarization using pyannote.audio""" + + def __init__(self, config: Optional[AudioProcessingConfig] = None): + self.config = config or AudioProcessingConfig() + self.device = self._setup_device() + self.pipeline = None + self.auth_token = os.environ.get(self.config.hf_token_env_var) + + if not self.auth_token: + print("⚠️ No Hugging Face token found. Speaker diarization will be disabled.") + + def _setup_device(self) -> torch.device: + """Setup and return the best available device""" + if torch.cuda.is_available(): + return torch.device("cuda") + else: + return torch.device("cpu") + + async def detect_speakers( + self, + audio_file_path: str, + num_speakers: Optional[int] = None, + min_speakers: int = 1, + max_speakers: int = 10 + ) -> Dict[str, Any]: + """Detect speakers in audio file""" + + if not self.auth_token: + raise SpeakerDiarizationError( + "Speaker diarization requires Hugging Face token", + audio_file=audio_file_path + ) + + try: + # Load pipeline if not already loaded + if self.pipeline is None: + self.pipeline = self._load_pipeline() + + # Perform diarization + diarization = self.pipeline(audio_file_path) + + # Convert to our format + speakers = {} + segments = [] + + for turn, _, speaker in diarization.itertracks(yield_label=True): + speaker_id = f"SPEAKER_{speaker.split('_')[-1].zfill(2)}" + segments.append({ + "start": turn.start, + "end": turn.end, + "speaker": speaker_id + }) + + if speaker_id not in speakers: + speakers[speaker_id] = { + "id": speaker_id, + "total_time": 0.0, + "segments": [] + } + + speakers[speaker_id]["total_time"] += turn.end - turn.start + speakers[speaker_id]["segments"].append({ + "start": turn.start, + "end": turn.end + }) + + return { + "speaker_count": len(speakers), + "speakers": speakers, + "segments": segments, + "audio_file": audio_file_path + } + + except Exception as e: + raise SpeakerDiarizationError( + f"Speaker detection failed: {str(e)}", + audio_file=audio_file_path + ) + + def _load_pipeline(self): + """Load pyannote speaker diarization pipeline""" + try: + # Suppress warnings + import warnings + warnings.filterwarnings("ignore", category=UserWarning, module="pyannote") + warnings.filterwarnings("ignore", category=UserWarning, module="pytorch_lightning") + warnings.filterwarnings("ignore", category=FutureWarning, module="pytorch_lightning") + + from pyannote.audio import Pipeline + + print("📥 Loading speaker diarization pipeline...") + pipeline = Pipeline.from_pretrained( + self.config.speaker_diarization_model, + use_auth_token=self.auth_token + ) + pipeline.to(self.device) + + return pipeline + + except Exception as e: + raise ModelLoadError( + f"Failed to load speaker diarization pipeline: {str(e)}", + model_name=self.config.speaker_diarization_model + ) + + def get_supported_models(self) -> List[str]: + """Get list of supported speaker diarization models""" + return [self.config.speaker_diarization_model] + + def is_available(self) -> bool: + """Check if speaker diarization is available""" + return self.auth_token is not None \ No newline at end of file diff --git a/src/core/whisper_transcriber.py b/src/core/whisper_transcriber.py new file mode 100644 index 0000000000000000000000000000000000000000..dd61d7a25a70a4ecfa5357e49a08a2a9bed2e799 --- /dev/null +++ b/src/core/whisper_transcriber.py @@ -0,0 +1,113 @@ +""" +Local Whisper transcriber implementation +""" + +import whisper +import torch +import pathlib +import time +from typing import Optional, List + +from ..interfaces.transcriber import ITranscriber, TranscriptionResult, TranscriptionSegment +from ..utils.config import AudioProcessingConfig +from ..utils.errors import TranscriptionError, ModelLoadError + + +class WhisperTranscriber(ITranscriber): + """Local Whisper transcriber implementation""" + + def __init__(self, config: Optional[AudioProcessingConfig] = None): + self.config = config or AudioProcessingConfig() + self.model_cache = {} + self.device = self._setup_device() + + def _setup_device(self) -> str: + """Setup and return the best available device""" + if torch.cuda.is_available(): + return "cuda" + else: + return "cpu" + + async def transcribe( + self, + audio_file_path: str, + model_size: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False + ) -> TranscriptionResult: + """Transcribe audio using local Whisper model""" + + try: + # Validate audio file + audio_path = pathlib.Path(audio_file_path) + if not audio_path.exists(): + raise TranscriptionError( + f"Audio file not found: {audio_file_path}", + audio_file=audio_file_path + ) + + # Load model + model = self._load_model(model_size) + + # Transcribe + start_time = time.time() + result = model.transcribe( + str(audio_path), + language=language, + verbose=False + ) + processing_time = time.time() - start_time + + # Convert to our format + segments = [] + for seg in result.get("segments", []): + segments.append(TranscriptionSegment( + start=seg["start"], + end=seg["end"], + text=seg["text"].strip(), + confidence=seg.get("avg_logprob") + )) + + return TranscriptionResult( + text=result.get("text", "").strip(), + segments=segments, + language=result.get("language", "unknown"), + model_used=model_size, + audio_duration=result.get("duration", 0), + processing_time=processing_time, + speaker_diarization_enabled=enable_speaker_diarization, + global_speaker_count=0, + error_message=None + ) + + except Exception as e: + raise TranscriptionError( + f"Whisper transcription failed: {str(e)}", + model=model_size, + audio_file=audio_file_path + ) + + def _load_model(self, model_size: str): + """Load Whisper model with caching""" + if model_size not in self.model_cache: + try: + print(f"📥 Loading Whisper model: {model_size}") + self.model_cache[model_size] = whisper.load_model( + model_size, + device=self.device + ) + except Exception as e: + raise ModelLoadError( + f"Failed to load model {model_size}: {str(e)}", + model_name=model_size + ) + + return self.model_cache[model_size] + + def get_supported_models(self) -> List[str]: + """Get list of supported model sizes""" + return list(self.config.whisper_models.keys()) + + def get_supported_languages(self) -> List[str]: + """Get list of supported language codes""" + return ["en", "zh", "ja", "ko", "es", "fr", "de", "ru", "auto"] \ No newline at end of file diff --git a/src/deployment/__init__.py b/src/deployment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef00f5af0fc212a1b7c3cda43a6c9e827a8ada83 --- /dev/null +++ b/src/deployment/__init__.py @@ -0,0 +1,8 @@ +""" +Deployment management for audio processing services +""" + +from .modal_deployer import ModalDeployer +from .endpoint_manager import EndpointManager + +__all__ = ["ModalDeployer", "EndpointManager"] \ No newline at end of file diff --git a/src/deployment/deployment_manager.py b/src/deployment/deployment_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a8aac28a245f0331a53174294306ec80b2be164c --- /dev/null +++ b/src/deployment/deployment_manager.py @@ -0,0 +1,153 @@ +""" +Simplified deployment manager +This replaces the complex deploy_endpoints.py with a cleaner interface +""" + +import argparse +import sys +from typing import Optional + +from ..audio_processing.deployment import ModalDeployer, EndpointManager +from ..audio_processing.utils.config import AudioProcessingConfig +from ..audio_processing.utils.errors import DeploymentError + + +class DeploymentManager: + """Simplified deployment manager for audio processing services""" + + def __init__(self): + self.config = AudioProcessingConfig() + self.modal_deployer = ModalDeployer(self.config) + self.endpoint_manager = EndpointManager() + + def deploy(self) -> bool: + """Deploy transcription service""" + try: + print("🚀 Starting deployment process...") + endpoint_url = self.modal_deployer.deploy_transcription_service() + + if endpoint_url: + print(f"✅ Deployment successful!") + print(f"🌐 Endpoint URL: {endpoint_url}") + return True + else: + print("❌ Deployment failed: Could not get endpoint URL") + return False + + except DeploymentError as e: + print(f"❌ Deployment failed: {e.message}") + if e.details: + print(f"📋 Details: {e.details}") + return False + except Exception as e: + print(f"❌ Unexpected deployment error: {str(e)}") + return False + + def status(self) -> bool: + """Check deployment status""" + print("🔍 Checking deployment status...") + + endpoints = self.endpoint_manager.list_endpoints() + if not endpoints: + print("❌ No endpoints configured") + return False + + print(f"📋 Configured endpoints:") + for name, url in endpoints.items(): + print(f" • {name}: {url}") + + # Check health + return self.modal_deployer.check_deployment_status() + + def undeploy(self): + """Remove deployment configuration""" + print("🗑️ Removing deployment configuration...") + self.modal_deployer.undeploy_transcription_service() + + def list_endpoints(self): + """List all configured endpoints""" + endpoints = self.endpoint_manager.list_endpoints() + + if not endpoints: + print("📋 No endpoints configured") + return + + print("📋 Configured endpoints:") + for name, url in endpoints.items(): + health_status = "✅ Healthy" if self.endpoint_manager.check_endpoint_health(name) else "❌ Unhealthy" + print(f" • {name}: {url} ({health_status})") + + def set_endpoint(self, name: str, url: str): + """Manually set an endpoint""" + self.endpoint_manager.set_endpoint(name, url) + + def remove_endpoint(self, name: str): + """Remove an endpoint""" + self.endpoint_manager.remove_endpoint(name) + + +def main(): + """Command line interface for deployment manager""" + parser = argparse.ArgumentParser(description="Audio Processing Deployment Manager") + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Deploy command + subparsers.add_parser("deploy", help="Deploy transcription service to Modal") + + # Status command + subparsers.add_parser("status", help="Check deployment status") + + # Undeploy command + subparsers.add_parser("undeploy", help="Remove deployment configuration") + + # List endpoints command + subparsers.add_parser("list", help="List all configured endpoints") + + # Set endpoint command + set_parser = subparsers.add_parser("set", help="Set endpoint URL manually") + set_parser.add_argument("name", help="Endpoint name") + set_parser.add_argument("url", help="Endpoint URL") + + # Remove endpoint command + remove_parser = subparsers.add_parser("remove", help="Remove endpoint") + remove_parser.add_argument("name", help="Endpoint name") + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + manager = DeploymentManager() + + try: + if args.command == "deploy": + success = manager.deploy() + sys.exit(0 if success else 1) + + elif args.command == "status": + success = manager.status() + sys.exit(0 if success else 1) + + elif args.command == "undeploy": + manager.undeploy() + + elif args.command == "list": + manager.list_endpoints() + + elif args.command == "set": + manager.set_endpoint(args.name, args.url) + + elif args.command == "remove": + manager.remove_endpoint(args.name) + + except KeyboardInterrupt: + print("\n⚠️ Operation cancelled by user") + sys.exit(1) + except Exception as e: + print(f"❌ Error: {str(e)}") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/deployment/endpoint_manager.py b/src/deployment/endpoint_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..247ab24a64da388b6a7ede2c3cc5883522eeafeb --- /dev/null +++ b/src/deployment/endpoint_manager.py @@ -0,0 +1,76 @@ +""" +Endpoint manager for handling Modal endpoints +""" + +import json +import os +from typing import Dict, Optional + +from ..utils.errors import ConfigurationError + + +class EndpointManager: + """Manager for Modal endpoint configuration""" + + def __init__(self, config_file: str = "endpoint_config.json"): + self.config_file = config_file + self._endpoints = self._load_endpoints() + + def _load_endpoints(self) -> Dict[str, str]: + """Load endpoints from configuration file""" + if not os.path.exists(self.config_file): + return {} + + try: + with open(self.config_file, 'r') as f: + return json.load(f) + except Exception as e: + print(f"⚠️ Failed to load endpoint configuration: {e}") + return {} + + def save_endpoints(self): + """Save endpoints to configuration file""" + try: + with open(self.config_file, 'w') as f: + json.dump(self._endpoints, f, indent=2) + print(f"💾 Endpoint configuration saved to {self.config_file}") + except Exception as e: + raise ConfigurationError(f"Failed to save endpoint configuration: {e}") + + def set_endpoint(self, name: str, url: str): + """Set endpoint URL""" + self._endpoints[name] = url + self.save_endpoints() + print(f"✅ Endpoint '{name}' set to: {url}") + + def get_endpoint(self, name: str) -> Optional[str]: + """Get endpoint URL""" + return self._endpoints.get(name) + + def remove_endpoint(self, name: str): + """Remove endpoint""" + if name in self._endpoints: + del self._endpoints[name] + self.save_endpoints() + print(f"🗑️ Endpoint '{name}' removed") + else: + print(f"⚠️ Endpoint '{name}' not found") + + def list_endpoints(self) -> Dict[str, str]: + """List all endpoints""" + return self._endpoints.copy() + + def check_endpoint_health(self, name: str) -> bool: + """Check if endpoint is healthy""" + url = self.get_endpoint(name) + if not url: + return False + + try: + import requests + # Try a simple health check (adjust based on your endpoint) + health_url = url.replace("/transcribe", "/health") + response = requests.get(health_url, timeout=10) + return response.status_code == 200 + except Exception: + return False \ No newline at end of file diff --git a/src/deployment/modal_deployer.py b/src/deployment/modal_deployer.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a0b95ab9b0cc606c8dc7ec1a2e3606b29cad6e --- /dev/null +++ b/src/deployment/modal_deployer.py @@ -0,0 +1,97 @@ +""" +Modal deployer for deploying transcription services +""" + +import subprocess +from typing import Optional + +from ..utils.config import AudioProcessingConfig +from ..utils.errors import DeploymentError +from .endpoint_manager import EndpointManager + + +class ModalDeployer: + """Deployer for Modal transcription services""" + + def __init__(self, config: Optional[AudioProcessingConfig] = None): + self.config = config or AudioProcessingConfig() + self.endpoint_manager = EndpointManager() + + def deploy_transcription_service(self) -> Optional[str]: + """Deploy transcription service to Modal""" + + print("🚀 Deploying transcription service to Modal...") + + try: + # Deploy the Modal app + print("🚀 Running modal deploy command...") + result = subprocess.run( + ["modal", "deploy", "modal_config.py"], + capture_output=True, + text=True + ) + + if result.returncode == 0: + # Extract or construct endpoint URL + endpoint_url = self._extract_endpoint_url(result.stdout) + + if endpoint_url: + # Save endpoint configuration + self.endpoint_manager.set_endpoint("transcribe_audio", endpoint_url) + print(f"✅ Transcription service deployed: {endpoint_url}") + return endpoint_url + else: + print("⚠️ Could not extract endpoint URL from deployment output") + return None + else: + raise DeploymentError( + f"Modal deployment failed: {result.stderr}", + service="transcription" + ) + + except FileNotFoundError: + raise DeploymentError( + "Modal CLI not found. Please install Modal: pip install modal", + service="transcription" + ) + except Exception as e: + raise DeploymentError( + f"Failed to deploy transcription service: {str(e)}", + service="transcription" + ) + + def _extract_endpoint_url(self, output: str) -> Optional[str]: + """Extract endpoint URL from deployment output""" + + # Look for URL in output + for line in output.split('\n'): + if 'https://' in line and 'modal.run' in line: + # Extract URL from line + parts = line.split() + for part in parts: + if part.startswith('https://') and 'modal.run' in part: + return part + + # Fallback to constructed URL + return f"https://{self.config.modal_app_name}--transcribe-audio-endpoint.modal.run" + + def check_deployment_status(self) -> bool: + """Check if transcription service is deployed and healthy""" + + endpoint_url = self.endpoint_manager.get_endpoint("transcribe_audio") + if not endpoint_url: + print("❌ No transcription endpoint configured") + return False + + if self.endpoint_manager.check_endpoint_health("transcribe_audio"): + print(f"✅ Transcription service is healthy: {endpoint_url}") + return True + else: + print(f"❌ Transcription service is not responding: {endpoint_url}") + return False + + def undeploy_transcription_service(self): + """Remove transcription service endpoint""" + self.endpoint_manager.remove_endpoint("transcribe_audio") + print("🗑️ Transcription service endpoint removed from configuration") + print("💡 Note: The actual Modal deployment may still be active. Use 'modal app stop' to stop it.") \ No newline at end of file diff --git a/src/interfaces/__init__.py b/src/interfaces/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ecaf7ba07a89b572fc51d4a43b6c0e45eb01e01 --- /dev/null +++ b/src/interfaces/__init__.py @@ -0,0 +1,38 @@ +""" +Interfaces for audio processing components +""" + +from .transcriber import ITranscriber +from .speaker_detector import ISpeakerDetector +from .audio_splitter import IAudioSplitter +from .audio_processor import IAudioProcessor, AudioSegment +from .podcast_downloader import IPodcastDownloader, PodcastInfo, DownloadResult, PodcastPlatform +from .speaker_manager import ( + ISpeakerEmbeddingManager, + ISpeakerIdentificationService, + SpeakerEmbedding, + SpeakerSegment +) + +__all__ = [ + # Core interfaces + "ITranscriber", + "ISpeakerDetector", + "IAudioSplitter", + + # New service interfaces + "IAudioProcessor", + "IPodcastDownloader", + "ISpeakerEmbeddingManager", + "ISpeakerIdentificationService", + + # Data classes + "AudioSegment", + "PodcastInfo", + "DownloadResult", + "SpeakerEmbedding", + "SpeakerSegment", + + # Enums + "PodcastPlatform" +] \ No newline at end of file diff --git a/src/interfaces/__pycache__/__init__.cpython-310.pyc b/src/interfaces/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd190ed9e7d3b774530985b210bc0ef1be694bec Binary files /dev/null and b/src/interfaces/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/interfaces/__pycache__/audio_processor.cpython-310.pyc b/src/interfaces/__pycache__/audio_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20dcf790aded637fea60da8d4714f7048bd858c9 Binary files /dev/null and b/src/interfaces/__pycache__/audio_processor.cpython-310.pyc differ diff --git a/src/interfaces/__pycache__/audio_splitter.cpython-310.pyc b/src/interfaces/__pycache__/audio_splitter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de8cc50bc010ac35dd4f3cddf8eedfc28f3e9886 Binary files /dev/null and b/src/interfaces/__pycache__/audio_splitter.cpython-310.pyc differ diff --git a/src/interfaces/__pycache__/podcast_downloader.cpython-310.pyc b/src/interfaces/__pycache__/podcast_downloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83b82b4efca88c0633650cbc6fd3074e41742fc0 Binary files /dev/null and b/src/interfaces/__pycache__/podcast_downloader.cpython-310.pyc differ diff --git a/src/interfaces/__pycache__/speaker_detector.cpython-310.pyc b/src/interfaces/__pycache__/speaker_detector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a9aca09cdc1c34355a8e67d4daae793cdcf7775 Binary files /dev/null and b/src/interfaces/__pycache__/speaker_detector.cpython-310.pyc differ diff --git a/src/interfaces/__pycache__/speaker_manager.cpython-310.pyc b/src/interfaces/__pycache__/speaker_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a3990fa1d978e6f2304cd30fb8545f911e9c688 Binary files /dev/null and b/src/interfaces/__pycache__/speaker_manager.cpython-310.pyc differ diff --git a/src/interfaces/__pycache__/transcriber.cpython-310.pyc b/src/interfaces/__pycache__/transcriber.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..172b9a9f229f6cd926a586ca2de3037bdb770734 Binary files /dev/null and b/src/interfaces/__pycache__/transcriber.cpython-310.pyc differ diff --git a/src/interfaces/audio_processor.py b/src/interfaces/audio_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc8b4b632ad6b949452b3ea5be4bbbd35a31bad --- /dev/null +++ b/src/interfaces/audio_processor.py @@ -0,0 +1,53 @@ +""" +Audio processing interface definitions +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Tuple, Iterator, Optional +from dataclasses import dataclass + + +@dataclass +class AudioSegment: + """Audio segment representation""" + start: float + end: float + file_path: str + duration: float + + +class IAudioProcessor(ABC): + """Interface for audio processing operations""" + + @abstractmethod + async def split_audio_by_silence( + self, + audio_path: str, + min_segment_length: float = 30.0, + min_silence_length: float = 1.0 + ) -> List[AudioSegment]: + """Split audio file by silence detection""" + pass + + @abstractmethod + async def process_audio_segment( + self, + segment: AudioSegment, + model_name: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False + ) -> Dict[str, Any]: + """Process a single audio segment""" + pass + + @abstractmethod + async def process_complete_audio( + self, + audio_path: str, + model_name: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False, + min_segment_length: float = 30.0 + ) -> Dict[str, Any]: + """Process complete audio file""" + pass \ No newline at end of file diff --git a/src/interfaces/audio_splitter.py b/src/interfaces/audio_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..45ddc36009f0198dfa17d8f3f97e295c2951e7fc --- /dev/null +++ b/src/interfaces/audio_splitter.py @@ -0,0 +1,48 @@ +""" +Audio splitter interface definition +""" + +from abc import ABC, abstractmethod +from typing import Iterator, Tuple +from dataclasses import dataclass + + +@dataclass +class AudioSegment: + """Audio segment data class""" + start: float + end: float + duration: float + + def __post_init__(self): + if self.duration <= 0: + self.duration = self.end - self.start + + +class IAudioSplitter(ABC): + """Interface for audio splitting""" + + @abstractmethod + def split_audio( + self, + audio_path: str, + min_segment_length: float = 30.0, + min_silence_length: float = 1.0 + ) -> Iterator[AudioSegment]: + """ + Split audio into segments + + Args: + audio_path: Path to audio file + min_segment_length: Minimum segment length in seconds + min_silence_length: Minimum silence length for splitting + + Yields: + AudioSegment objects + """ + pass + + @abstractmethod + def get_audio_duration(self, audio_path: str) -> float: + """Get total duration of audio file""" + pass \ No newline at end of file diff --git a/src/interfaces/podcast_downloader.py b/src/interfaces/podcast_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..ea485c584a41d87d1b516dcd680696cbae2bac47 --- /dev/null +++ b/src/interfaces/podcast_downloader.py @@ -0,0 +1,66 @@ +""" +Podcast downloading interface definitions +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, Tuple +from dataclasses import dataclass +from enum import Enum + + +class PodcastPlatform(Enum): + """Podcast platform enumeration""" + APPLE = "apple" + XIAOYUZHOU = "xyz" + SPOTIFY = "spotify" + GENERIC = "generic" + + +@dataclass +class PodcastInfo: + """Podcast episode information""" + title: str + audio_url: str + episode_id: str + platform: PodcastPlatform + duration: Optional[float] = None + description: Optional[str] = None + + +@dataclass +class DownloadResult: + """Download operation result""" + success: bool + file_path: Optional[str] + podcast_info: Optional[PodcastInfo] + error_message: Optional[str] = None + + +class IPodcastDownloader(ABC): + """Interface for podcast downloading operations""" + + @abstractmethod + async def extract_podcast_info(self, url: str) -> PodcastInfo: + """Extract podcast information from URL""" + pass + + @abstractmethod + async def download_podcast( + self, + url: str, + output_folder: str = "downloads", + convert_to_mp3: bool = False, + keep_original: bool = False + ) -> DownloadResult: + """Download podcast from URL""" + pass + + @abstractmethod + def get_supported_platforms(self) -> list[PodcastPlatform]: + """Get list of supported platforms""" + pass + + @abstractmethod + def can_handle_url(self, url: str) -> bool: + """Check if this downloader can handle the given URL""" + pass \ No newline at end of file diff --git a/src/interfaces/speaker_detector.py b/src/interfaces/speaker_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..f8152345977c23aab529a8e7a8e070de39c325a5 --- /dev/null +++ b/src/interfaces/speaker_detector.py @@ -0,0 +1,71 @@ +""" +Speaker detector interface definition +""" + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional +from dataclasses import dataclass +import numpy as np + + +@dataclass +class SpeakerSegment: + """Speaker segment data class""" + start: float + end: float + speaker_id: str + confidence: Optional[float] = None + + +@dataclass +class SpeakerProfile: + """Speaker profile data class""" + speaker_id: str + embedding: np.ndarray + segments: List[SpeakerSegment] + total_duration: float + + +class ISpeakerDetector(ABC): + """Interface for speaker detection and diarization""" + + @abstractmethod + async def detect_speakers( + self, + audio_file_path: str, + audio_segments: Optional[List] = None + ) -> Dict[str, SpeakerProfile]: + """ + Detect and identify speakers in audio + + Args: + audio_file_path: Path to audio file + audio_segments: Optional pre-segmented audio + + Returns: + Dictionary mapping speaker IDs to SpeakerProfile objects + """ + pass + + @abstractmethod + def map_to_global_speakers( + self, + local_speakers: Dict[str, SpeakerProfile], + source_file: str + ) -> Dict[str, str]: + """ + Map local speakers to global speaker identities + + Args: + local_speakers: Local speaker profiles + source_file: Source audio file path + + Returns: + Mapping from local speaker ID to global speaker ID + """ + pass + + @abstractmethod + def get_speaker_summary(self) -> Dict: + """Get summary of all detected speakers""" + pass \ No newline at end of file diff --git a/src/interfaces/speaker_manager.py b/src/interfaces/speaker_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d88816e6bb287c4371fc48219c14a5984c9e36e1 --- /dev/null +++ b/src/interfaces/speaker_manager.py @@ -0,0 +1,113 @@ +""" +Speaker identification and embedding management interfaces +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List, Tuple +from dataclasses import dataclass +import numpy as np + + +@dataclass +class SpeakerEmbedding: + """Speaker embedding data structure""" + speaker_id: str + embedding: np.ndarray + confidence: float + source_files: List[str] + sample_count: int + created_at: str + updated_at: str + + +@dataclass +class SpeakerSegment: + """Speaker segment information""" + start: float + end: float + speaker_id: str + confidence: float + + +class ISpeakerEmbeddingManager(ABC): + """Interface for speaker embedding management""" + + @abstractmethod + async def find_matching_speaker( + self, + embedding: np.ndarray, + source_file: str + ) -> Optional[str]: + """Find matching speaker from existing embeddings""" + pass + + @abstractmethod + async def add_or_update_speaker( + self, + embedding: np.ndarray, + source_file: str, + confidence: float = 1.0, + original_label: Optional[str] = None + ) -> str: + """Add new speaker or update existing speaker""" + pass + + @abstractmethod + async def map_local_to_global_speakers( + self, + local_embeddings: Dict[str, np.ndarray], + source_file: str + ) -> Dict[str, str]: + """Map local speaker labels to global speaker IDs""" + pass + + @abstractmethod + async def get_speaker_info(self, speaker_id: str) -> Optional[SpeakerEmbedding]: + """Get speaker information by ID""" + pass + + @abstractmethod + async def get_all_speakers_summary(self) -> Dict[str, Any]: + """Get summary of all speakers""" + pass + + @abstractmethod + async def save_speakers(self) -> None: + """Save speaker data to storage""" + pass + + @abstractmethod + async def load_speakers(self) -> None: + """Load speaker data from storage""" + pass + + +class ISpeakerIdentificationService(ABC): + """Interface for speaker identification operations""" + + @abstractmethod + async def extract_speaker_embeddings( + self, + audio_path: str, + segments: List[SpeakerSegment] + ) -> Dict[str, np.ndarray]: + """Extract speaker embeddings from audio segments""" + pass + + @abstractmethod + async def identify_speakers_in_audio( + self, + audio_path: str, + transcription_segments: List[Dict[str, Any]] + ) -> List[SpeakerSegment]: + """Identify speakers in audio file""" + pass + + @abstractmethod + async def map_transcription_to_speakers( + self, + transcription_segments: List[Dict[str, Any]], + speaker_segments: List[SpeakerSegment] + ) -> List[Dict[str, Any]]: + """Map transcription segments to speaker information""" + pass \ No newline at end of file diff --git a/src/interfaces/transcriber.py b/src/interfaces/transcriber.py new file mode 100644 index 0000000000000000000000000000000000000000..b943fc49a2fc51760317764ff34309d90bc5c715 --- /dev/null +++ b/src/interfaces/transcriber.py @@ -0,0 +1,67 @@ +""" +Transcriber interface definition +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + + +@dataclass +class TranscriptionSegment: + """Transcription segment data class""" + start: float + end: float + text: str + speaker: Optional[str] = None + confidence: Optional[float] = None + + +@dataclass +class TranscriptionResult: + """Transcription result data class""" + text: str + segments: List[TranscriptionSegment] + language: str + model_used: str + audio_duration: float + processing_time: float + speaker_diarization_enabled: bool = False + global_speaker_count: int = 0 + error_message: Optional[str] = None + + +class ITranscriber(ABC): + """Interface for audio transcription""" + + @abstractmethod + async def transcribe( + self, + audio_file_path: str, + model_size: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False + ) -> TranscriptionResult: + """ + Transcribe audio file + + Args: + audio_file_path: Path to audio file + model_size: Model size to use + language: Language code (None for auto-detect) + enable_speaker_diarization: Whether to enable speaker detection + + Returns: + TranscriptionResult object + """ + pass + + @abstractmethod + def get_supported_models(self) -> List[str]: + """Get list of supported model sizes""" + pass + + @abstractmethod + def get_supported_languages(self) -> List[str]: + """Get list of supported language codes""" + pass \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a00c0739ea159911eacdab6ce0e5e1a1146767a1 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,96 @@ +""" +Data models for audio processing +""" + +from .base import BaseRequest, BaseResponse, OperationStatus +from .transcription import ( + TranscriptionRequest, + TranscriptionResponse, + TranscriptionSegment, + SpeakerInfo, + TranscriptionFiles, + TranscriptionMetrics, + ModelSize +) +from .download import ( + DownloadRequest, + DownloadResponse, + PodcastPlatform +) +from .file_operations import ( + FileInfoRequest, + FileInfoResponse, + FileReadRequest, + FileReadResponse, + DirectoryListRequest, + DirectoryListResponse, + FileMetadata +) +from .services import ( + AudioProcessingTask, + FileOperationType, + AudioProcessingRequest, + AudioProcessingResult, + PodcastDownloadRequest, + PodcastDownloadResult, + SpeakerEmbeddingRequest, + SpeakerEmbeddingResult, + FileManagementRequest, + FileManagementResult, + ServiceError, + ServiceHealthCheck +) +from .converters import ( + TranscriptionConverter, + DownloadConverter, + FileOperationConverter +) + +__all__ = [ + # Base + "BaseRequest", + "BaseResponse", + "OperationStatus", + + # Transcription + "TranscriptionRequest", + "TranscriptionResponse", + "TranscriptionSegment", + "SpeakerInfo", + "TranscriptionFiles", + "TranscriptionMetrics", + "ModelSize", + + # Download + "DownloadRequest", + "DownloadResponse", + "PodcastPlatform", + + # File Operations + "FileInfoRequest", + "FileInfoResponse", + "FileReadRequest", + "FileReadResponse", + "DirectoryListRequest", + "DirectoryListResponse", + "FileMetadata", + + # Service layer models + "AudioProcessingTask", + "FileOperationType", + "AudioProcessingRequest", + "AudioProcessingResult", + "PodcastDownloadRequest", + "PodcastDownloadResult", + "SpeakerEmbeddingRequest", + "SpeakerEmbeddingResult", + "FileManagementRequest", + "FileManagementResult", + "ServiceError", + "ServiceHealthCheck", + + # Converters + "TranscriptionConverter", + "DownloadConverter", + "FileOperationConverter", +] \ No newline at end of file diff --git a/src/models/__pycache__/__init__.cpython-310.pyc b/src/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d109bcab60db6246180ea0a2c963d952dd8f3458 Binary files /dev/null and b/src/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/models/__pycache__/base.cpython-310.pyc b/src/models/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ff3d540416553925cdc64a8cdccc7bf361adfba Binary files /dev/null and b/src/models/__pycache__/base.cpython-310.pyc differ diff --git a/src/models/__pycache__/converters.cpython-310.pyc b/src/models/__pycache__/converters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d9abf1059235c519fd6ca2a9937df4abb7e13b Binary files /dev/null and b/src/models/__pycache__/converters.cpython-310.pyc differ diff --git a/src/models/__pycache__/download.cpython-310.pyc b/src/models/__pycache__/download.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dfe46746a57d46127de15defaad6392a258b0f3 Binary files /dev/null and b/src/models/__pycache__/download.cpython-310.pyc differ diff --git a/src/models/__pycache__/file_operations.cpython-310.pyc b/src/models/__pycache__/file_operations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8d1e058ea8274fe5e1f3191b6b86b707c7b97e4 Binary files /dev/null and b/src/models/__pycache__/file_operations.cpython-310.pyc differ diff --git a/src/models/__pycache__/services.cpython-310.pyc b/src/models/__pycache__/services.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42f4ec27d6a0209e36308ae3612a426e2f55312 Binary files /dev/null and b/src/models/__pycache__/services.cpython-310.pyc differ diff --git a/src/models/__pycache__/transcription.cpython-310.pyc b/src/models/__pycache__/transcription.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..173a07f5e058dc046592e82b3f160cb1deceabf5 Binary files /dev/null and b/src/models/__pycache__/transcription.cpython-310.pyc differ diff --git a/src/models/base.py b/src/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9c4b24a490c7131ccc5490a0d924bfdf6f54f7 --- /dev/null +++ b/src/models/base.py @@ -0,0 +1,59 @@ +""" +Base models for common data structures +""" + +from dataclasses import dataclass, asdict +from enum import Enum +from typing import Optional, Dict, Any +import json + + +class OperationStatus(str, Enum): + """Standard operation status""" + SUCCESS = "success" + FAILED = "failed" + PENDING = "pending" + IN_PROGRESS = "in_progress" + + +@dataclass +class BaseResponse: + """Base response model for all operations""" + status: OperationStatus + message: Optional[str] = None + error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + result = asdict(self) + # Convert enum to string + result["status"] = self.status.value + return result + + def to_json(self) -> str: + """Convert to JSON string""" + return json.dumps(self.to_dict(), ensure_ascii=False, indent=2) + + @property + def is_success(self) -> bool: + """Check if operation was successful""" + return self.status == OperationStatus.SUCCESS + + @property + def is_failed(self) -> bool: + """Check if operation failed""" + return self.status == OperationStatus.FAILED + + +@dataclass +class BaseRequest: + """Base request model""" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return asdict(self) + + def to_json(self) -> str: + """Convert to JSON string""" + return json.dumps(self.to_dict(), ensure_ascii=False, indent=2) \ No newline at end of file diff --git a/src/models/converters.py b/src/models/converters.py new file mode 100644 index 0000000000000000000000000000000000000000..9d33381c4f99a21ff76663f3ccd849740cd033d3 --- /dev/null +++ b/src/models/converters.py @@ -0,0 +1,235 @@ +""" +Model converters for transforming between different formats +""" + +from typing import Dict, Any, List +from datetime import datetime +import os + +from .transcription import ( + TranscriptionResponse, TranscriptionFiles, TranscriptionSegment, + TranscriptionMetrics, SpeakerInfo, ModelSize, OutputFormat +) +from .download import DownloadResponse, PodcastPlatform +from .file_operations import ( + FileInfoResponse, FileReadResponse, DirectoryListResponse, + FileMetadata, FileReadProgress +) +from .base import OperationStatus + + +class TranscriptionConverter: + """Converter for transcription models""" + + @staticmethod + def from_legacy_dict(data: Dict[str, Any]) -> TranscriptionResponse: + """Convert legacy dictionary format to TranscriptionResponse""" + + # Parse files + files = TranscriptionFiles( + txt_file_path=data.get("txt_file_path"), + srt_file_path=data.get("srt_file_path"), + json_file_path=data.get("json_file_path") + ) + + # Parse segments + segments = [] + if "segments" in data: + for seg_data in data["segments"]: + segments.append(TranscriptionSegment( + start=seg_data.get("start", 0), + end=seg_data.get("end", 0), + text=seg_data.get("text", ""), + speaker=seg_data.get("speaker"), + confidence=seg_data.get("confidence") + )) + + # Parse metrics + metrics = TranscriptionMetrics( + audio_duration=data.get("audio_duration", 0), + processing_time=data.get("processing_time", 0), + segment_count=data.get("segment_count", 0), + model_used=data.get("model_used", ""), + language_detected=data.get("language_detected", "unknown") + ) + + # Parse speaker info + speaker_info = SpeakerInfo( + enabled=data.get("speaker_diarization_enabled", False), + global_speaker_count=data.get("global_speaker_count", 0), + speaker_mapping=data.get("speaker_summary", {}).get("speaker_mapping", {}), + speaker_summary=data.get("speaker_summary", {}) + ) + + # Determine status + status = OperationStatus.SUCCESS if data.get("processing_status") == "success" else OperationStatus.FAILED + + return TranscriptionResponse( + status=status, + message=data.get("error_message") if status == OperationStatus.FAILED else "转录完成", + audio_file=data.get("audio_file", ""), + files=files, + segments=segments, + speaker_info=speaker_info, + metrics=metrics + ) + + @staticmethod + def to_legacy_dict(response: TranscriptionResponse) -> Dict[str, Any]: + """Convert TranscriptionResponse to legacy dictionary format""" + + return { + "txt_file_path": response.files.txt_file_path, + "srt_file_path": response.files.srt_file_path, + "audio_file": response.audio_file, + "model_used": response.metrics.model_used, + "segment_count": response.metrics.segment_count, + "audio_duration": response.metrics.audio_duration, + "processing_status": response.status.value, + "processing_time": response.metrics.processing_time, + "saved_files": response.files.all_files, + "speaker_diarization_enabled": response.speaker_info.enabled, + "global_speaker_count": response.speaker_info.global_speaker_count, + "speaker_summary": response.speaker_info.speaker_summary, + "language_detected": response.metrics.language_detected, + "error_message": response.message if response.is_failed else None + } + + +class DownloadConverter: + """Converter for download models""" + + @staticmethod + def from_legacy_dict(data: Dict[str, Any]) -> DownloadResponse: + """Convert legacy dictionary format to DownloadResponse""" + + status = OperationStatus.SUCCESS if data.get("status") == "success" else OperationStatus.FAILED + + # Calculate file size if available + file_size_mb = None + if data.get("audio_file_path") and os.path.exists(data["audio_file_path"]): + try: + file_size_bytes = os.path.getsize(data["audio_file_path"]) + file_size_mb = file_size_bytes / (1024 * 1024) + except: + pass + + return DownloadResponse( + status=status, + message=data.get("error_message") if status == OperationStatus.FAILED else "下载成功", + original_url=data.get("original_url", ""), + audio_file_path=data.get("audio_file_path"), + file_size_mb=file_size_mb + ) + + @staticmethod + def to_legacy_dict(response: DownloadResponse) -> Dict[str, Any]: + """Convert DownloadResponse to legacy dictionary format""" + + return { + "status": response.status.value, + "original_url": response.original_url, + "audio_file_path": response.audio_file_path, + "error_message": response.message if response.is_failed else None + } + + +class FileOperationConverter: + """Converter for file operation models""" + + @staticmethod + def from_legacy_file_info(data: Dict[str, Any]) -> FileInfoResponse: + """Convert legacy file info format to FileInfoResponse""" + + status = OperationStatus.SUCCESS if data.get("status") == "success" else OperationStatus.FAILED + + metadata = None + if status == OperationStatus.SUCCESS: + metadata = FileMetadata( + filename=data.get("filename", ""), + full_path=data.get("file_path", ""), + file_size=data.get("file_size", 0), + file_size_mb=data.get("file_size_mb", 0.0), + created_time=datetime.fromtimestamp(data.get("created_time", 0)), + modified_time=datetime.fromtimestamp(data.get("modified_time", 0)), + file_extension=data.get("file_extension", ""), + is_audio_file=data.get("file_extension", "").lower() in ['.mp3', '.wav', '.m4a', '.flac'] + ) + + return FileInfoResponse( + status=status, + message=data.get("error_message") if status == OperationStatus.FAILED else "文件信息获取成功", + file_path=data.get("file_path", ""), + file_exists=data.get("file_exists", False), + metadata=metadata + ) + + @staticmethod + def from_legacy_file_read(data: Dict[str, Any]) -> FileReadResponse: + """Convert legacy file read format to FileReadResponse""" + + status = OperationStatus.SUCCESS if data.get("status") == "success" else OperationStatus.FAILED + + progress = None + if status == OperationStatus.SUCCESS: + progress = FileReadProgress( + current_position=data.get("current_position", 0), + file_size=data.get("file_size", 0), + bytes_read=data.get("bytes_read", 0), + content_length=data.get("content_length", 0), + progress_percentage=data.get("progress_percentage", 0.0), + end_of_file_reached=data.get("end_of_file_reached", False), + actual_boundary=data.get("actual_boundary", "") + ) + + return FileReadResponse( + status=status, + message=data.get("error_message") if status == OperationStatus.FAILED else "文件读取成功", + file_path=data.get("file_path", ""), + content=data.get("content", ""), + progress=progress + ) + + @staticmethod + def from_legacy_directory_list(data: Dict[str, Any]) -> DirectoryListResponse: + """Convert legacy directory list format to DirectoryListResponse""" + + status = OperationStatus.SUCCESS if not data.get("error_message") else OperationStatus.FAILED + + file_list = [] + if "file_list" in data: + for file_data in data["file_list"]: + # Handle time format conversion - original format is string, not ISO format + created_time_str = file_data.get("created_time", "") + modified_time_str = file_data.get("modified_time", "") + + try: + # Try to parse as ISO format first + created_time = datetime.fromisoformat(created_time_str.replace("T", " ")) if created_time_str else datetime.fromtimestamp(0) + except: + # Fallback to default time + created_time = datetime.fromtimestamp(0) + + try: + modified_time = datetime.fromisoformat(modified_time_str.replace("T", " ")) if modified_time_str else datetime.fromtimestamp(0) + except: + modified_time = datetime.fromtimestamp(0) + + file_list.append(FileMetadata( + filename=file_data.get("filename", ""), + full_path=file_data.get("full_path", ""), + file_size=file_data.get("file_size", 0), + file_size_mb=file_data.get("file_size_mb", 0.0), + created_time=created_time, + modified_time=modified_time, + file_extension=os.path.splitext(file_data.get("filename", ""))[1], + is_audio_file=file_data.get("filename", "").lower().endswith(('.mp3', '.wav', '.m4a', '.flac')) + )) + + return DirectoryListResponse( + status=status, + message=data.get("error_message", "目录扫描成功"), + directory=data.get("scanned_directory", ""), + total_files=data.get("total_files", 0), + file_list=file_list + ) \ No newline at end of file diff --git a/src/models/download.py b/src/models/download.py new file mode 100644 index 0000000000000000000000000000000000000000..6f59c9541eb5bfd907d8dc6c60dff5a8b6d35c87 --- /dev/null +++ b/src/models/download.py @@ -0,0 +1,70 @@ +""" +Download models +""" + +from dataclasses import dataclass +from typing import Optional, Dict, Any +from enum import Enum + +from .base import BaseRequest, BaseResponse, OperationStatus + + +class PodcastPlatform(str, Enum): + """Supported podcast platforms""" + APPLE_PODCAST = "apple_podcast" + XIAOYUZHOU = "xiaoyuzhou" + + +@dataclass +class DownloadRequest(BaseRequest): + """Request model for podcast download""" + url: str + platform: PodcastPlatform + output_directory: Optional[str] = None + auto_transcribe: bool = False + enable_speaker_diarization: bool = False + + +@dataclass +class DownloadResponse(BaseResponse): + """Response model for podcast download""" + original_url: str = "" + audio_file_path: Optional[str] = None + file_size_mb: Optional[float] = None + duration_seconds: Optional[float] = None + + @classmethod + def success( + cls, + original_url: str, + audio_file_path: str, + file_size_mb: Optional[float] = None, + duration_seconds: Optional[float] = None, + message: str = "下载成功" + ) -> "DownloadResponse": + """Create successful response""" + return cls( + status=OperationStatus.SUCCESS, + message=message, + original_url=original_url, + audio_file_path=audio_file_path, + file_size_mb=file_size_mb, + duration_seconds=duration_seconds + ) + + @classmethod + def failed( + cls, + original_url: str, + error_message: str, + error_code: str = "DOWNLOAD_ERROR", + error_details: Optional[Dict[str, Any]] = None + ) -> "DownloadResponse": + """Create failed response""" + return cls( + status=OperationStatus.FAILED, + message=error_message, + error_code=error_code, + error_details=error_details, + original_url=original_url + ) \ No newline at end of file diff --git a/src/models/file_operations.py b/src/models/file_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..c722129d69e9de4cb7ada33daaed111af8d94612 --- /dev/null +++ b/src/models/file_operations.py @@ -0,0 +1,180 @@ +""" +File operation models +""" + +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any +from datetime import datetime + +from .base import BaseRequest, BaseResponse, OperationStatus + + +@dataclass +class FileMetadata: + """File metadata information""" + filename: str + full_path: str + file_size: int + file_size_mb: float + created_time: datetime + modified_time: datetime + file_extension: str + is_audio_file: bool = False + + +@dataclass +class FileInfoRequest(BaseRequest): + """Request model for file information""" + file_path: str + + +@dataclass +class FileInfoResponse(BaseResponse): + """Response model for file information""" + file_path: str = "" + file_exists: bool = False + metadata: Optional[FileMetadata] = None + + @classmethod + def success( + cls, + file_path: str, + metadata: FileMetadata, + message: str = "文件信息获取成功" + ) -> "FileInfoResponse": + """Create successful response""" + return cls( + status=OperationStatus.SUCCESS, + message=message, + file_path=file_path, + file_exists=True, + metadata=metadata + ) + + @classmethod + def failed( + cls, + file_path: str, + error_message: str, + error_code: str = "FILE_ERROR", + error_details: Optional[Dict[str, Any]] = None + ) -> "FileInfoResponse": + """Create failed response""" + return cls( + status=OperationStatus.FAILED, + message=error_message, + error_code=error_code, + error_details=error_details, + file_path=file_path, + file_exists=False + ) + + +@dataclass +class FileReadRequest(BaseRequest): + """Request model for file reading""" + file_path: str + chunk_size: int = 64 * 1024 # 64KB + start_position: int = 0 + + +@dataclass +class FileReadProgress: + """File reading progress information""" + current_position: int + file_size: int + bytes_read: int + content_length: int + progress_percentage: float + end_of_file_reached: bool + actual_boundary: str + + +@dataclass +class FileReadResponse(BaseResponse): + """Response model for file reading""" + file_path: str = "" + content: str = "" + progress: Optional[FileReadProgress] = None + + @classmethod + def success( + cls, + file_path: str, + content: str, + progress: FileReadProgress, + message: str = "文件读取成功" + ) -> "FileReadResponse": + """Create successful response""" + return cls( + status=OperationStatus.SUCCESS, + message=message, + file_path=file_path, + content=content, + progress=progress + ) + + @classmethod + def failed( + cls, + file_path: str, + error_message: str, + error_code: str = "FILE_READ_ERROR", + error_details: Optional[Dict[str, Any]] = None + ) -> "FileReadResponse": + """Create failed response""" + return cls( + status=OperationStatus.FAILED, + message=error_message, + error_code=error_code, + error_details=error_details, + file_path=file_path + ) + + +@dataclass +class DirectoryListRequest(BaseRequest): + """Request model for directory listing""" + directory: str + file_extension_filter: Optional[str] = None # e.g., ".mp3" + + +@dataclass +class DirectoryListResponse(BaseResponse): + """Response model for directory listing""" + directory: str = "" + total_files: int = 0 + file_list: List[FileMetadata] = field(default_factory=list) + + @classmethod + def success( + cls, + directory: str, + file_list: List[FileMetadata], + message: str = "目录扫描成功" + ) -> "DirectoryListResponse": + """Create successful response""" + return cls( + status=OperationStatus.SUCCESS, + message=message, + directory=directory, + total_files=len(file_list), + file_list=file_list + ) + + @classmethod + def failed( + cls, + directory: str, + error_message: str, + error_code: str = "DIRECTORY_ERROR", + error_details: Optional[Dict[str, Any]] = None + ) -> "DirectoryListResponse": + """Create failed response""" + return cls( + status=OperationStatus.FAILED, + message=error_message, + error_code=error_code, + error_details=error_details, + directory=directory + ) \ No newline at end of file diff --git a/src/models/services.py b/src/models/services.py new file mode 100644 index 0000000000000000000000000000000000000000..de66516d6554947d1cb99620894d1828be7e2494 --- /dev/null +++ b/src/models/services.py @@ -0,0 +1,159 @@ +""" +Service layer specific data models +""" + +from dataclasses import dataclass, field +from typing import Dict, Any, List, Optional, Union +from enum import Enum +import numpy as np + +from .base import BaseRequest, BaseResponse, OperationStatus + + +class AudioProcessingTask(Enum): + """Audio processing task types""" + TRANSCRIPTION = "transcription" + SPEAKER_IDENTIFICATION = "speaker_identification" + AUDIO_SEGMENTATION = "audio_segmentation" + COMPLETE_PROCESSING = "complete_processing" + + +class FileOperationType(Enum): + """File operation types""" + SCAN = "scan" + READ = "read" + WRITE = "write" + ORGANIZE = "organize" + CONVERT = "convert" + + +@dataclass +class AudioProcessingRequest(BaseRequest): + """Request for audio processing operations""" + audio_path: str + task: AudioProcessingTask + model_name: str = "turbo" + language: Optional[str] = None + enable_speaker_diarization: bool = False + min_segment_length: float = 30.0 + output_path: Optional[str] = None + + +@dataclass +class AudioProcessingResult: + """Result of audio processing operations""" + audio_path: str + task: AudioProcessingTask + text: str + segments: List[Dict[str, Any]] + audio_duration: float + segment_count: int + language_detected: str + model_used: str + speaker_diarization_enabled: bool + status: OperationStatus = OperationStatus.SUCCESS + processing_time: Optional[float] = None + speakers_detected: Optional[int] = None + message: Optional[str] = None + error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None + + +@dataclass +class PodcastDownloadRequest(BaseRequest): + """Request for podcast download operations""" + url: str + output_folder: str = "downloads" + convert_to_mp3: bool = False + keep_original: bool = False + extract_info_only: bool = False + + +@dataclass +class PodcastDownloadResult: + """Result of podcast download operations""" + url: str + title: str + episode_id: str + platform: str + status: OperationStatus = OperationStatus.SUCCESS + file_path: Optional[str] = None + download_time: Optional[float] = None + file_size: Optional[int] = None + message: Optional[str] = None + error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None + + +@dataclass +class SpeakerEmbeddingRequest(BaseRequest): + """Request for speaker embedding operations""" + audio_path: str + speaker_segments: List[Dict[str, Any]] + source_file: str + update_global_speakers: bool = True + + +@dataclass +class SpeakerEmbeddingResult: + """Result of speaker embedding operations""" + audio_path: str + speaker_embeddings: Dict[str, np.ndarray] + speaker_mapping: Dict[str, str] # local_id -> global_id + speakers_created: int + speakers_updated: int + status: OperationStatus = OperationStatus.SUCCESS + message: Optional[str] = None + error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None + + +@dataclass +class FileManagementRequest(BaseRequest): + """Request for file management operations""" + operation: FileOperationType + file_path: Optional[str] = None + directory_path: Optional[str] = None + content: Optional[str] = None + options: Optional[Dict[str, Any]] = None + + +@dataclass +class FileManagementResult: + """Result of file management operations""" + operation: FileOperationType + status: OperationStatus = OperationStatus.SUCCESS + file_path: Optional[str] = None + directory_path: Optional[str] = None + files_processed: int = 0 + total_files: int = 0 + content_length: int = 0 + file_size: int = 0 + message: Optional[str] = None + error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None + + +@dataclass +class ServiceError: + """Service error information""" + service_name: str + error_code: str + error_message: str + error_details: Optional[Dict[str, Any]] = None + timestamp: Optional[str] = None + + +@dataclass +class ServiceHealthCheck: + """Service health check result""" + service_name: str + is_healthy: bool + dependencies: Dict[str, bool] + version: str + uptime: float + last_check: str + status: OperationStatus = OperationStatus.SUCCESS + message: Optional[str] = None + error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None \ No newline at end of file diff --git a/src/models/transcription.py b/src/models/transcription.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7ce1d957511fef75f2d69ba9b1aa1777d281ac --- /dev/null +++ b/src/models/transcription.py @@ -0,0 +1,126 @@ +""" +Transcription models +""" + +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any +from enum import Enum + +from .base import BaseRequest, BaseResponse, OperationStatus + + +class ModelSize(str, Enum): + """Whisper model sizes""" + TINY = "tiny" + BASE = "base" + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + TURBO = "turbo" + + +class OutputFormat(str, Enum): + """Output formats""" + TXT = "txt" + SRT = "srt" + JSON = "json" + + +@dataclass +class TranscriptionRequest(BaseRequest): + """Request model for transcription""" + audio_file_path: str + model_size: ModelSize = ModelSize.TURBO + language: Optional[str] = None + output_format: OutputFormat = OutputFormat.SRT + enable_speaker_diarization: bool = False + + +@dataclass +class TranscriptionSegment: + """Individual transcription segment""" + start: float + end: float + text: str + speaker: Optional[str] = None + confidence: Optional[float] = None + + +@dataclass +class SpeakerInfo: + """Speaker diarization information""" + enabled: bool = False + global_speaker_count: int = 0 + speaker_mapping: Dict[str, str] = field(default_factory=dict) + speaker_summary: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class TranscriptionFiles: + """Generated transcription files""" + txt_file_path: Optional[str] = None + srt_file_path: Optional[str] = None + json_file_path: Optional[str] = None + + @property + def all_files(self) -> List[str]: + """Get all non-None file paths""" + return [f for f in [self.txt_file_path, self.srt_file_path, self.json_file_path] if f] + + +@dataclass +class TranscriptionMetrics: + """Transcription processing metrics""" + audio_duration: float = 0.0 + processing_time: float = 0.0 + segment_count: int = 0 + model_used: str = "" + language_detected: str = "unknown" + + +@dataclass +class TranscriptionResponse(BaseResponse): + """Response model for transcription""" + audio_file: str = "" + files: TranscriptionFiles = field(default_factory=TranscriptionFiles) + segments: List[TranscriptionSegment] = field(default_factory=list) + speaker_info: SpeakerInfo = field(default_factory=SpeakerInfo) + metrics: TranscriptionMetrics = field(default_factory=TranscriptionMetrics) + + @classmethod + def success( + cls, + audio_file: str, + files: TranscriptionFiles, + segments: List[TranscriptionSegment], + metrics: TranscriptionMetrics, + speaker_info: Optional[SpeakerInfo] = None, + message: str = "转录完成" + ) -> "TranscriptionResponse": + """Create successful response""" + return cls( + status=OperationStatus.SUCCESS, + message=message, + audio_file=audio_file, + files=files, + segments=segments, + speaker_info=speaker_info or SpeakerInfo(), + metrics=metrics + ) + + @classmethod + def failed( + cls, + audio_file: str, + error_message: str, + error_code: str = "TRANSCRIPTION_ERROR", + error_details: Optional[Dict[str, Any]] = None + ) -> "TranscriptionResponse": + """Create failed response""" + return cls( + status=OperationStatus.FAILED, + message=error_message, + error_code=error_code, + error_details=error_details, + audio_file=audio_file + ) \ No newline at end of file diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcbc918ed182174a4daa711cf36595a30b4812c --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,133 @@ +""" +Services layer for the podcast transcription system +Provides clean abstraction for business logic +""" + +# Core transcription services +from .transcription_service import TranscriptionService +from .distributed_transcription_service import DistributedTranscriptionService + +# Modal-specific services +from .modal_transcription_service import ModalTranscriptionService +# Note: ModalDownloadService removed - downloads now handled locally by PodcastDownloadService + +# Download and media services +from .podcast_download_service import PodcastDownloadService + +# System services +from .health_service import HealthService + +# File and utility services +from .file_management_service import FileManagementService + +# Speaker services (if speaker diarization is available) +try: + from .speaker_embedding_service import SpeakerEmbeddingService + SPEAKER_DIARIZATION_AVAILABLE = True +except ImportError: + SPEAKER_DIARIZATION_AVAILABLE = False + SpeakerEmbeddingService = None + +# Deprecated services (kept for backward compatibility but should not be used) +# These services have been consolidated into other services or replaced +try: + from .audio_processing_service import AudioProcessingService + from .file_service import FileService + DEPRECATED_SERVICES_AVAILABLE = True +except ImportError: + DEPRECATED_SERVICES_AVAILABLE = False + AudioProcessingService = None + FileService = None + +# Export active services +__all__ = [ + # Primary services for active use + "TranscriptionService", + "DistributedTranscriptionService", + "ModalTranscriptionService", + "PodcastDownloadService", + "HealthService", + "FileManagementService", + + # Optional services + "SpeakerEmbeddingService", + + # Availability flags + "SPEAKER_DIARIZATION_AVAILABLE", + "DEPRECATED_SERVICES_AVAILABLE", + + # Deprecated services (for backward compatibility only) + "AudioProcessingService", + "FileService" +] + +# Service registry for dynamic access +SERVICE_REGISTRY = { + "transcription": TranscriptionService, + "distributed_transcription": DistributedTranscriptionService, + "modal_transcription": ModalTranscriptionService, + "podcast_download": PodcastDownloadService, + "health": HealthService, + "file_management": FileManagementService, +} + +if SPEAKER_DIARIZATION_AVAILABLE: + SERVICE_REGISTRY["speaker_embedding"] = SpeakerEmbeddingService + +if DEPRECATED_SERVICES_AVAILABLE: + SERVICE_REGISTRY["audio_processing"] = AudioProcessingService + SERVICE_REGISTRY["file"] = FileService + + +def get_service(service_name: str, *args, **kwargs): + """ + Factory function to get service instances + + Args: + service_name: Name of the service to get + *args: Arguments to pass to service constructor + **kwargs: Keyword arguments to pass to service constructor + + Returns: + Service instance + + Raises: + ValueError: If service name is not found + """ + if service_name not in SERVICE_REGISTRY: + available_services = list(SERVICE_REGISTRY.keys()) + raise ValueError(f"Service '{service_name}' not found. Available services: {available_services}") + + service_class = SERVICE_REGISTRY[service_name] + return service_class(*args, **kwargs) + + +def list_available_services() -> dict: + """ + Get list of all available services with their status + + Returns: + Dictionary of service names and their availability status + """ + services = {} + + # Active services + for name in ["transcription", "distributed_transcription", "modal_transcription", + "podcast_download", "health", "file_management"]: + services[name] = {"status": "active", "available": True} + + # Optional services + services["speaker_embedding"] = { + "status": "optional", + "available": SPEAKER_DIARIZATION_AVAILABLE + } + + # Deprecated services + if DEPRECATED_SERVICES_AVAILABLE: + services["audio_processing"] = {"status": "deprecated", "available": True} + services["file"] = {"status": "deprecated", "available": True} + else: + services["audio_processing"] = {"status": "deprecated", "available": False} + services["file"] = {"status": "deprecated", "available": False} + + return services \ No newline at end of file diff --git a/src/services/__pycache__/__init__.cpython-310.pyc b/src/services/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18fc01120c1635cfbccc6b2753cc7c12b92f130c Binary files /dev/null and b/src/services/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/services/__pycache__/audio_processing_service.cpython-310.pyc b/src/services/__pycache__/audio_processing_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3179668e5f5cbf5c4815e5afb235143a0c04114 Binary files /dev/null and b/src/services/__pycache__/audio_processing_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/distributed_transcription_service.cpython-310.pyc b/src/services/__pycache__/distributed_transcription_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db11198ba3792c037d8080bed7a79f4964d9a13c Binary files /dev/null and b/src/services/__pycache__/distributed_transcription_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/file_management_service.cpython-310.pyc b/src/services/__pycache__/file_management_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..511c4c866ee94c5f8e73feccfa6415bb65797c94 Binary files /dev/null and b/src/services/__pycache__/file_management_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/file_service.cpython-310.pyc b/src/services/__pycache__/file_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afd12e1525cdadf9467a27d2ed72146493608e25 Binary files /dev/null and b/src/services/__pycache__/file_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/health_service.cpython-310.pyc b/src/services/__pycache__/health_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58f1b539d32c64ce3042f530350dcf22d8d7c0b4 Binary files /dev/null and b/src/services/__pycache__/health_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/modal_download_service.cpython-310.pyc b/src/services/__pycache__/modal_download_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..070dffe4e6262916d0281bf7dba0fb1ef49cd88b Binary files /dev/null and b/src/services/__pycache__/modal_download_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/modal_transcription_service.cpython-310.pyc b/src/services/__pycache__/modal_transcription_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..730f5390626a1bac64c64ecb42d8369422235888 Binary files /dev/null and b/src/services/__pycache__/modal_transcription_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/podcast_download_service.cpython-310.pyc b/src/services/__pycache__/podcast_download_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7599a82f0da17ed5be9a2fbd1c80239d41ed480 Binary files /dev/null and b/src/services/__pycache__/podcast_download_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/speaker_embedding_service.cpython-310.pyc b/src/services/__pycache__/speaker_embedding_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..942402e685acd9645851556e2e4d9194971daca3 Binary files /dev/null and b/src/services/__pycache__/speaker_embedding_service.cpython-310.pyc differ diff --git a/src/services/__pycache__/transcription_service.cpython-310.pyc b/src/services/__pycache__/transcription_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d65f1b783d5df06d71ae885c61fb9aede17cc6f Binary files /dev/null and b/src/services/__pycache__/transcription_service.cpython-310.pyc differ diff --git a/src/services/audio_processing_service.py b/src/services/audio_processing_service.py new file mode 100644 index 0000000000000000000000000000000000000000..09c6ab3174e37c5358fb9a23c65033d31f00ea2b --- /dev/null +++ b/src/services/audio_processing_service.py @@ -0,0 +1,252 @@ +""" +Audio Processing Service - integrates audio segmentation and transcription +""" + +import re +import asyncio +import pathlib +import tempfile +from typing import Dict, Any, List, Optional + +import ffmpeg + +from ..interfaces.audio_processor import IAudioProcessor, AudioSegment +from ..interfaces.transcriber import ITranscriber +from ..interfaces.speaker_manager import ISpeakerIdentificationService +from ..utils.config import AudioProcessingConfig +from ..utils.errors import AudioProcessingError +from ..models.transcription import TranscriptionResponse, TranscriptionSegment + + +class AudioProcessingService(IAudioProcessor): + """High-level audio processing service that coordinates transcription and speaker identification""" + + def __init__( + self, + transcriber: ITranscriber, + speaker_service: Optional[ISpeakerIdentificationService] = None, + config: Optional[AudioProcessingConfig] = None + ): + self.transcriber = transcriber + self.speaker_service = speaker_service + self.config = config or AudioProcessingConfig() + + async def split_audio_by_silence( + self, + audio_path: str, + min_segment_length: float = 30.0, + min_silence_length: float = 1.0 + ) -> List[AudioSegment]: + """ + Intelligently split audio using FFmpeg's silencedetect filter + """ + try: + silence_end_re = re.compile( + r" silence_end: (?P[0-9]+(\.?[0-9]*)) \| silence_duration: (?P[0-9]+(\.?[0-9]*))" + ) + + # Get audio duration + metadata = ffmpeg.probe(audio_path) + duration = float(metadata["format"]["duration"]) + + # Use silence detection filter + reader = ( + ffmpeg.input(str(audio_path)) + .filter("silencedetect", n="-10dB", d=min_silence_length) + .output("pipe:", format="null") + .run_async(pipe_stderr=True) + ) + + segments = [] + cur_start = 0.0 + + while True: + line = reader.stderr.readline().decode("utf-8") + if not line: + break + + match = silence_end_re.search(line) + if match: + silence_end, silence_dur = match.group("end"), match.group("dur") + split_at = float(silence_end) - (float(silence_dur) / 2) + + if (split_at - cur_start) < min_segment_length: + continue + + segments.append(AudioSegment( + start=cur_start, + end=split_at, + file_path=audio_path, + duration=split_at - cur_start + )) + cur_start = split_at + + # Handle the last segment + if duration > cur_start: + segments.append(AudioSegment( + start=cur_start, + end=duration, + file_path=audio_path, + duration=duration - cur_start + )) + + print(f"Audio split into {len(segments)} segments") + return segments + + except Exception as e: + raise AudioProcessingError(f"Audio segmentation failed: {str(e)}") + + async def process_audio_segment( + self, + segment: AudioSegment, + model_name: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False + ) -> Dict[str, Any]: + """ + Process a single audio segment + """ + try: + # Create temporary segment file + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: + temp_path = temp_file.name + + # Extract segment using ffmpeg + ( + ffmpeg.input(segment.file_path, ss=segment.start, t=segment.duration) + .output(temp_path) + .overwrite_output() + .run(quiet=True) + ) + + # Transcribe segment + result = await self.transcriber.transcribe( + audio_file_path=temp_path, + model_size=model_name, + language=language, + enable_speaker_diarization=enable_speaker_diarization + ) + + # Adjust timestamps to match original audio + adjusted_segments = [] + for seg in result.segments: + adjusted_segments.append(TranscriptionSegment( + start=seg.start + segment.start, + end=seg.end + segment.start, + text=seg.text, + speaker=seg.speaker, + confidence=seg.confidence + )) + + # Clean up temp file + pathlib.Path(temp_path).unlink(missing_ok=True) + + return { + "segment_start": segment.start, + "segment_end": segment.end, + "text": result.text, + "segments": [ + { + "start": seg.start, + "end": seg.end, + "text": seg.text, + "speaker": seg.speaker, + "confidence": seg.confidence + } for seg in adjusted_segments + ], + "language_detected": result.language, + "model_used": result.model_used + } + + except Exception as e: + raise AudioProcessingError(f"Segment processing failed: {str(e)}") + + async def process_complete_audio( + self, + audio_path: str, + model_name: str = "turbo", + language: Optional[str] = None, + enable_speaker_diarization: bool = False, + min_segment_length: float = 30.0 + ) -> Dict[str, Any]: + """ + Process complete audio file with intelligent segmentation + """ + try: + print(f"🚀 Starting complete audio processing: {audio_path}") + + # Get audio metadata + metadata = ffmpeg.probe(audio_path) + total_duration = float(metadata["format"]["duration"]) + + # Split audio into segments + segments = await self.split_audio_by_silence( + audio_path=audio_path, + min_segment_length=min_segment_length, + min_silence_length=1.0 + ) + + # Process segments in parallel (with limited concurrency) + semaphore = asyncio.Semaphore(3) # Limit concurrent processing + + async def process_segment_with_semaphore(segment): + async with semaphore: + return await self.process_audio_segment( + segment=segment, + model_name=model_name, + language=language, + enable_speaker_diarization=enable_speaker_diarization + ) + + # Process all segments + segment_results = await asyncio.gather(*[ + process_segment_with_semaphore(segment) for segment in segments + ]) + + # Combine results + all_segments = [] + combined_text = [] + + for result in segment_results: + all_segments.extend(result["segments"]) + if result["text"].strip(): + combined_text.append(result["text"].strip()) + + # Apply speaker identification if enabled + if enable_speaker_diarization and self.speaker_service: + try: + speaker_segments = await self.speaker_service.identify_speakers_in_audio( + audio_path=audio_path, + transcription_segments=all_segments + ) + + # Map transcription to speakers + all_segments = await self.speaker_service.map_transcription_to_speakers( + transcription_segments=all_segments, + speaker_segments=speaker_segments + ) + except Exception as e: + print(f"⚠️ Speaker identification failed: {e}") + + return { + "text": " ".join(combined_text), + "segments": all_segments, + "audio_duration": total_duration, + "segment_count": len(all_segments), + "processing_segments": len(segments), + "language_detected": segment_results[0]["language_detected"] if segment_results else "unknown", + "model_used": model_name, + "speaker_diarization_enabled": enable_speaker_diarization, + "processing_status": "success" + } + + except Exception as e: + raise AudioProcessingError(f"Complete audio processing failed: {str(e)}") + + def get_supported_models(self) -> List[str]: + """Get supported transcription models""" + return self.transcriber.get_supported_models() + + def get_supported_languages(self) -> List[str]: + """Get supported languages""" + return self.transcriber.get_supported_languages() \ No newline at end of file diff --git a/src/services/distributed_transcription_service.py b/src/services/distributed_transcription_service.py new file mode 100644 index 0000000000000000000000000000000000000000..df6ad706f6331aafd6a3c913f9f8cee0d4131897 --- /dev/null +++ b/src/services/distributed_transcription_service.py @@ -0,0 +1,993 @@ +""" +Distributed Transcription Service +Handles audio transcription with true distributed processing across multiple Modal containers +Enhanced with intelligent audio segmentation capabilities +""" + +import asyncio +import aiohttp +import base64 +import os +import tempfile +import subprocess +import json +from pathlib import Path +from typing import Dict, Any, List, Tuple +from concurrent.futures import ThreadPoolExecutor +import time +import re + +import ffmpeg +import torch + +from .transcription_service import TranscriptionService + + +class DistributedTranscriptionService: + """Service for handling distributed audio transcription across multiple Modal containers""" + + def __init__(self, cache_dir: str = "/tmp"): + self.cache_dir = cache_dir + self.transcription_service = TranscriptionService(cache_dir) + + def split_audio_by_time(self, audio_file_path: str, chunk_duration: int = 60) -> List[Dict[str, Any]]: + """Split audio into time-based chunks""" + try: + # Get audio duration using ffprobe + duration_cmd = [ + "ffprobe", "-v", "quiet", "-show_entries", "format=duration", + "-of", "csv=p=0", audio_file_path + ] + result = subprocess.run(duration_cmd, capture_output=True, text=True, check=True) + total_duration = float(result.stdout.strip()) + + chunks = [] + start_time = 0.0 + chunk_index = 0 + + while start_time < total_duration: + end_time = min(start_time + chunk_duration, total_duration) + actual_duration = end_time - start_time + + # Skip very short chunks (less than 5 seconds) + if actual_duration < 5.0: + break + + chunk_filename = f"chunk_{chunk_index:03d}.wav" + chunks.append({ + "chunk_index": chunk_index, + "start_time": start_time, + "end_time": end_time, + "duration": actual_duration, + "filename": chunk_filename + }) + + start_time = end_time + chunk_index += 1 + + print(f"📊 Split audio into {len(chunks)} time-based chunks") + return chunks + + except Exception as e: + print(f"❌ Error splitting audio by time: {e}") + return [] + + def split_audio_by_silence( + self, + audio_file_path: str, + min_segment_length: float = 30.0, + min_silence_length: float = 1.0, + max_segment_length: float = 120.0 + ) -> List[Dict[str, Any]]: + """ + Intelligently split audio using FFmpeg's silencedetect filter + Enhanced from AudioProcessingService + """ + try: + silence_end_re = re.compile( + r" silence_end: (?P[0-9]+(\.?[0-9]*)) \| silence_duration: (?P[0-9]+(\.?[0-9]*))" + ) + + # Get audio duration + metadata = ffmpeg.probe(audio_file_path) + total_duration = float(metadata["format"]["duration"]) + + print(f"🎵 Audio duration: {total_duration:.2f}s") + print(f"🔍 Detecting silence with min_silence_length={min_silence_length}s...") + + # Use silence detection filter + cmd = [ + "ffmpeg", "-i", audio_file_path, + "-af", f"silencedetect=noise=-30dB:duration={min_silence_length}", + "-f", "null", "-" + ] + + process = subprocess.Popen( + cmd, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True + ) + + segments = [] + cur_start = 0.0 + chunk_index = 0 + + # Process silence detection output + for line in process.stderr: + match = silence_end_re.search(line) + if match: + silence_end = float(match.group("end")) + silence_dur = float(match.group("dur")) + split_at = silence_end - (silence_dur / 2) + + segment_duration = split_at - cur_start + + # Skip segments that are too short + if segment_duration < min_segment_length: + continue + + # Split long segments + if segment_duration > max_segment_length: + # Split into multiple smaller segments + sub_start = cur_start + while sub_start < split_at: + sub_end = min(sub_start + max_segment_length, split_at) + sub_duration = sub_end - sub_start + + if sub_duration >= min_segment_length: + segments.append({ + "chunk_index": chunk_index, + "start_time": sub_start, + "end_time": sub_end, + "duration": sub_duration, + "filename": f"silence_chunk_{chunk_index:03d}.wav", + "segmentation_type": "silence_based" + }) + chunk_index += 1 + + sub_start = sub_end + else: + segments.append({ + "chunk_index": chunk_index, + "start_time": cur_start, + "end_time": split_at, + "duration": segment_duration, + "filename": f"silence_chunk_{chunk_index:03d}.wav", + "segmentation_type": "silence_based" + }) + chunk_index += 1 + + cur_start = split_at + + process.wait() + + # Handle the last segment + if total_duration > cur_start: + remaining_duration = total_duration - cur_start + if remaining_duration >= min_segment_length: + segments.append({ + "chunk_index": chunk_index, + "start_time": cur_start, + "end_time": total_duration, + "duration": remaining_duration, + "filename": f"silence_chunk_{chunk_index:03d}.wav", + "segmentation_type": "silence_based" + }) + + print(f"🎯 Silence-based segmentation created {len(segments)} segments") + return segments + + except Exception as e: + print(f"⚠️ Silence-based segmentation failed: {e}") + # Fallback to time-based segmentation + print("📋 Falling back to time-based segmentation...") + return self.split_audio_by_time(audio_file_path, chunk_duration=60) + + def choose_segmentation_strategy( + self, + audio_file_path: str, + use_intelligent_segmentation: bool = True, + chunk_duration: int = 60 + ) -> List[Dict[str, Any]]: + """ + Choose the best segmentation strategy based on audio characteristics + """ + try: + # Get audio metadata + metadata = ffmpeg.probe(audio_file_path) + duration = float(metadata["format"]["duration"]) + + print(f"🎛️ Choosing segmentation strategy for {duration:.2f}s audio...") + + # For short audio (< 30s), use single processing + if duration < 30: + print("📝 Audio is short, using single chunk") + return [{ + "chunk_index": 0, + "start_time": 0.0, + "end_time": duration, + "duration": duration, + "filename": "single_chunk.wav", + "segmentation_type": "single" + }] + + # For longer audio, choose based on user preference + if use_intelligent_segmentation: + print("🧠 Using intelligent silence-based segmentation") + segments = self.split_audio_by_silence( + audio_file_path, + min_segment_length=30.0, + min_silence_length=1.0, + max_segment_length=120.0 + ) + + # NEW: Check if silence-based segmentation failed for long audio + if duration > 180 and len(segments) == 1: # Audio > 3 minutes with only 1 segment + print(f"⚠️ Silence-based segmentation created only 1 segment for {duration:.2f}s audio") + print("🔄 Falling back to 3-minute time-based segmentation for better processing efficiency") + return self.split_audio_by_time(audio_file_path, chunk_duration=180) # 3-minute chunks + + # If silence-based segmentation didn't work well, fallback to time-based + if len(segments) == 0 or len(segments) > duration / 20: # Too many tiny segments + print("🔄 Silence segmentation not optimal, using time-based") + return self.split_audio_by_time(audio_file_path, chunk_duration) + + return segments + else: + print("⏰ Using time-based segmentation") + return self.split_audio_by_time(audio_file_path, chunk_duration) + + except Exception as e: + print(f"❌ Error in segmentation strategy: {e}") + # Ultimate fallback + return self.split_audio_by_time(audio_file_path, chunk_duration) + + def split_audio_locally( + self, + audio_file_path: str, + chunk_duration: int = 60, + use_intelligent_segmentation: bool = True + ) -> List[Tuple[str, float, float]]: + """ + Split audio file into chunks locally for distributed processing using intelligent segmentation + + Args: + audio_file_path: Path to audio file + chunk_duration: Duration of each chunk in seconds + use_intelligent_segmentation: Whether to use intelligent silence-based segmentation + + Returns: + List of (chunk_file_path, start_time, end_time) tuples + """ + try: + # Choose segmentation strategy + segments = self.choose_segmentation_strategy( + audio_file_path, + use_intelligent_segmentation=use_intelligent_segmentation, + chunk_duration=chunk_duration + ) + + if not segments: + print("❌ No segments generated") + return [] + + print(f"🎵 Processing {len(segments)} segments using {segments[0].get('segmentation_type', 'time_based')} segmentation") + + # Create temporary directory for chunks + temp_dir = tempfile.mkdtemp(prefix="audio_chunks_") + chunks = [] + + for segment in segments: + start_time = segment["start_time"] + end_time = segment["end_time"] + duration = segment["duration"] + + # Create chunk file path + chunk_filename = f"chunk_{segment['chunk_index']:03d}_{start_time:.1f}s-{end_time:.1f}s.wav" + chunk_path = os.path.join(temp_dir, chunk_filename) + + # Extract chunk using ffmpeg-python (no subprocess) + try: + ( + ffmpeg + .input(audio_file_path, ss=start_time, t=duration) + .output( + chunk_path, + acodec='pcm_s16le', + ar=16000, + ac=1 + ) + .overwrite_output() + .run(quiet=True, capture_stdout=True, capture_stderr=True) + ) + except ffmpeg.Error as e: + print(f"❌ FFmpeg error for chunk {segment['chunk_index']+1}: {e}") + print(f" stderr: {e.stderr.decode() if e.stderr else 'No stderr'}") + continue + + if os.path.exists(chunk_path) and os.path.getsize(chunk_path) > 0: + chunks.append((chunk_path, start_time, end_time)) + segmentation_type = segment.get('segmentation_type', 'time_based') + print(f"📦 Created {segmentation_type} chunk {segment['chunk_index']+1}: {start_time:.1f}s-{end_time:.1f}s") + else: + print(f"⚠️ Failed to create chunk {segment['chunk_index']+1}") + + return chunks + + except Exception as e: + print(f"❌ Error splitting audio: {e}") + return [] + + async def transcribe_chunk_distributed( + self, + chunk_path: str, + start_time: float, + end_time: float, + model_size: str = "turbo", + language: str = None, + enable_speaker_diarization: bool = False, + chunk_endpoint_url: str = None + ) -> Dict[str, Any]: + """ + Transcribe a single chunk using Modal distributed endpoint + + Args: + chunk_path: Path to audio chunk file + start_time: Start time of chunk in original audio + end_time: End time of chunk in original audio + model_size: Whisper model size + language: Language code + enable_speaker_diarization: Whether to enable speaker diarization + chunk_endpoint_url: URL of chunk transcription endpoint + + Returns: + Transcription result for the chunk + """ + try: + # Read and encode chunk file + with open(chunk_path, "rb") as f: + audio_data = f.read() + + audio_base64 = base64.b64encode(audio_data).decode('utf-8') + + # Prepare request data + request_data = { + "audio_file_data": audio_base64, + "audio_file_name": os.path.basename(chunk_path), + "model_size": model_size, + "language": language, + "output_format": "json", # Use JSON for easier merging + "enable_speaker_diarization": enable_speaker_diarization, + "chunk_start_time": start_time, + "chunk_end_time": end_time + } + + # Send request to Modal chunk endpoint with retry mechanism + max_retries = 3 + for attempt in range(max_retries): + try: + # Adjust timeout based on whether speaker diarization is enabled + if enable_speaker_diarization: + timeout_config = aiohttp.ClientTimeout( + total=720, # 12 minutes total for speaker diarization + connect=45, # 45 seconds connection timeout + sock_read=300 # 5 minutes read timeout for speaker processing + ) + else: + timeout_config = aiohttp.ClientTimeout( + total=480, # 8 minutes total for regular transcription + connect=30, # 30 seconds connection timeout + sock_read=120 # 2 minutes read timeout for regular processing + ) + + async with aiohttp.ClientSession(timeout=timeout_config) as session: + async with session.post( + chunk_endpoint_url, + json=request_data + ) as response: + if response.status == 200: + result = await response.json() + result["chunk_start_time"] = start_time + result["chunk_end_time"] = end_time + result["chunk_file"] = chunk_path + return result + else: + error_text = await response.text() + if attempt < max_retries - 1: + print(f"⚠️ HTTP {response.status} on attempt {attempt + 1}, retrying...") + await asyncio.sleep(2 ** attempt) # Exponential backoff + continue + else: + return { + "processing_status": "failed", + "error_message": f"HTTP {response.status} after {max_retries} attempts: {error_text}", + "chunk_start_time": start_time, + "chunk_end_time": end_time, + "chunk_file": chunk_path + } + + except (asyncio.TimeoutError, aiohttp.ClientError) as e: + if attempt < max_retries - 1: + print(f"⚠️ Network error on attempt {attempt + 1}: {e}, retrying...") + await asyncio.sleep(2 ** attempt) # Exponential backoff + continue + else: + return { + "processing_status": "failed", + "error_message": f"Network error after {max_retries} attempts: {e}", + "chunk_start_time": start_time, + "chunk_end_time": end_time, + "chunk_file": chunk_path + } + + except Exception as e: + return { + "processing_status": "failed", + "error_message": str(e), + "chunk_start_time": start_time, + "chunk_end_time": end_time, + "chunk_file": chunk_path + } + + async def merge_chunk_results( + self, + chunk_results: List[Dict[str, Any]], + output_format: str = "srt", + enable_speaker_diarization: bool = False, + audio_file_path: str = None + ) -> Dict[str, Any]: + """ + Merge transcription results from multiple chunks + + Args: + chunk_results: List of chunk transcription results + output_format: Output format (srt, txt, json) + enable_speaker_diarization: Whether speaker diarization was enabled + audio_file_path: Path to original audio file (needed for speaker embedding) + + Returns: + Merged transcription result + """ + try: + print(f"🔗 Starting merge_chunk_results: {len(chunk_results)} chunks to process") + + # Filter successful chunks + successful_chunks = [ + chunk for chunk in chunk_results + if chunk.get("processing_status") == "success" + ] + + failed_chunks = [ + chunk for chunk in chunk_results + if chunk.get("processing_status") != "success" + ] + + print(f"📊 Chunk processing results: {len(successful_chunks)} successful, {len(failed_chunks)} failed") + + if not successful_chunks: + print("❌ All chunks failed - returning failure result") + return { + "processing_status": "failed", + "error_message": "All chunks failed to process", + "chunks_processed": 0, + "chunks_failed": len(failed_chunks) + } + + # Sort chunks by start time + successful_chunks.sort(key=lambda x: x.get("chunk_start_time", 0)) + print(f"📈 Sorted {len(successful_chunks)} successful chunks by start time") + + # Apply speaker embedding unification if speaker diarization is enabled + speaker_mapping = {} + if enable_speaker_diarization and audio_file_path: + print(f"🎤 Speaker diarization enabled, attempting speaker unification...") + try: + from .speaker_embedding_service import SpeakerIdentificationService, SpeakerEmbeddingService + from ..utils.config import AudioProcessingConfig + + print(f"✅ Successfully imported speaker embedding services") + + # Initialize speaker services + embedding_manager = SpeakerEmbeddingService() + speaker_service = SpeakerIdentificationService(embedding_manager) + + print(f"✅ Speaker services initialized") + + # Unify speakers across chunks using embedding similarity + print("🎤 Unifying speakers across chunks using embedding similarity...") + speaker_mapping = await speaker_service.unify_distributed_speakers( + successful_chunks, audio_file_path + ) + + print(f"✅ Speaker unification returned mapping with {len(speaker_mapping)} entries") + + if speaker_mapping: + print(f"✅ Speaker unification completed: {len(set(speaker_mapping.values()))} unified speakers") + else: + print("⚠️ Speaker unification returned empty mapping") + + except Exception as e: + print(f"⚠️ Speaker unification failed: {e}") + print(f" Exception type: {type(e).__name__}") + import traceback + print(f" Traceback: {traceback.format_exc()}") + print("📋 Continuing with original speaker labels...") + speaker_mapping = {} + else: + if enable_speaker_diarization: + print("⚠️ Speaker diarization enabled but no audio_file_path provided") + if audio_file_path: + print("ℹ️ Audio file path provided but speaker diarization disabled") + + # Merge segments + all_segments = [] + total_duration = 0 + segment_count = 0 + + # First pass: collect all segments and mark missing speakers as UNKNOWN + print("📝 First pass: collecting segments and marking unknown speakers...") + for chunk_idx, chunk in enumerate(successful_chunks): + chunk_start = chunk.get("chunk_start_time", 0) + chunk_segments = chunk.get("segments", []) + + for segment in chunk_segments: + # Adjust segment timestamps to global timeline + adjusted_segment = segment.copy() + adjusted_segment["start"] = segment["start"] + chunk_start + adjusted_segment["end"] = segment["end"] + chunk_start + + # Mark segments without speaker as UNKNOWN + if "speaker" not in segment or not segment["speaker"]: + adjusted_segment["speaker"] = "UNKNOWN" + adjusted_segment["chunk_id"] = chunk_idx + else: + # Preserve original speaker for embedding-based reassignment + adjusted_segment["original_speaker"] = segment["speaker"] + adjusted_segment["chunk_id"] = chunk_idx + # Temporarily use chunk-local speaker ID for embedding processing + adjusted_segment["speaker"] = f"chunk_{chunk_idx}_{segment['speaker']}" + + all_segments.append(adjusted_segment) + + segment_count += len(chunk_segments) + chunk_duration = chunk.get("audio_duration", 0) + if chunk_duration > 0: + total_duration = max(total_duration, chunk_start + chunk_duration) + + print(f"📊 Collected {len(all_segments)} segments from {len(successful_chunks)} chunks") + + # Second pass: Apply embedding-based speaker unification if enabled + final_speaker_mapping = {} + if enable_speaker_diarization and audio_file_path and speaker_mapping: + print("🎤 Second pass: applying embedding-based speaker unification...") + + # Create final speaker mapping based on embedding results + for mapping_key, unified_speaker_id in speaker_mapping.items(): + final_speaker_mapping[mapping_key] = unified_speaker_id + + # Apply the unified speaker mapping to segments + for segment in all_segments: + if segment["speaker"] != "UNKNOWN": + chunk_id = segment["chunk_id"] + original_speaker = segment.get("original_speaker", "") + mapping_key = f"chunk_{chunk_id}_{original_speaker}" + + if mapping_key in final_speaker_mapping: + segment["speaker"] = final_speaker_mapping[mapping_key] + print(f"🎯 Mapped chunk_{chunk_id}_{original_speaker} -> {segment['speaker']}") + else: + # Fallback: create a new speaker ID if not found in mapping + segment["speaker"] = f"SPEAKER_UNMATCHED_{chunk_id}_{original_speaker}" + print(f"⚠️ No mapping found for {mapping_key}, using fallback ID") + + print(f"✅ Applied speaker unification to segments") + else: + print("ℹ️ Speaker diarization disabled or no speaker mapping available") + # For segments with speakers but no diarization, use chunk-local naming + for segment in all_segments: + if segment["speaker"] != "UNKNOWN" and segment["speaker"].startswith("chunk_"): + chunk_id = segment["chunk_id"] + original_speaker = segment.get("original_speaker", "") + segment["speaker"] = f"SPEAKER_CHUNK_{chunk_id}_{original_speaker}" + + # Third pass: Filter and generate output files + print("📄 Third pass: generating output files...") + + # Separate segments by speaker type + known_speaker_segments = [seg for seg in all_segments if seg["speaker"] != "UNKNOWN"] + unknown_speaker_segments = [seg for seg in all_segments if seg["speaker"] == "UNKNOWN"] + + print(f"📊 Segment distribution:") + print(f" Known speakers: {len(known_speaker_segments)} segments") + print(f" Unknown speakers: {len(unknown_speaker_segments)} segments (will be filtered)") + + # Generate output files (excluding UNKNOWN speakers) + output_files = self._generate_output_files( + known_speaker_segments, # Only include segments with known speakers + output_format, + enable_speaker_diarization + ) + + # Collect speaker information based on filtered segments + speaker_info = self._collect_speaker_information_from_segments( + known_speaker_segments, enable_speaker_diarization + ) + + # Determine language (use most common language from chunks) + languages = [chunk.get("language_detected", "unknown") for chunk in successful_chunks] + most_common_language = max(set(languages), key=languages.count) if languages else "unknown" + + # Combine text from known speaker segments only + full_text = " ".join([seg.get("text", "").strip() for seg in known_speaker_segments if seg.get("text", "").strip()]) + + print(f"🔗 merge_chunk_results completion summary:") + print(f" Total segments collected: {len(all_segments)}") + print(f" Known speaker segments: {len(known_speaker_segments)}") + print(f" Unknown speaker segments filtered: {len(unknown_speaker_segments)}") + print(f" Final text length: {len(full_text)} characters") + print(f" Language detected: {most_common_language}") + print(f" Distributed processing flag: True") + + return { + "processing_status": "success", + "txt_file_path": output_files.get("txt_file_path"), + "srt_file_path": output_files.get("srt_file_path"), + "audio_duration": total_duration, + "segment_count": len(known_speaker_segments), # Count only known speaker segments + "total_segments_collected": len(all_segments), # Total including UNKNOWN + "unknown_segments_filtered": len(unknown_speaker_segments), # UNKNOWN segments count + "language_detected": most_common_language, + "model_used": successful_chunks[0].get("model_used", "turbo") if successful_chunks else "turbo", + "distributed_processing": True, + "chunks_processed": len(successful_chunks), + "chunks_failed": len(failed_chunks), + "speaker_diarization_enabled": enable_speaker_diarization, + "speaker_embedding_unified": len(speaker_mapping) > 0 if speaker_mapping else False, + "text": full_text, # Add full text for client-side file saving (filtered) + "segments": known_speaker_segments, # Add segments for client-side file saving (filtered) + **speaker_info + } + + except Exception as e: + print(f"❌ Error in merge_chunk_results: {e}") + print(f" Exception type: {type(e).__name__}") + import traceback + print(f" Traceback: {traceback.format_exc()}") + return { + "processing_status": "failed", + "error_message": f"Error merging chunk results: {e}", + "chunks_processed": len(successful_chunks) if 'successful_chunks' in locals() else 0, + "chunks_failed": len(failed_chunks) if 'failed_chunks' in locals() else len(chunk_results) + } + + def _generate_output_files( + self, + segments: List[Dict], + output_format: str, + enable_speaker_diarization: bool + ) -> Dict[str, str]: + """Generate output files from merged segments (excluding UNKNOWN speakers)""" + try: + # Create output directory + output_dir = Path(self.cache_dir) / "transcribe" + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate timestamp for unique filenames + timestamp = int(time.time()) + base_filename = f"distributed_transcription_{timestamp}" + + output_files = {} + + # Filter segments: only include segments with actual text content + valid_segments = [] + for segment in segments: + text = segment.get("text", "").strip() + speaker = segment.get("speaker", "UNKNOWN") + + # Skip segments with no text or UNKNOWN speaker + if text and speaker != "UNKNOWN": + valid_segments.append(segment) + + print(f"📝 Generating output files with {len(valid_segments)} valid segments (filtered from {len(segments)} total)") + + # Generate TXT file + txt_path = output_dir / f"{base_filename}.txt" + with open(txt_path, "w", encoding="utf-8") as f: + for segment in valid_segments: + text = segment.get("text", "").strip() + if enable_speaker_diarization and "speaker" in segment: + f.write(f"[{segment['speaker']}] {text}\n") + else: + f.write(f"{text}\n") + output_files["txt_file_path"] = str(txt_path) + + # Generate SRT file if requested + if output_format in ["srt", "both"]: + srt_path = output_dir / f"{base_filename}.srt" + with open(srt_path, "w", encoding="utf-8") as f: + srt_index = 1 + for segment in valid_segments: + start_time = self._format_srt_time(segment.get("start", 0)) + end_time = self._format_srt_time(segment.get("end", 0)) + text = segment.get("text", "").strip() + + if enable_speaker_diarization and "speaker" in segment: + text = f"[{segment['speaker']}] {text}" + + f.write(f"{srt_index}\n") + f.write(f"{start_time} --> {end_time}\n") + f.write(f"{text}\n\n") + srt_index += 1 + + output_files["srt_file_path"] = str(srt_path) + + print(f"✅ Generated output files: {list(output_files.keys())}") + return output_files + + except Exception as e: + print(f"❌ Error generating output files: {e}") + return {} + + def _format_srt_time(self, seconds: float) -> str: + """Format seconds to SRT time format""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + millisecs = int((seconds % 1) * 1000) + return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}" + + def _collect_speaker_information_from_segments( + self, + segments: List[Dict], + enable_speaker_diarization: bool + ) -> Dict[str, Any]: + """Collect and merge speaker information from segments""" + if not enable_speaker_diarization: + return {} + + try: + # Collect all speakers from segments + all_speakers = set() + speaker_summary = {} + + for segment in segments: + speaker = segment.get("speaker", "UNKNOWN") + if speaker != "UNKNOWN": + all_speakers.add(speaker) + + if speaker not in speaker_summary: + speaker_summary[speaker] = { + "total_duration": 0, + "segment_count": 0 + } + + # Calculate segment duration from start and end times + segment_duration = segment.get("end", 0) - segment.get("start", 0) + speaker_summary[speaker]["total_duration"] += segment_duration + speaker_summary[speaker]["segment_count"] += 1 + + return { + "global_speaker_count": len(all_speakers), + "speakers_detected": list(all_speakers), + "speaker_summary": speaker_summary + } + + except Exception as e: + print(f"⚠️ Error collecting speaker information: {e}") + print(f" Segment data types: {[type(seg.get('duration', 0)) for seg in segments]}") + return { + "global_speaker_count": 0, + "speakers_detected": [], + "speaker_summary": {} + } + + async def transcribe_audio_distributed( + self, + audio_file_path: str, + model_size: str = "turbo", + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False, + chunk_duration: int = 60, + use_intelligent_segmentation: bool = True, + chunk_endpoint_url: str = None + ) -> Dict[str, Any]: + """ + Transcribe audio using distributed processing across multiple Modal containers + + Args: + audio_file_path: Path to audio file + model_size: Whisper model size + language: Language code + output_format: Output format + enable_speaker_diarization: Whether to enable speaker diarization + chunk_duration: Duration of each chunk in seconds + use_intelligent_segmentation: Whether to use intelligent segmentation + chunk_endpoint_url: URL of chunk transcription endpoint + + Returns: + Transcription result dictionary + """ + temp_files = [] + + try: + print(f"🚀 Starting distributed transcription for: {audio_file_path}") + print(f"🚀 Using model: {model_size}") + print(f"⚡ Chunk duration: {chunk_duration}s") + + # Step 1: Split audio locally into chunks + chunks = self.split_audio_locally( + audio_file_path, + chunk_duration, + use_intelligent_segmentation + ) + + if not chunks: + return { + "processing_status": "failed", + "error_message": "Failed to split audio into chunks" + } + + temp_files.extend([chunk[0] for chunk in chunks]) + + # Step 2: Process all chunks concurrently (no batching) + print(f"🔄 Processing {len(chunks)} chunks concurrently across multiple containers...") + + # Set default chunk endpoint URL if not provided + if not chunk_endpoint_url: + chunk_endpoint_url = "https://richardsucran--transcribe-audio-chunk-endpoint.modal.run" + + # Create all tasks simultaneously for maximum concurrency + all_tasks = [] + for chunk_idx, (chunk_path, start_time, end_time) in enumerate(chunks): + # Create a coroutine first + coro = self.transcribe_chunk_distributed( + chunk_path=chunk_path, + start_time=start_time, + end_time=end_time, + model_size=model_size, + language=language, + enable_speaker_diarization=enable_speaker_diarization, + chunk_endpoint_url=chunk_endpoint_url + ) + # Convert coroutine to Task explicitly + task = asyncio.create_task(coro) + all_tasks.append((chunk_idx, task)) + + print(f"📤 Launched {len(all_tasks)} concurrent transcription tasks") + + # Process results as they complete (optimal resource utilization) + chunk_results = [None] * len(chunks) # Pre-allocate results array + completed_count = 0 + failed_count = 0 + + # Set timeout based on speaker diarization + total_timeout = 1800 if enable_speaker_diarization else 1200 # 30min vs 20min total + print(f"⏰ Total processing timeout: {total_timeout//60} minutes") + + try: + # Use asyncio.wait with return_when=FIRST_COMPLETED for real-time progress + pending_tasks = {task: chunk_idx for chunk_idx, task in all_tasks} + + start_time = asyncio.get_event_loop().time() + + while pending_tasks: + # Check for timeout + elapsed = asyncio.get_event_loop().time() - start_time + if elapsed > total_timeout: + print(f"⏰ Total timeout reached ({total_timeout//60} minutes), cancelling remaining tasks...") + for task in pending_tasks.keys(): + task.cancel() + break + + # Wait for at least one task to complete + remaining_timeout = total_timeout - elapsed + done, pending = await asyncio.wait( + pending_tasks.keys(), + return_when=asyncio.FIRST_COMPLETED, + timeout=min(60, remaining_timeout) # Check every minute + ) + + # Process completed tasks + for task in done: + chunk_idx = pending_tasks.pop(task) + try: + result = await task + chunk_results[chunk_idx] = result + + if result.get("processing_status") == "success": + completed_count += 1 + print(f"✅ Chunk {chunk_idx + 1}/{len(chunks)} completed successfully") + else: + failed_count += 1 + error_msg = result.get("error_message", "Unknown error") + print(f"❌ Chunk {chunk_idx + 1}/{len(chunks)} failed: {error_msg}") + + except Exception as e: + failed_count += 1 + chunk_results[chunk_idx] = { + "processing_status": "failed", + "error_message": str(e), + "chunk_start_time": chunks[chunk_idx][1], + "chunk_end_time": chunks[chunk_idx][2], + "chunk_file": chunks[chunk_idx][0] + } + print(f"❌ Chunk {chunk_idx + 1}/{len(chunks)} exception: {e}") + + # Show progress + total_processed = completed_count + failed_count + if total_processed > 0: + print(f"📊 Progress: {total_processed}/{len(chunks)} chunks processed " + f"({completed_count} ✅, {failed_count} ❌)") + + # Handle any remaining cancelled tasks + for task, chunk_idx in pending_tasks.items(): + if chunk_results[chunk_idx] is None: + chunk_results[chunk_idx] = { + "processing_status": "failed", + "error_message": "Task cancelled due to timeout", + "chunk_start_time": chunks[chunk_idx][1], + "chunk_end_time": chunks[chunk_idx][2], + "chunk_file": chunks[chunk_idx][0] + } + failed_count += 1 + + except Exception as e: + print(f"❌ Error during concurrent processing: {e}") + # Fill in any missing results + for i, result in enumerate(chunk_results): + if result is None: + chunk_results[i] = { + "processing_status": "failed", + "error_message": f"Processing error: {e}", + "chunk_start_time": chunks[i][1], + "chunk_end_time": chunks[i][2], + "chunk_file": chunks[i][0] + } + + print(f"🏁 Concurrent processing completed: {completed_count} successful, {failed_count} failed") + + # Step 3: Merge results from all chunks + print("🔗 Merging results from all chunks...") + final_result = await self.merge_chunk_results( + chunk_results, + output_format, + enable_speaker_diarization, + audio_file_path + ) + + print(f"✅ Distributed transcription completed successfully") + print(f" Chunks processed: {final_result.get('chunks_processed', 0)}") + print(f" Chunks failed: {final_result.get('chunks_failed', 0)}") + print(f" Total segments: {final_result.get('segment_count', 0)}") + print(f" Duration: {final_result.get('audio_duration', 0):.2f}s") + + return final_result + + except Exception as e: + return { + "processing_status": "failed", + "error_message": f"Distributed transcription failed: {e}", + "chunks_processed": 0, + "chunks_failed": len(chunks) if 'chunks' in locals() else 0 + } + + finally: + # Clean up temporary files + for temp_file in temp_files: + try: + if os.path.exists(temp_file): + os.remove(temp_file) + except Exception as e: + print(f"⚠️ Failed to clean up temp file {temp_file}: {e}") + + # Clean up temporary directories + for chunk_path, _, _ in chunks if 'chunks' in locals() else []: + try: + temp_dir = os.path.dirname(chunk_path) + if temp_dir.startswith("/tmp/audio_chunks_"): + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + except Exception as e: + print(f"⚠️ Failed to clean up temp directory: {e}") \ No newline at end of file diff --git a/src/services/file_management_service.py b/src/services/file_management_service.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0965932ddad8ef8b83ee90a3ddde89bea325fc --- /dev/null +++ b/src/services/file_management_service.py @@ -0,0 +1,452 @@ +""" +File Management Service - handles audio file and text file operations +""" + +import os +import asyncio +from pathlib import Path +from typing import Dict, Any, List, Optional +from datetime import datetime + +from ..utils.errors import FileProcessingError + + +class FileManagementService: + """Service for managing audio files and text files""" + + def __init__(self, base_directory: str = "."): + self.base_directory = Path(base_directory) + + # ==================== MP3/Audio File Management ==================== + + async def scan_mp3_files(self, directory: str) -> Dict[str, Any]: + """ + Scan directory for MP3 files and return detailed information + """ + try: + scan_path = Path(directory) + if not scan_path.exists(): + raise FileProcessingError(f"Directory does not exist: {directory}") + + if not scan_path.is_dir(): + raise FileProcessingError(f"Path is not a directory: {directory}") + + mp3_files = [] + + # Scan for MP3 files + for file_path in scan_path.rglob("*.mp3"): + try: + stat = file_path.stat() + file_info = { + "filename": file_path.name, + "full_path": str(file_path.absolute()), + "file_size": stat.st_size, + "file_size_mb": round(stat.st_size / (1024 * 1024), 2), + "created_time": datetime.fromtimestamp(stat.st_ctime).strftime("%Y-%m-%d %H:%M:%S"), + "modified_time": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S") + } + mp3_files.append(file_info) + except Exception as e: + print(f"⚠️ Error processing file {file_path}: {e}") + continue + + # Sort by modification time (newest first) + mp3_files.sort(key=lambda x: x["modified_time"], reverse=True) + + return { + "total_files": len(mp3_files), + "scanned_directory": str(scan_path.absolute()), + "file_list": mp3_files + } + + except Exception as e: + return { + "total_files": 0, + "scanned_directory": directory, + "file_list": [], + "error_message": str(e) + } + + async def get_file_info(self, file_path: str) -> Dict[str, Any]: + """ + Get detailed information about a specific file + """ + try: + path = Path(file_path) + + if not path.exists(): + return { + "status": "failed", + "file_path": file_path, + "file_exists": False, + "error_message": "File does not exist" + } + + stat = path.stat() + + return { + "status": "success", + "file_path": file_path, + "file_exists": True, + "filename": path.name, + "file_size": stat.st_size, + "file_size_mb": round(stat.st_size / (1024 * 1024), 2), + "created_time": stat.st_ctime, + "modified_time": stat.st_mtime, + "is_file": path.is_file(), + "file_extension": path.suffix + } + + except Exception as e: + return { + "status": "failed", + "file_path": file_path, + "file_exists": False, + "error_message": str(e) + } + + async def organize_audio_files( + self, + source_directory: str, + target_directory: str = None, + organize_by: str = "date" # "date", "size", "name" + ) -> Dict[str, Any]: + """ + Organize audio files in a directory structure + """ + try: + source_path = Path(source_directory) + target_path = Path(target_directory) if target_directory else source_path / "organized" + + if not source_path.exists(): + raise FileProcessingError(f"Source directory does not exist: {source_directory}") + + # Scan for audio files + audio_extensions = {".mp3", ".wav", ".m4a", ".aac", ".ogg", ".flac"} + audio_files = [] + + for ext in audio_extensions: + audio_files.extend(source_path.rglob(f"*{ext}")) + + organized_count = 0 + + for audio_file in audio_files: + try: + # Determine target subdirectory based on organization method + if organize_by == "date": + stat = audio_file.stat() + date_folder = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m") + target_subdir = target_path / date_folder + elif organize_by == "size": + stat = audio_file.stat() + size_mb = stat.st_size / (1024 * 1024) + if size_mb < 10: + size_folder = "small" + elif size_mb < 100: + size_folder = "medium" + else: + size_folder = "large" + target_subdir = target_path / size_folder + else: # organize by name + first_letter = audio_file.name[0].upper() + target_subdir = target_path / first_letter + + # Create target directory + target_subdir.mkdir(parents=True, exist_ok=True) + + # Move file + target_file = target_subdir / audio_file.name + if not target_file.exists(): + audio_file.rename(target_file) + organized_count += 1 + + except Exception as e: + print(f"⚠️ Error organizing file {audio_file}: {e}") + continue + + return { + "status": "success", + "total_files_found": len(audio_files), + "files_organized": organized_count, + "target_directory": str(target_path), + "organization_method": organize_by + } + + except Exception as e: + return { + "status": "failed", + "error_message": str(e) + } + + # ==================== Text File Management ==================== + + async def read_text_file_segments( + self, + file_path: str, + chunk_size: int = 65536, # 64KB + start_position: int = 0 + ) -> Dict[str, Any]: + """ + Read text file content in segments with intelligent boundary detection + """ + try: + path = Path(file_path) + + if not path.exists(): + return { + "status": "failed", + "file_path": file_path, + "error_message": "File does not exist" + } + + file_size = path.stat().st_size + + if start_position >= file_size: + return { + "status": "success", + "file_path": file_path, + "content": "", + "current_position": file_size, + "file_size": file_size, + "end_of_file_reached": True, + "bytes_read": 0, + "content_length": 0, + "progress_percentage": 100.0, + "actual_boundary": "end_of_file" + } + + with open(path, 'r', encoding='utf-8') as f: + f.seek(start_position) + + # Read the chunk + raw_content = f.read(chunk_size) + + if not raw_content: + return { + "status": "success", + "file_path": file_path, + "content": "", + "current_position": file_size, + "file_size": file_size, + "end_of_file_reached": True, + "bytes_read": 0, + "content_length": 0, + "progress_percentage": 100.0, + "actual_boundary": "end_of_file" + } + + # Find intelligent boundary + boundary_type = "chunk_boundary" + actual_content = raw_content + bytes_read = len(raw_content.encode('utf-8')) + + if len(raw_content) == chunk_size: + # Look for newline boundary + last_newline = raw_content.rfind('\n') + if last_newline > chunk_size * 0.5: # At least half the chunk + actual_content = raw_content[:last_newline + 1] + boundary_type = "newline_boundary" + else: + # Look for space boundary + last_space = raw_content.rfind(' ') + if last_space > chunk_size * 0.7: # At least 70% of chunk + actual_content = raw_content[:last_space + 1] + boundary_type = "space_boundary" + + # Calculate actual position + actual_bytes_read = len(actual_content.encode('utf-8')) + current_position = start_position + actual_bytes_read + + # Check if end of file reached + end_of_file_reached = current_position >= file_size + + return { + "status": "success", + "file_path": file_path, + "content": actual_content, + "current_position": current_position, + "file_size": file_size, + "end_of_file_reached": end_of_file_reached, + "bytes_read": actual_bytes_read, + "content_length": len(actual_content), + "progress_percentage": round((current_position / file_size) * 100, 2), + "actual_boundary": boundary_type + } + + except Exception as e: + return { + "status": "failed", + "file_path": file_path, + "error_message": str(e) + } + + async def read_complete_text_file(self, file_path: str) -> Dict[str, Any]: + """ + Read complete text file content (use with caution for large files) + """ + try: + path = Path(file_path) + + if not path.exists(): + return { + "status": "failed", + "file_path": file_path, + "error_message": "File does not exist" + } + + # Check file size + file_size = path.stat().st_size + if file_size > 10 * 1024 * 1024: # 10MB + return { + "status": "failed", + "file_path": file_path, + "error_message": "File too large (>10MB). Use read_text_file_segments instead." + } + + with open(path, 'r', encoding='utf-8') as f: + content = f.read() + + return { + "status": "success", + "file_path": file_path, + "content": content, + "file_size": file_size, + "content_length": len(content) + } + + except Exception as e: + return { + "status": "failed", + "file_path": file_path, + "error_message": str(e) + } + + async def write_text_file( + self, + file_path: str, + content: str, + mode: str = "w", # "w" for write, "a" for append + encoding: str = "utf-8" + ) -> Dict[str, Any]: + """ + Write content to text file + """ + try: + path = Path(file_path) + + # Create directory if it doesn't exist + path.parent.mkdir(parents=True, exist_ok=True) + + with open(path, mode, encoding=encoding) as f: + f.write(content) + + # Get file info after writing + stat = path.stat() + + return { + "status": "success", + "file_path": file_path, + "content_length": len(content), + "file_size": stat.st_size, + "mode": mode, + "encoding": encoding + } + + except Exception as e: + return { + "status": "failed", + "file_path": file_path, + "error_message": str(e) + } + + async def convert_text_format( + self, + input_file: str, + output_file: str, + input_format: str = "txt", + output_format: str = "srt" + ) -> Dict[str, Any]: + """ + Convert text files between different formats (e.g., txt to srt) + """ + try: + # Read input file + read_result = await self.read_complete_text_file(input_file) + if read_result["status"] != "success": + return read_result + + content = read_result["content"] + + # Convert based on formats + if input_format == "txt" and output_format == "srt": + converted_content = self._convert_txt_to_srt(content) + elif input_format == "srt" and output_format == "txt": + converted_content = self._convert_srt_to_txt(content) + else: + return { + "status": "failed", + "error_message": f"Conversion from {input_format} to {output_format} not supported" + } + + # Write output file + write_result = await self.write_text_file(output_file, converted_content) + + if write_result["status"] == "success": + write_result.update({ + "input_file": input_file, + "input_format": input_format, + "output_format": output_format, + "conversion": "success" + }) + + return write_result + + except Exception as e: + return { + "status": "failed", + "error_message": str(e) + } + + def _convert_txt_to_srt(self, content: str) -> str: + """Convert plain text to SRT format (basic implementation)""" + lines = content.strip().split('\n') + srt_content = [] + + for i, line in enumerate(lines, 1): + if line.strip(): + # Create basic timestamps (assuming 3 seconds per line) + start_time = (i - 1) * 3 + end_time = i * 3 + + start_srt = self._seconds_to_srt_time(start_time) + end_srt = self._seconds_to_srt_time(end_time) + + srt_content.extend([ + str(i), + f"{start_srt} --> {end_srt}", + line.strip(), + "" + ]) + + return '\n'.join(srt_content) + + def _convert_srt_to_txt(self, content: str) -> str: + """Convert SRT to plain text""" + lines = content.strip().split('\n') + text_lines = [] + + for line in lines: + # Skip sequence numbers and timestamps + if line.strip() and not line.strip().isdigit() and '-->' not in line: + text_lines.append(line.strip()) + + return '\n'.join(text_lines) + + def _seconds_to_srt_time(self, seconds: float) -> str: + """Convert seconds to SRT time format (HH:MM:SS,mmm)""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + millisecs = int((seconds % 1) * 1000) + + return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}" \ No newline at end of file diff --git a/src/services/file_service.py b/src/services/file_service.py new file mode 100644 index 0000000000000000000000000000000000000000..83e828b225868534b266f5dbbb2989259ec032a9 --- /dev/null +++ b/src/services/file_service.py @@ -0,0 +1,42 @@ +""" +File service for handling file operations +""" + +import aiofiles +from pathlib import Path +from typing import Optional + + +class FileService: + """Service for file operations""" + + async def write_text_file(self, file_path: str, content: str, encoding: str = "utf-8"): + """Write text content to file asynchronously""" + async with aiofiles.open(file_path, 'w', encoding=encoding) as f: + await f.write(content) + + async def read_text_file(self, file_path: str, encoding: str = "utf-8") -> str: + """Read text content from file asynchronously""" + async with aiofiles.open(file_path, 'r', encoding=encoding) as f: + return await f.read() + + def ensure_directory(self, directory_path: str): + """Ensure directory exists""" + Path(directory_path).mkdir(parents=True, exist_ok=True) + + def file_exists(self, file_path: str) -> bool: + """Check if file exists""" + return Path(file_path).exists() + + def get_file_size(self, file_path: str) -> int: + """Get file size in bytes""" + return Path(file_path).stat().st_size + + def get_file_extension(self, file_path: str) -> str: + """Get file extension""" + return Path(file_path).suffix.lower() + + def is_audio_file(self, file_path: str) -> bool: + """Check if file is an audio file""" + audio_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.aac', '.ogg', '.wma'} + return self.get_file_extension(file_path) in audio_extensions \ No newline at end of file diff --git a/src/services/health_service.py b/src/services/health_service.py new file mode 100644 index 0000000000000000000000000000000000000000..69b71f4a744e257964eceef2b8a77b5d380c7879 --- /dev/null +++ b/src/services/health_service.py @@ -0,0 +1,214 @@ +""" +Health Service +Provides health check functionality for the transcription service +""" + +import os +import whisper +from pathlib import Path +from typing import Dict, Any + + +class HealthService: + """Service for health checks and status monitoring""" + + def get_health_status(self) -> Dict[str, Any]: + """Get comprehensive health status of the service""" + + # Check Whisper models + whisper_status = self._check_whisper_models() + + # Check speaker diarization + speaker_status = self._check_speaker_diarization() + + # Overall health + overall_health = "healthy" if ( + whisper_status["status"] == "healthy" and + speaker_status["status"] in ["healthy", "partial"] # Speaker diarization is optional + ) else "unhealthy" + + return { + "status": overall_health, + "timestamp": self._get_current_timestamp(), + "whisper": whisper_status, + "speaker_diarization": speaker_status, + "version": "1.0.0" + } + + def _check_whisper_models(self) -> Dict[str, Any]: + """Check Whisper model availability""" + try: + # Check available models + available_models = whisper.available_models() + + # Check if turbo model is available + default_model = "turbo" + + # Check model cache directory + model_cache_dir = "/model" + cache_exists = os.path.exists(model_cache_dir) + + # Try to load the default model + try: + if cache_exists: + model = whisper.load_model(default_model, download_root=model_cache_dir) + model_loaded = True + load_source = "cache" + else: + model = whisper.load_model(default_model) + model_loaded = True + load_source = "download" + except Exception as e: + model_loaded = False + load_source = f"failed: {e}" + + return { + "status": "healthy" if model_loaded else "unhealthy", + "default_model": default_model, + "available_models": available_models, + "model_cache_exists": cache_exists, + "model_cache_directory": model_cache_dir if cache_exists else None, + "model_loaded": model_loaded, + "load_source": load_source, + "whisper_version": getattr(whisper, '__version__', 'unknown') + } + + except Exception as e: + return { + "status": "unhealthy", + "error": str(e), + "default_model": "turbo", + "available_models": [], + "model_cache_exists": False, + "model_loaded": False + } + + def _check_speaker_diarization(self) -> Dict[str, Any]: + """Check speaker diarization functionality""" + try: + # Check if HF token is available + hf_token = os.environ.get("HF_TOKEN") + hf_token_available = hf_token is not None + + # Check speaker model cache + speaker_cache_dir = "/model/speaker-diarization" + cache_exists = os.path.exists(speaker_cache_dir) + + # Check config file + config_file = os.path.join(speaker_cache_dir, "config.json") + config_exists = os.path.exists(config_file) + + # Try to load speaker diarization pipeline + pipeline_loaded = False + pipeline_error = None + + if hf_token_available: + try: + from pyannote.audio import Pipeline + + # Try to load pipeline + pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=hf_token + ) + pipeline_loaded = True + + except Exception as e: + pipeline_error = str(e) + else: + pipeline_error = "HF_TOKEN not available" + + # Determine status + if pipeline_loaded: + status = "healthy" + elif hf_token_available: + status = "partial" # Token available but pipeline failed + else: + status = "disabled" # No token, feature disabled + + return { + "status": status, + "hf_token_available": hf_token_available, + "speaker_cache_exists": cache_exists, + "speaker_cache_directory": speaker_cache_dir if cache_exists else None, + "config_exists": config_exists, + "pipeline_loaded": pipeline_loaded, + "pipeline_error": pipeline_error, + "model_name": "pyannote/speaker-diarization-3.1" + } + + except Exception as e: + return { + "status": "unhealthy", + "error": str(e), + "hf_token_available": False, + "speaker_cache_exists": False, + "pipeline_loaded": False + } + + def test_speaker_diarization(self, test_audio_path: str = None) -> Dict[str, Any]: + """Test speaker diarization functionality with actual audio""" + try: + # Check if HF token is available + hf_token = os.environ.get("HF_TOKEN") + if not hf_token: + return { + "status": "skipped", + "reason": "HF_TOKEN not available" + } + + # Load speaker diarization pipeline + from pyannote.audio import Pipeline + + pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=hf_token + ) + + # If no test audio provided, return pipeline load success + if not test_audio_path: + return { + "status": "pipeline_loaded", + "message": "Speaker diarization pipeline loaded successfully" + } + + # Test with actual audio file + if not os.path.exists(test_audio_path): + return { + "status": "failed", + "reason": f"Test audio file not found: {test_audio_path}" + } + + # Run speaker diarization + diarization_result = pipeline(test_audio_path) + + # Process results + speakers = set() + segments_count = 0 + total_speech_duration = 0 + + for turn, _, speaker in diarization_result.itertracks(yield_label=True): + speakers.add(speaker) + segments_count += 1 + total_speech_duration += turn.end - turn.start + + return { + "status": "success", + "speakers_detected": len(speakers), + "segments_count": segments_count, + "total_speech_duration": total_speech_duration, + "test_audio_path": test_audio_path, + "speakers": list(speakers) + } + + except Exception as e: + return { + "status": "failed", + "error": str(e), + "test_audio_path": test_audio_path + } + + def _get_current_timestamp(self) -> str: + """Get current timestamp in ISO format""" + from datetime import datetime + return datetime.utcnow().isoformat() + "Z" \ No newline at end of file diff --git a/src/services/modal_transcription_service.py b/src/services/modal_transcription_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e09f54bb49463d360b974a9d956662088e70c0 --- /dev/null +++ b/src/services/modal_transcription_service.py @@ -0,0 +1,545 @@ +""" +Modal Transcription Service - handles transcription via Modal endpoints +Enhanced to replace transcription_tools.py functions with proper service architecture +""" + +import asyncio +import aiohttp +import base64 +import os +from typing import Dict, Any +from pathlib import Path + + +class ModalTranscriptionService: + """Service for audio transcription via Modal endpoints""" + + def __init__(self, endpoint_urls: Dict[str, str] = None, cache_dir: str = None, use_direct_modal_calls: bool = True): + """ + Initialize Modal transcription service + + Args: + endpoint_urls: Dictionary of endpoint URLs (used when use_direct_modal_calls=False) + cache_dir: Cache directory path + use_direct_modal_calls: Whether to use direct Modal function calls or HTTP endpoints + """ + self.use_direct_modal_calls = use_direct_modal_calls + self.endpoint_urls = endpoint_urls or { + "transcribe_audio": "https://richardsucran--transcribe-audio-endpoint.modal.run", + "transcribe_chunk": "https://richardsucran--transcribe-audio-chunk-endpoint.modal.run", + "health_check": "https://richardsucran--health-check-endpoint.modal.run" + } + self.cache_dir = cache_dir or "/tmp" + + # Determine if we're running in Modal environment + if self.use_direct_modal_calls: + print("✅ Using direct function calls (no HTTP endpoints)") + + async def transcribe_audio_file( + self, + audio_file_path: str, + model_size: str = "turbo", + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False, + use_parallel_processing: bool = True, + chunk_duration: int = 60, + use_intelligent_segmentation: bool = True + ) -> Dict[str, Any]: + """ + Transcribe audio file using Modal endpoints with intelligent processing + + Args: + audio_file_path: Path to audio file + model_size: Whisper model size + language: Language code (None for auto-detect) + output_format: Output format (srt, txt, json) + enable_speaker_diarization: Whether to enable speaker diarization + use_parallel_processing: Whether to use distributed processing + chunk_duration: Duration of chunks for parallel processing + use_intelligent_segmentation: Whether to use intelligent segmentation + + Returns: + Transcription result dictionary + """ + try: + # Validate input file + if not os.path.exists(audio_file_path): + return { + "processing_status": "failed", + "error_message": f"Audio file not found: {audio_file_path}" + } + + # Read and encode audio file + with open(audio_file_path, "rb") as f: + audio_data = f.read() + + audio_base64 = base64.b64encode(audio_data).decode('utf-8') + + # Prepare request data + request_data = { + "audio_file_data": audio_base64, + "audio_file_name": os.path.basename(audio_file_path), + "model_size": model_size, + "language": language, + "output_format": output_format, + "enable_speaker_diarization": enable_speaker_diarization, + "use_parallel_processing": use_parallel_processing, + "chunk_duration": chunk_duration, + "use_intelligent_segmentation": use_intelligent_segmentation + } + + endpoint_url = self.endpoint_urls["transcribe_audio"] + + print(f"🎤 Starting transcription via Modal {'function call' if self.use_direct_modal_calls else 'endpoint'}...") + print(f" File: {audio_file_path}") + print(f" Size: {len(audio_data) / (1024*1024):.2f} MB") + print(f" Model: {model_size}") + print(f" Parallel processing: {use_parallel_processing}") + print(f" Intelligent segmentation: {use_intelligent_segmentation}") + print(f" Speaker diarization: {enable_speaker_diarization}") + + # Choose between direct function call or HTTP endpoint + if self.use_direct_modal_calls: + # Direct function call (when running inside Modal environment) + try: + # Call the process_transcription_request method directly + result = await self.process_transcription_request(request_data) + except Exception as e: + print(f"⚠️ Direct Modal call failed, falling back to HTTP: {e}") + self.use_direct_modal_calls = False + # Fall through to HTTP endpoint call + else: + print(f"✅ Transcription completed successfully via direct function call") + self._log_transcription_results(result, enable_speaker_diarization) + return result + + if not self.use_direct_modal_calls: + # HTTP endpoint call (fallback) + endpoint_url = self.endpoint_urls["transcribe_audio"] + async with aiohttp.ClientSession() as session: + async with session.post( + endpoint_url, + json=request_data, + timeout=aiohttp.ClientTimeout(total=3600) # 1 hour timeout + ) as response: + if response.status == 200: + result = await response.json() + print(f"✅ Transcription completed successfully via HTTP endpoint") + self._log_transcription_results(result, enable_speaker_diarization) + return result + else: + error_text = await response.text() + return { + "processing_status": "failed", + "error_message": f"HTTP {response.status}: {error_text}" + } + + except Exception as e: + return { + "processing_status": "failed", + "error_message": f"Transcription request failed: {e}" + } + + async def transcribe_chunk( + self, + chunk_path: str, + start_time: float, + end_time: float, + model_size: str = "turbo", + language: str = None, + enable_speaker_diarization: bool = False + ) -> Dict[str, Any]: + """ + Transcribe a single audio chunk using Modal chunk endpoint + + Args: + chunk_path: Path to audio chunk file + start_time: Start time of chunk in original audio + end_time: End time of chunk in original audio + model_size: Whisper model size + language: Language code + enable_speaker_diarization: Whether to enable speaker diarization + + Returns: + Transcription result for the chunk + """ + try: + # Read and encode chunk file + with open(chunk_path, "rb") as f: + audio_data = f.read() + + audio_base64 = base64.b64encode(audio_data).decode('utf-8') + + # Prepare request data + request_data = { + "audio_file_data": audio_base64, + "audio_file_name": os.path.basename(chunk_path), + "model_size": model_size, + "language": language, + "output_format": "json", # Use JSON for easier merging + "enable_speaker_diarization": enable_speaker_diarization, + "chunk_start_time": start_time, + "chunk_end_time": end_time + } + + # Choose between direct function call or HTTP endpoint + if self.use_direct_modal_calls: + # Direct function call + try: + result = self.process_chunk_request(request_data) + result["chunk_start_time"] = start_time + result["chunk_end_time"] = end_time + result["chunk_file"] = chunk_path + return result + except Exception as e: + print(f"⚠️ Direct chunk call failed, falling back to HTTP: {e}") + self.use_direct_modal_calls = False + # Fall through to HTTP endpoint call + + if not self.use_direct_modal_calls: + # HTTP endpoint call (fallback) + endpoint_url = self.endpoint_urls["transcribe_chunk"] + # Configure timeout with more granular controls + # Adjust timeout based on speaker diarization + if enable_speaker_diarization: + timeout_config = aiohttp.ClientTimeout( + total=720, # 12 minutes total for speaker diarization + connect=45, # 45 seconds connection timeout + sock_read=300 # 5 minutes read timeout for speaker processing + ) + else: + timeout_config = aiohttp.ClientTimeout( + total=480, # 8 minutes total for regular transcription + connect=30, # 30 seconds connection timeout + sock_read=120 # 2 minutes read timeout for regular processing + ) + + async with aiohttp.ClientSession(timeout=timeout_config) as session: + async with session.post( + endpoint_url, + json=request_data + ) as response: + if response.status == 200: + result = await response.json() + result["chunk_start_time"] = start_time + result["chunk_end_time"] = end_time + result["chunk_file"] = chunk_path + return result + else: + error_text = await response.text() + return { + "processing_status": "failed", + "error_message": f"HTTP {response.status}: {error_text}", + "chunk_start_time": start_time, + "chunk_end_time": end_time, + "chunk_file": chunk_path + } + + except Exception as e: + return { + "processing_status": "failed", + "error_message": str(e), + "chunk_start_time": start_time, + "chunk_end_time": end_time, + "chunk_file": chunk_path + } + + async def check_endpoints_health(self) -> Dict[str, Any]: + """ + Check the health status of all Modal endpoints + + Returns: + Health status dictionary for all endpoints + """ + health_status = {} + + async with aiohttp.ClientSession() as session: + for endpoint_name, endpoint_url in self.endpoint_urls.items(): + try: + if endpoint_name == "health_check": + # Health check endpoint supports GET + async with session.get( + endpoint_url, + timeout=aiohttp.ClientTimeout(total=30) + ) as response: + if response.status == 200: + response_data = await response.json() + health_status[endpoint_name] = { + "status": "healthy", + "response": response_data, + "url": endpoint_url + } + else: + health_status[endpoint_name] = { + "status": "unhealthy", + "error": f"HTTP {response.status}", + "url": endpoint_url + } + else: + # Other endpoints are POST-only, just check if they're accessible + async with session.get( + endpoint_url, + timeout=aiohttp.ClientTimeout(total=10) + ) as response: + # 405 Method Not Allowed is expected for POST-only endpoints + if response.status == 405: + health_status[endpoint_name] = { + "status": "healthy", + "response": "Endpoint accessible (POST-only)", + "url": endpoint_url + } + else: + health_status[endpoint_name] = { + "status": "unknown", + "response": f"HTTP {response.status}", + "url": endpoint_url + } + + except Exception as e: + health_status[endpoint_name] = { + "status": "error", + "error": str(e), + "url": endpoint_url + } + + return health_status + + async def get_system_status(self) -> Dict[str, Any]: + """ + Get comprehensive system status including health checks + + Returns: + System status dictionary + """ + try: + endpoint_url = self.endpoint_urls["health_check"] + + async with aiohttp.ClientSession() as session: + async with session.get( + endpoint_url, + timeout=aiohttp.ClientTimeout(total=30) + ) as response: + if response.status == 200: + return await response.json() + else: + error_text = await response.text() + return { + "status": "failed", + "error_message": f"HTTP {response.status}: {error_text}" + } + + except Exception as e: + return { + "status": "failed", + "error_message": f"Health check failed: {e}" + } + + def get_endpoint_url(self, endpoint_name: str) -> str: + """ + Get URL for specific endpoint + + Args: + endpoint_name: Name of the endpoint + + Returns: + Endpoint URL + """ + return self.endpoint_urls.get(endpoint_name, f"https://richardsucran--{endpoint_name}.modal.run") + + # ==================== Modal Server-Side Methods ==================== + # These methods are used by Modal endpoints running on the server + + async def process_transcription_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Process transcription request on Modal server + This method runs on the Modal server, not the client + """ + try: + # Import services that are available on Modal server + import sys + import tempfile + import base64 + from pathlib import Path + + # Import local services (available on Modal server) + from src.services.distributed_transcription_service import DistributedTranscriptionService + from src.services.transcription_service import TranscriptionService + + # Extract request parameters + audio_file_data = request_data.get("audio_file_data") + audio_file_name = request_data.get("audio_file_name", "audio.mp3") + model_size = request_data.get("model_size", "turbo") + language = request_data.get("language") + output_format = request_data.get("output_format", "srt") + enable_speaker_diarization = request_data.get("enable_speaker_diarization", False) + use_parallel_processing = request_data.get("use_parallel_processing", True) + chunk_duration = request_data.get("chunk_duration", 60) + use_intelligent_segmentation = request_data.get("use_intelligent_segmentation", True) + + if not audio_file_data: + return { + "processing_status": "failed", + "error_message": "No audio data provided" + } + + # Decode audio data and save to temporary file + audio_bytes = base64.b64decode(audio_file_data) + temp_dir = Path(self.cache_dir) + temp_dir.mkdir(exist_ok=True) + + temp_audio_path = temp_dir / audio_file_name + with open(temp_audio_path, "wb") as f: + f.write(audio_bytes) + + print(f"🎤 Processing audio on Modal server: {audio_file_name}") + print(f" Size: {len(audio_bytes) / (1024*1024):.2f} MB") + print(f" Model: {model_size}") + print(f" Parallel processing: {use_parallel_processing}") + print(f" Intelligent segmentation: {use_intelligent_segmentation}") + + # Choose processing strategy based on file size and settings + file_size_mb = len(audio_bytes) / (1024 * 1024) + + if use_parallel_processing and file_size_mb > 10: # Use distributed for files > 10MB + print("🔄 Using distributed transcription service") + service = DistributedTranscriptionService() + + result = await service.transcribe_audio_distributed( + audio_file_path=str(temp_audio_path), + model_size=model_size, + language=language, + output_format=output_format, + enable_speaker_diarization=enable_speaker_diarization, + chunk_duration=chunk_duration, + use_intelligent_segmentation=use_intelligent_segmentation + ) + else: + print("🎯 Using single transcription service") + service = TranscriptionService() + + result = service.transcribe_audio( + audio_file_path=str(temp_audio_path), + model_size=model_size, + language=language, + output_format=output_format, + enable_speaker_diarization=enable_speaker_diarization + ) + + # Clean up temporary file + try: + temp_audio_path.unlink() + except: + pass + + print(f"✅ Transcription completed on Modal server") + return result + + except Exception as e: + print(f"❌ Error processing transcription request: {e}") + return { + "processing_status": "failed", + "error_message": f"Server processing error: {str(e)}" + } + + def process_chunk_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Process chunk transcription request on Modal server + This method runs on the Modal server, not the client + """ + try: + # Import services that are available on Modal server + import base64 + import tempfile + from pathlib import Path + + # Import local services (available on Modal server) + from src.services.transcription_service import TranscriptionService + + # Extract request parameters + audio_file_data = request_data.get("audio_file_data") + audio_file_name = request_data.get("audio_file_name", "chunk.mp3") + model_size = request_data.get("model_size", "turbo") + language = request_data.get("language") + enable_speaker_diarization = request_data.get("enable_speaker_diarization", False) + chunk_start_time = request_data.get("chunk_start_time", 0) + chunk_end_time = request_data.get("chunk_end_time", 0) + + if not audio_file_data: + return { + "processing_status": "failed", + "error_message": "No audio data provided", + "chunk_start_time": chunk_start_time, + "chunk_end_time": chunk_end_time + } + + # Decode audio data and save to temporary file + audio_bytes = base64.b64decode(audio_file_data) + temp_dir = Path(self.cache_dir) + temp_dir.mkdir(exist_ok=True) + + temp_audio_path = temp_dir / audio_file_name + with open(temp_audio_path, "wb") as f: + f.write(audio_bytes) + + print(f"🎤 Processing chunk on Modal server: {audio_file_name}") + print(f" Time range: {chunk_start_time:.2f}s - {chunk_end_time:.2f}s") + print(f" Size: {len(audio_bytes) / (1024*1024):.2f} MB") + + # Use single transcription service for chunks + service = TranscriptionService() + + result = service.transcribe_audio( + audio_file_path=str(temp_audio_path), + model_size=model_size, + language=language, + output_format="json", # Always use JSON for chunks + enable_speaker_diarization=enable_speaker_diarization + ) + + # Add chunk timing information + if result.get("processing_status") == "success": + result["chunk_start_time"] = chunk_start_time + result["chunk_end_time"] = chunk_end_time + result["chunk_file"] = audio_file_name + + # Clean up temporary file + try: + temp_audio_path.unlink() + except: + pass + + print(f"✅ Chunk transcription completed on Modal server") + return result + + except Exception as e: + print(f"❌ Error processing chunk request: {e}") + return { + "processing_status": "failed", + "error_message": f"Server chunk processing error: {str(e)}", + "chunk_start_time": request_data.get("chunk_start_time", 0), + "chunk_end_time": request_data.get("chunk_end_time", 0) + } + + def _log_transcription_results(self, result: Dict[str, Any], enable_speaker_diarization: bool = False): + """ + Log transcription results in a consistent format + + Args: + result: Transcription result dictionary + enable_speaker_diarization: Whether speaker diarization was enabled + """ + print(f" Processing type: {'Distributed' if result.get('distributed_processing', False) else 'Single'}") + print(f" Segments: {result.get('segment_count', 0)}") + print(f" Duration: {result.get('audio_duration', 0):.2f}s") + print(f" Language: {result.get('language_detected', 'unknown')}") + + if result.get("distributed_processing", False): + print(f" Chunks processed: {result.get('chunks_processed', 0)}") + print(f" Chunks failed: {result.get('chunks_failed', 0)}") + segmentation_type = result.get("segmentation_type", "time_based") + print(f" Segmentation: {segmentation_type}") + + if enable_speaker_diarization: + speaker_count = result.get("global_speaker_count", 0) + print(f" Speakers detected: {speaker_count}") \ No newline at end of file diff --git a/src/services/podcast_download_service.py b/src/services/podcast_download_service.py new file mode 100644 index 0000000000000000000000000000000000000000..997d58ccf516d6109542bd65336f7a9bbd72a7aa --- /dev/null +++ b/src/services/podcast_download_service.py @@ -0,0 +1,526 @@ +""" +Podcast Download Service - unified download functionality for multiple platforms +""" + +import os +import re +import asyncio +import subprocess +import pathlib +from typing import Dict, Any, Optional +from urllib.parse import urlparse + +import requests +from bs4 import BeautifulSoup + +from ..interfaces.podcast_downloader import ( + IPodcastDownloader, + PodcastInfo, + DownloadResult, + PodcastPlatform +) +from ..utils.errors import FileProcessingError, ConfigurationError +from ..models.download import DownloadResponse + + +class PodcastDownloadService(IPodcastDownloader): + """Unified podcast download service supporting multiple platforms""" + + def __init__(self, default_output_folder: str = None): + # Use storage config if no folder specified + if default_output_folder is None: + from ..utils.storage_config import get_storage_config + storage_config = get_storage_config() + self.default_output_folder = str(storage_config.downloads_dir) + else: + self.default_output_folder = default_output_folder + + self.supported_platforms = { + PodcastPlatform.APPLE: self._handle_apple_podcast, + PodcastPlatform.XIAOYUZHOU: self._handle_xiaoyuzhou_podcast + } + + async def extract_podcast_info(self, url: str) -> PodcastInfo: + """Extract podcast information from URL""" + + platform = self._detect_platform(url) + + if platform == PodcastPlatform.APPLE: + return await self._extract_apple_info(url) + elif platform == PodcastPlatform.XIAOYUZHOU: + return await self._extract_xiaoyuzhou_info(url) + else: + raise ConfigurationError(f"Unsupported platform for URL: {url}") + + async def download_podcast( + self, + url: str, + output_folder: str = None, + convert_to_mp3: bool = False, + keep_original: bool = False + ) -> DownloadResult: + """Download podcast from URL""" + + output_folder = output_folder or self.default_output_folder + + try: + # Ensure output folder exists + pathlib.Path(output_folder).mkdir(parents=True, exist_ok=True) + + # Detect platform and use appropriate handler + platform = self._detect_platform(url) + handler = self.supported_platforms.get(platform) + + if not handler: + return DownloadResult( + success=False, + file_path=None, + podcast_info=None, + error_message=f"Unsupported platform for URL: {url}" + ) + + # Call platform-specific handler + result = await handler(url, output_folder, convert_to_mp3, keep_original) + return result + + except Exception as e: + return DownloadResult( + success=False, + file_path=None, + podcast_info=None, + error_message=f"Download failed: {str(e)}" + ) + + def get_supported_platforms(self) -> list[PodcastPlatform]: + """Get list of supported platforms""" + return list(self.supported_platforms.keys()) + + def can_handle_url(self, url: str) -> bool: + """Check if this downloader can handle the given URL""" + try: + platform = self._detect_platform(url) + return platform in self.supported_platforms + except: + return False + + def _detect_platform(self, url: str) -> PodcastPlatform: + """Detect platform from URL""" + + if "podcasts.apple.com" in url: + return PodcastPlatform.APPLE + elif "xiaoyuzhoufm.com" in url: + return PodcastPlatform.XIAOYUZHOU + else: + return PodcastPlatform.GENERIC + + async def _extract_apple_info(self, url: str) -> PodcastInfo: + """Extract Apple Podcast information""" + + loop = asyncio.get_event_loop() + response = await loop.run_in_executor(None, requests.get, url) + + if response.status_code != 200: + raise FileProcessingError(f"Failed to fetch podcast page: {response.status_code}") + + soup = BeautifulSoup(response.text, 'html.parser') + + # Find audio URL + audio_url = self._find_audio_url_in_html(response.text) + if not audio_url: + raise FileProcessingError("Unable to find podcast audio URL") + + # Find title + title = self._extract_apple_title(soup) + if not title: + raise FileProcessingError("Unable to find podcast title") + + # Extract episode ID + episode_id = self._extract_apple_episode_id(url) + + return PodcastInfo( + title=title, + audio_url=audio_url, + episode_id=episode_id, + platform=PodcastPlatform.APPLE + ) + + async def _extract_xiaoyuzhou_info(self, url: str) -> PodcastInfo: + """Extract XiaoYuZhou Podcast information""" + + loop = asyncio.get_event_loop() + + # Use similar extraction logic as original + # This would need to be implemented based on XiaoYuZhou's page structure + try: + response = await loop.run_in_executor(None, requests.get, url) + + if response.status_code != 200: + raise FileProcessingError(f"Failed to fetch XYZ podcast page: {response.status_code}") + + # Extract info from response (implementation depends on XYZ structure) + # For now, return a basic structure + episode_id = self._extract_xyz_episode_id(url) + + return PodcastInfo( + title=f"XYZ Episode {episode_id}", + audio_url="", # Would be extracted from page + episode_id=episode_id, + platform=PodcastPlatform.XIAOYUZHOU + ) + except Exception as e: + raise FileProcessingError(f"XiaoYuZhou info extraction failed: {str(e)}") + + async def _handle_apple_podcast( + self, + url: str, + output_folder: str, + convert_to_mp3: bool, + keep_original: bool + ) -> DownloadResult: + """Handle Apple Podcast download""" + + try: + # Extract podcast info + podcast_info = await self._extract_apple_info(url) + + # Generate output filename + output_file = self._generate_filename( + podcast_info.episode_id, + podcast_info.audio_url, + output_folder, + convert_to_mp3 + ) + + # Download audio file + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, + self._download_file, + podcast_info.audio_url, + output_file + ) + + # Convert to MP3 if requested + if convert_to_mp3: + output_file = await self._convert_to_mp3(output_file, keep_original) + + return DownloadResult( + success=True, + file_path=output_file, + podcast_info=podcast_info + ) + + except Exception as e: + return DownloadResult( + success=False, + file_path=None, + podcast_info=None, + error_message=str(e) + ) + + async def _handle_xiaoyuzhou_podcast( + self, + url: str, + output_folder: str, + convert_to_mp3: bool, + keep_original: bool + ) -> DownloadResult: + """Handle XiaoYuZhou Podcast download""" + + try: + # Use our internal implementation instead of importing from methods + audio_path, title = await self._download_xiaoyuzhou_episode(url, output_folder) + + # Extract episode ID + episode_id = self._extract_xyz_episode_id(url) + + podcast_info = PodcastInfo( + title=title or f"XYZ Episode {episode_id}", + audio_url="", # Not available from the function + episode_id=episode_id, + platform=PodcastPlatform.XIAOYUZHOU + ) + + return DownloadResult( + success=True, + file_path=audio_path, + podcast_info=podcast_info + ) + + except Exception as e: + return DownloadResult( + success=False, + file_path=None, + podcast_info=None, + error_message=str(e) + ) + + async def _download_xiaoyuzhou_episode(self, url: str, output_folder: str) -> tuple[str, str]: + """Download XiaoYuZhou episode using Selenium""" + + from selenium import webdriver + from selenium.webdriver.chrome.options import Options + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + from bs4 import BeautifulSoup + import requests + import json + import os + import time + + # Setup Chrome options + chrome_options = Options() + chrome_options.add_argument("--headless") + chrome_options.add_argument("--no-sandbox") + chrome_options.add_argument("--disable-dev-shm-usage") + chrome_options.add_argument("--disable-gpu") + chrome_options.add_argument("--window-size=1920,1080") + chrome_options.add_argument("--user-agent=Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36") + + # Initialize driver + driver = None + try: + driver = webdriver.Chrome(options=chrome_options) + driver.get(url) + + # Wait for page to load + WebDriverWait(driver, 10).until( + EC.presence_of_element_located((By.TAG_NAME, "body")) + ) + + # Wait a bit more for JavaScript to execute + time.sleep(3) + + # Get page source + page_source = driver.page_source + soup = BeautifulSoup(page_source, 'html.parser') + + # Extract episode title + title = "Unknown Episode" + title_selectors = [ + 'h1[data-v-]', + '.episode-title', + 'h1', + '.title' + ] + + for selector in title_selectors: + title_elem = soup.select_one(selector) + if title_elem and title_elem.text.strip(): + title = title_elem.text.strip() + break + + # Try to find audio URL in the page source or network requests + audio_url = None + + # Look for audio URLs in script tags + scripts = soup.find_all('script') + for script in scripts: + if script.string: + # Look for audio file URLs + import re + audio_matches = re.findall(r'https://[^\s"\']+\.mp3[^\s"\']*', script.string) + if audio_matches: + audio_url = audio_matches[0] + break + + if not audio_url: + # Alternative: try to find audio element or data attributes + audio_elements = soup.find_all(['audio', 'source']) + for audio_elem in audio_elements: + src = audio_elem.get('src') or audio_elem.get('data-src') + if src and ('.mp3' in src or '.m4a' in src): + audio_url = src + break + + if not audio_url: + raise Exception("Could not find audio URL in the page") + + # Extract episode ID for filename + episode_id = self._extract_xyz_episode_id(url) + filename = f"{episode_id}_xiaoyuzhou_episode.mp3" + output_path = os.path.join(output_folder, filename) + + # Ensure output directory exists + os.makedirs(output_folder, exist_ok=True) + + # Download the audio file + await self._download_file_async(audio_url, output_path) + + return output_path, title + + except Exception as e: + raise Exception(f"XiaoYuZhou download failed: {str(e)}") + + finally: + if driver: + driver.quit() + + async def _download_file_async(self, url: str, output_path: str) -> None: + """Download file from URL asynchronously""" + + import aiohttp + import asyncio + + headers = { + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' + } + + timeout = aiohttp.ClientTimeout(total=300) # 5 minutes timeout + + async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session: + async with session.get(url) as response: + response.raise_for_status() + + with open(output_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + f.write(chunk) + + def _find_audio_url_in_html(self, html: str) -> Optional[str]: + """Find audio URL in HTML content""" + + # Find all .mp3 and .m4a URLs + audio_urls = re.findall(r'https://[^\s^"]+(?:\.mp3|\.m4a)', html) + + if audio_urls: + pattern = r'=https?://[^\s^"]+(?:\.mp3|\.m4a)' + result = re.findall(pattern, audio_urls[-1]) + if result: + return result[-1][1:] + else: + return audio_urls[-1] + + return None + + def _extract_apple_title(self, soup: BeautifulSoup) -> Optional[str]: + """Extract title from Apple Podcast page""" + + title_selectors = [ + 'span.product-header__title', + 'h1.product-header__title', + '.product-header__title', + 'h1[data-testid="product-header-title"]', + '.headings__title', + 'h1.headings__title', + '.episode-title', + 'h1' + ] + + for selector in title_selectors: + title_elem = soup.select_one(selector) + if title_elem: + return title_elem.text.strip().replace('/', '-') + + # Fallback to page title + page_title = soup.find('title') + if page_title: + return page_title.text.strip().replace('/', '-').replace(' on Apple Podcasts', '') + + return None + + def _extract_apple_episode_id(self, url: str) -> str: + """Extract episode ID from Apple Podcast URL""" + + # Try to extract episode ID from URL + episode_match = re.search(r'[?&]i=(\d+)', url) + if episode_match: + return episode_match.group(1) + + # Try podcast ID + podcast_match = re.search(r'/id(\d+)', url) + if podcast_match: + return podcast_match.group(1) + + # Fallback to timestamp + import time + return str(int(time.time())) + + def _extract_xyz_episode_id(self, url: str) -> str: + """Extract episode ID from XiaoYuZhou URL""" + + episode_match = re.search(r'/episode/([^/?]+)', url) + if episode_match: + return episode_match.group(1) + + # Fallback + import time + return str(int(time.time())) + + def _generate_filename( + self, + episode_id: str, + audio_url: str, + output_folder: str, + convert_to_mp3: bool + ) -> str: + """Generate output filename""" + + if convert_to_mp3: + extension = ".mp3" + else: + # Extract extension from URL + parsed_url = urlparse(audio_url) + _, extension = os.path.splitext(parsed_url.path) + if not extension: + extension = ".mp3" + + filename = f"{episode_id}_episode_audio{extension}" + return os.path.join(output_folder, filename) + + def _download_file(self, url: str, output_path: str) -> None: + """Download file from URL""" + + with requests.get(url, stream=True) as response: + response.raise_for_status() + with open(output_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + async def _convert_to_mp3( + self, + input_file: str, + keep_original: bool = False + ) -> str: + """Convert audio file to MP3 format""" + + base_name = os.path.splitext(input_file)[0] + output_file = f"{base_name}.mp3" + + if input_file == output_file: + return input_file # Already MP3 + + try: + cmd = [ + 'ffmpeg', + '-i', input_file, + '-codec:a', 'libmp3lame', + '-b:a', '128k', + '-y', + output_file + ] + + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + subprocess.run, + cmd, + True, # capture_output + True # text + ) + + if result.returncode == 0: + print(f"Successfully converted to: {output_file}") + + if not keep_original: + os.remove(input_file) + print(f"Removed original file: {input_file}") + + return output_file + else: + print(f"Error converting file: {result.stderr}") + return input_file + + except Exception as e: + print(f"Error during conversion: {str(e)}") + return input_file \ No newline at end of file diff --git a/src/services/speaker_embedding_service.py b/src/services/speaker_embedding_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7fa2c51ee17fd852f52e44fa458e5837404a1c88 --- /dev/null +++ b/src/services/speaker_embedding_service.py @@ -0,0 +1,616 @@ +""" +Speaker Embedding Service - manages global speaker embeddings and identification + +This module provides advanced speaker identification and unification across distributed audio chunks +using pyannote.audio embedding models and cosine similarity calculations. + +Key Features: +1. Global Speaker Management: Maintains a persistent database of speaker embeddings +2. Embedding Extraction: Uses pyannote.audio models to extract speaker embeddings from audio segments +3. Speaker Unification: Identifies when speakers in different chunks are the same person +4. Distributed Processing Support: Unifies speakers across multiple transcription chunks + +Usage in Modal Configuration: +- Speaker diarization models are preloaded in modal_config.py download_models() function +- Models include both diarization pipeline and embedding extraction models +- GPU acceleration is used for optimal performance + +Usage in Distributed Transcription: +- DistributedTranscriptionService.merge_chunk_results() calls speaker unification +- Speaker embeddings are extracted for each speaker segment using inference.crop() +- Cosine distance calculations determine if speakers are the same across chunks +- Speaker IDs are unified to prevent duplicate speaker labeling + +Example workflow: +1. Audio is split into chunks for distributed processing +2. Each chunk performs speaker diarization independently (e.g., SPEAKER_00, SPEAKER_01) +3. After all chunks complete, speaker embeddings are extracted for unification +4. Cosine similarity comparison identifies matching speakers across chunks +5. Local speaker IDs are mapped to global unified IDs (e.g., SPEAKER_GLOBAL_001) +6. Final transcription uses consistent speaker labels throughout + +Technical Details: +- Uses pyannote/embedding model for feature extraction +- Cosine distance threshold of 0.3 for speaker matching (configurable) +- Supports both single-file and distributed transcription workflows +- Thread-safe speaker database operations +- Persistent storage in JSON format for speaker history +""" + +import asyncio +import json +import pickle +import threading +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional, List +from dataclasses import asdict + +import numpy as np +import torch +from scipy.spatial.distance import cosine + +from ..interfaces.speaker_manager import ( + ISpeakerEmbeddingManager, + ISpeakerIdentificationService, + SpeakerEmbedding, + SpeakerSegment +) +from ..utils.errors import SpeakerDiarizationError, ModelLoadError +from ..utils.config import AudioProcessingConfig + + +class SpeakerEmbeddingService(ISpeakerEmbeddingManager): + """Global speaker embedding management service""" + + def __init__( + self, + storage_path: str = "global_speakers.json", + similarity_threshold: float = 0.3 + ): + self.storage_path = Path(storage_path) + self.similarity_threshold = similarity_threshold + self.speakers: Dict[str, SpeakerEmbedding] = {} + self.speaker_counter = 0 + self.lock = threading.Lock() + self._loaded = False + + # Don't load speakers in __init__ to avoid async issues + # Loading will happen on first use via _ensure_loaded() + + async def _ensure_loaded(self) -> None: + """Ensure speakers are loaded (called on first use)""" + if not self._loaded: + await self.load_speakers() + self._loaded = True + + async def load_speakers(self) -> None: + """Load speaker data from storage file""" + + if not self.storage_path.exists(): + return + + try: + loop = asyncio.get_event_loop() + data = await loop.run_in_executor(None, self._read_speakers_file) + + self.speakers = { + speaker_id: SpeakerEmbedding( + speaker_id=speaker_data["speaker_id"], + embedding=np.array(speaker_data["embedding"]), + confidence=speaker_data["confidence"], + source_files=speaker_data["source_files"], + sample_count=speaker_data["sample_count"], + created_at=speaker_data["created_at"], + updated_at=speaker_data["updated_at"] + ) + for speaker_id, speaker_data in data.get("speakers", {}).items() + } + self.speaker_counter = data.get("speaker_counter", 0) + + print(f"✅ Loaded {len(self.speakers)} known speakers") + + except Exception as e: + print(f"⚠️ Failed to load speaker data: {e}") + self.speakers = {} + self.speaker_counter = 0 + + async def save_speakers(self) -> None: + """Save speaker data to storage file""" + + try: + data = { + "speakers": { + speaker_id: { + "speaker_id": speaker.speaker_id, + "embedding": speaker.embedding.tolist(), + "confidence": speaker.confidence, + "source_files": speaker.source_files, + "sample_count": speaker.sample_count, + "created_at": speaker.created_at, + "updated_at": speaker.updated_at + } + for speaker_id, speaker in self.speakers.items() + }, + "speaker_counter": self.speaker_counter, + "updated_at": datetime.now().isoformat() + } + + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._write_speakers_file, data) + + print(f"💾 Speaker data saved: {len(self.speakers)} speakers") + + except Exception as e: + print(f"❌ Failed to save speaker data: {e}") + + async def find_matching_speaker( + self, + embedding: np.ndarray, + source_file: str + ) -> Optional[str]: + """Find matching speaker from existing embeddings""" + + await self._ensure_loaded() + + if not self.speakers: + return None + + best_match_id = None + best_similarity = float('inf') + + for speaker_id, speaker in self.speakers.items(): + # Calculate cosine distance + distance = cosine(embedding, speaker.embedding) + + if distance < best_similarity: + best_similarity = distance + best_match_id = speaker_id + + # Check if similarity threshold is met + if best_similarity <= self.similarity_threshold: + print(f"🎯 Found matching speaker: {best_match_id} (distance: {best_similarity:.3f})") + return best_match_id + + print(f"🆕 No matching speaker found (best distance: {best_similarity:.3f} > {self.similarity_threshold})") + return None + + async def add_or_update_speaker( + self, + embedding: np.ndarray, + source_file: str, + confidence: float = 1.0, + original_label: Optional[str] = None + ) -> str: + """Add new speaker or update existing speaker""" + + await self._ensure_loaded() + + with self.lock: + # Find matching speaker + matching_speaker_id = await self.find_matching_speaker(embedding, source_file) + + if matching_speaker_id: + # Update existing speaker + speaker = self.speakers[matching_speaker_id] + + # Update embedding vector using weighted average + weight = 1.0 / (speaker.sample_count + 1) + speaker.embedding = (speaker.embedding * (1 - weight) + embedding * weight) + + # Update other information + if source_file not in speaker.source_files: + speaker.source_files.append(source_file) + speaker.sample_count += 1 + speaker.confidence = max(speaker.confidence, confidence) + speaker.updated_at = datetime.now().isoformat() + + print(f"🔄 Updated speaker {matching_speaker_id}: {speaker.sample_count} samples") + return matching_speaker_id + + else: + # Create new speaker + self.speaker_counter += 1 + new_speaker_id = f"SPEAKER_GLOBAL_{self.speaker_counter:03d}" + + new_speaker = SpeakerEmbedding( + speaker_id=new_speaker_id, + embedding=embedding.copy(), + confidence=confidence, + source_files=[source_file], + sample_count=1, + created_at=datetime.now().isoformat(), + updated_at=datetime.now().isoformat() + ) + + self.speakers[new_speaker_id] = new_speaker + + print(f"🆕 Created new speaker {new_speaker_id}") + return new_speaker_id + + async def map_local_to_global_speakers( + self, + local_embeddings: Dict[str, np.ndarray], + source_file: str + ) -> Dict[str, str]: + """Map local speaker labels to global speaker IDs""" + + mapping = {} + + for local_label, embedding in local_embeddings.items(): + global_id = await self.add_or_update_speaker( + embedding=embedding, + source_file=source_file, + original_label=local_label + ) + mapping[local_label] = global_id + + # Save updated speaker data + await self.save_speakers() + + return mapping + + async def get_speaker_info(self, speaker_id: str) -> Optional[SpeakerEmbedding]: + """Get speaker information by ID""" + return self.speakers.get(speaker_id) + + async def get_all_speakers_summary(self) -> Dict[str, Any]: + """Get summary of all speakers""" + + return { + "total_speakers": len(self.speakers), + "speakers": { + speaker_id: { + "speaker_id": speaker.speaker_id, + "confidence": speaker.confidence, + "source_files_count": len(speaker.source_files), + "sample_count": speaker.sample_count, + "created_at": speaker.created_at, + "updated_at": speaker.updated_at + } + for speaker_id, speaker in self.speakers.items() + } + } + + def _read_speakers_file(self) -> Dict[str, Any]: + """Read speakers file synchronously""" + with open(self.storage_path, 'r', encoding='utf-8') as f: + return json.load(f) + + def _write_speakers_file(self, data: Dict[str, Any]) -> None: + """Write speakers file synchronously""" + # Atomic write + temp_path = self.storage_path.with_suffix('.tmp') + with open(temp_path, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + temp_path.replace(self.storage_path) + + +class SpeakerIdentificationService(ISpeakerIdentificationService): + """Speaker identification service using pyannote.audio""" + + def __init__( + self, + embedding_manager: ISpeakerEmbeddingManager, + config: Optional[AudioProcessingConfig] = None + ): + self.embedding_manager = embedding_manager + self.config = config or AudioProcessingConfig() + self.auth_token = None + self.pipeline = None + self.embedding_model = None + + # Check for HF token + import os + self.auth_token = os.environ.get(self.config.hf_token_env_var) + self.available = self.auth_token is not None + + if not self.available: + print("⚠️ No Hugging Face token found. Speaker identification will be disabled.") + + async def extract_speaker_embeddings( + self, + audio_path: str, + segments: List[SpeakerSegment] + ) -> Dict[str, np.ndarray]: + """Extract speaker embeddings from audio segments""" + + if not self.available: + raise SpeakerDiarizationError("Speaker identification not available - missing HF token") + + try: + # Load models if needed + if self.embedding_model is None: + await self._load_models() + + # Create inference object for embedding extraction + from pyannote.audio.core.inference import Inference + from pyannote.core import Segment + import torchaudio + + inference = Inference(self.embedding_model, window="whole") + + # Load audio file + waveform, sample_rate = torchaudio.load(audio_path) + + embeddings = {} + + # Extract embeddings for each unique speaker + for segment in segments: + if segment.speaker_id not in embeddings: + # Create audio segment for embedding extraction + audio_segment = Segment(segment.start, segment.end) + + # Extract embedding using inference.crop + embedding = inference.crop(waveform, audio_segment) + + # Convert to numpy array and store + if isinstance(embedding, torch.Tensor): + embedding_np = embedding.detach().cpu().numpy() + else: + embedding_np = embedding + + embeddings[segment.speaker_id] = embedding_np + print(f"🎯 Extracted embedding for {segment.speaker_id}: shape {embedding_np.shape}") + + return embeddings + + except Exception as e: + raise SpeakerDiarizationError(f"Embedding extraction failed: {str(e)}") + + async def identify_speakers_in_audio( + self, + audio_path: str, + transcription_segments: List[Dict[str, Any]] + ) -> List[SpeakerSegment]: + """Identify speakers in audio file""" + + if not self.available: + print("⚠️ Speaker identification skipped - not available") + return [] + + try: + # Load pipeline if needed + if self.pipeline is None: + await self._load_models() + + # Perform diarization + loop = asyncio.get_event_loop() + diarization = await loop.run_in_executor( + None, + self.pipeline, + audio_path + ) + + # Convert to speaker segments + speaker_segments = [] + + for turn, _, speaker in diarization.itertracks(yield_label=True): + speaker_id = f"SPEAKER_{speaker.split('_')[-1].zfill(2)}" + speaker_segments.append(SpeakerSegment( + start=turn.start, + end=turn.end, + speaker_id=speaker_id, + confidence=1.0 # pyannote doesn't provide confidence + )) + + return speaker_segments + + except Exception as e: + raise SpeakerDiarizationError(f"Speaker identification failed: {str(e)}") + + async def map_transcription_to_speakers( + self, + transcription_segments: List[Dict[str, Any]], + speaker_segments: List[SpeakerSegment] + ) -> List[Dict[str, Any]]: + """Map transcription segments to speaker information""" + + result_segments = [] + + for trans_seg in transcription_segments: + trans_start = trans_seg["start"] + trans_end = trans_seg["end"] + + # Find overlapping speaker segment + best_speaker = None + best_overlap = 0 + + for speaker_seg in speaker_segments: + # Calculate overlap + overlap_start = max(trans_start, speaker_seg.start) + overlap_end = min(trans_end, speaker_seg.end) + overlap = max(0, overlap_end - overlap_start) + + if overlap > best_overlap: + best_overlap = overlap + best_speaker = speaker_seg.speaker_id + + # Add speaker information to transcription segment + result_segment = trans_seg.copy() + result_segment["speaker"] = best_speaker + result_segments.append(result_segment) + + return result_segments + + async def unify_distributed_speakers( + self, + chunk_results: List[Dict[str, Any]], + audio_file_path: str + ) -> Dict[str, str]: + """ + Unify speaker identifications across distributed chunks using embedding similarity + + Args: + chunk_results: List of chunk transcription results with speaker information + audio_file_path: Path to the original audio file for embedding extraction + + Returns: + Mapping from local chunk speaker IDs to unified global speaker IDs + """ + if not self.available: + print("⚠️ Speaker unification skipped - embedding service not available") + return {} + + try: + # Load models if needed + if self.embedding_model is None: + await self._load_models() + + from pyannote.audio.core.inference import Inference + from pyannote.core import Segment + import torchaudio + from scipy.spatial.distance import cosine + + inference = Inference(self.embedding_model, window="whole") + waveform, sample_rate = torchaudio.load(audio_file_path) + + # Collect all speaker segments from chunks with their chunk context + all_speaker_segments = [] + + for chunk_idx, chunk in enumerate(chunk_results): + if chunk.get("processing_status") != "success": + continue + + chunk_start_time = chunk.get("chunk_start_time", 0) + segments = chunk.get("segments", []) + + for segment in segments: + if "speaker" in segment and segment["speaker"]: + # Create unique chunk-local speaker ID + chunk_speaker_id = f"chunk_{chunk_idx}_{segment['speaker']}" + + all_speaker_segments.append({ + "chunk_speaker_id": chunk_speaker_id, + "original_speaker_id": segment["speaker"], + "chunk_index": chunk_idx, + "start": segment["start"] + chunk_start_time, + "end": segment["end"] + chunk_start_time, + "text": segment.get("text", "") + }) + + if not all_speaker_segments: + return {} + + # Extract embeddings for each unique chunk speaker + speaker_embeddings = {} + + for seg in all_speaker_segments: + chunk_speaker_id = seg["chunk_speaker_id"] + + if chunk_speaker_id not in speaker_embeddings: + try: + # Create audio segment for embedding extraction + audio_segment = Segment(seg["start"], seg["end"]) + + # Extract embedding using inference.crop + embedding = inference.crop(waveform, audio_segment) + + # Convert to numpy array + if hasattr(embedding, 'detach'): + embedding_np = embedding.detach().cpu().numpy() + else: + embedding_np = embedding + + speaker_embeddings[chunk_speaker_id] = embedding_np + print(f"🎯 Extracted embedding for {chunk_speaker_id}: shape {embedding_np.shape}") + + except Exception as e: + print(f"⚠️ Failed to extract embedding for {chunk_speaker_id}: {e}") + continue + + # Perform speaker clustering based on embedding similarity + unified_mapping = {} + global_speaker_counter = 1 + similarity_threshold = 0.3 # Cosine distance threshold + + for chunk_speaker_id, embedding in speaker_embeddings.items(): + best_match_id = None + best_distance = float('inf') + + # Compare with existing unified speakers + for existing_id, mapped_global_id in unified_mapping.items(): + if existing_id != chunk_speaker_id and existing_id in speaker_embeddings: + existing_embedding = speaker_embeddings[existing_id] + + try: + # Calculate cosine distance + distance = cosine(embedding.flatten(), existing_embedding.flatten()) + + if distance < best_distance: + best_distance = distance + best_match_id = mapped_global_id + except Exception as e: + print(f"⚠️ Error calculating distance: {e}") + continue + + # Assign speaker ID based on similarity + if best_match_id and best_distance <= similarity_threshold: + unified_mapping[chunk_speaker_id] = best_match_id + print(f"🎯 Unified {chunk_speaker_id} -> {best_match_id} (distance: {best_distance:.3f})") + else: + # Create new unified speaker ID + new_global_id = f"SPEAKER_GLOBAL_{global_speaker_counter:03d}" + unified_mapping[chunk_speaker_id] = new_global_id + global_speaker_counter += 1 + print(f"🆕 New speaker {chunk_speaker_id} -> {new_global_id}") + + # Create final mapping from original speaker IDs to global IDs + final_mapping = {} + for seg in all_speaker_segments: + chunk_speaker_id = seg["chunk_speaker_id"] + original_id = seg["original_speaker_id"] + + if chunk_speaker_id in unified_mapping: + # Create a key that includes chunk context for uniqueness + mapping_key = f"chunk_{seg['chunk_index']}_{original_id}" + final_mapping[mapping_key] = unified_mapping[chunk_speaker_id] + + print(f"🎤 Speaker unification completed: {len(set(unified_mapping.values()))} global speakers from {len(speaker_embeddings)} chunk speakers") + return final_mapping + + except Exception as e: + print(f"❌ Speaker unification failed: {e}") + return {} + + async def _load_models(self) -> None: + """Load pyannote.audio models""" + + try: + # Suppress warnings + import warnings + warnings.filterwarnings("ignore", category=UserWarning, module="pyannote") + warnings.filterwarnings("ignore", category=UserWarning, module="pytorch_lightning") + warnings.filterwarnings("ignore", category=FutureWarning, module="pytorch_lightning") + + from pyannote.audio import Model, Pipeline + from pyannote.audio.core.inference import Inference + import torch + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load embedding model + loop = asyncio.get_event_loop() + + self.embedding_model = await loop.run_in_executor( + None, + Model.from_pretrained, + "pyannote/embedding", + self.auth_token + ) + self.embedding_model.to(device) + self.embedding_model.eval() + + # Load diarization pipeline + self.pipeline = await loop.run_in_executor( + None, + Pipeline.from_pretrained, + "pyannote/speaker-diarization-3.1", + self.auth_token + ) + self.pipeline.to(device) + + print("✅ Speaker identification models loaded") + + except Exception as e: + raise ModelLoadError(f"Failed to load speaker models: {str(e)}") \ No newline at end of file diff --git a/src/services/transcription_service.py b/src/services/transcription_service.py new file mode 100644 index 0000000000000000000000000000000000000000..507b55dad6068f72e49cd9bee7bf722df3f14efa --- /dev/null +++ b/src/services/transcription_service.py @@ -0,0 +1,680 @@ +""" +Transcription Service +Handles audio transcription logic with support for parallel processing +""" + +import whisper +import os +import json +import tempfile +import subprocess +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Any, List + + +class TranscriptionService: + """Service for handling audio transcription""" + + def __init__(self, cache_dir: str = "/tmp"): + self.cache_dir = cache_dir + + def _load_cached_model(self, model_size: str = "turbo"): + """Load Whisper model from cache directory if available""" + try: + # Try to load from preloaded cache first + model_cache_dir = "/model" + if os.path.exists(model_cache_dir): + print(f"📦 Loading {model_size} model from cache: {model_cache_dir}") + # Set download root to cache directory + model = whisper.load_model(model_size, download_root=model_cache_dir) + print(f"✅ Successfully loaded {model_size} model from cache") + return model + else: + print(f"⚠️ Cache directory not found, downloading {model_size} model...") + return whisper.load_model(model_size) + except Exception as e: + print(f"⚠️ Failed to load cached model, downloading: {e}") + return whisper.load_model(model_size) + + def _load_speaker_diarization_pipeline(self): + """Load speaker diarization pipeline from cache if available""" + try: + speaker_cache_dir = "/model/speaker-diarization" + config_file = os.path.join(speaker_cache_dir, "download_complete.json") + + # Set proper cache directory for pyannote + os.environ["PYANNOTE_CACHE"] = "/model/speaker-diarization" + + # Check if cached speaker diarization models exist + if os.path.exists(config_file): + print(f"📦 Loading speaker diarization from cache: {speaker_cache_dir}") + # Load from cache with proper cache_dir + from pyannote.audio import Pipeline + + pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=os.environ.get("HF_TOKEN"), + cache_dir="/model/speaker-diarization" + ) + print("✅ Successfully loaded speaker diarization pipeline from cache") + return pipeline + else: + print("⚠️ Speaker diarization cache not found, downloading...") + # Download fresh if cache not available + from pyannote.audio import Pipeline + pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=os.environ.get("HF_TOKEN"), + cache_dir="/model/speaker-diarization" + ) + + # Create marker file to indicate successful download + import json + config = { + "model_name": "pyannote/speaker-diarization-3.1", + "cached_at": speaker_cache_dir, + "cache_complete": True, + "runtime_download": True + } + with open(config_file, "w") as f: + json.dump(config, f) + + return pipeline + except Exception as e: + print(f"⚠️ Failed to load speaker diarization pipeline: {e}") + return None + + def transcribe_audio( + self, + audio_file_path: str, + model_size: str = "turbo", + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False + ) -> Dict[str, Any]: + """ + Transcribe audio file using Whisper + + Args: + audio_file_path: Path to audio file + model_size: Whisper model size + language: Language code (optional) + output_format: Output format + enable_speaker_diarization: Enable speaker identification + + Returns: + Transcription result dictionary + """ + try: + print(f"🎤 Starting transcription for: {audio_file_path}") + print(f"🚀 Using model: {model_size}") + + # Check if file exists + if not os.path.exists(audio_file_path): + raise FileNotFoundError(f"Audio file not found: {audio_file_path}") + + # Load Whisper model from cache + model = self._load_cached_model(model_size) + + # Load speaker diarization pipeline if enabled + speaker_pipeline = None + if enable_speaker_diarization: + speaker_pipeline = self._load_speaker_diarization_pipeline() + if speaker_pipeline is None: + print("⚠️ Speaker diarization disabled due to loading failure") + enable_speaker_diarization = False + + # Transcribe audio + transcribe_options = { + "language": language if language and language != "auto" else None, + "task": "transcribe", + "verbose": True + } + + print(f"🔄 Transcribing with options: {transcribe_options}") + result = model.transcribe(audio_file_path, **transcribe_options) + + # Extract information + text = result.get("text", "").strip() + segments = result.get("segments", []) + language_detected = result.get("language", "unknown") + + # Apply speaker diarization if enabled + speaker_segments = [] + global_speaker_count = 0 + speaker_summary = {} + + if enable_speaker_diarization and speaker_pipeline: + try: + print("👥 Applying speaker diarization...") + diarization_result = speaker_pipeline(audio_file_path) + + # Process diarization results + speakers = set() + for turn, _, speaker in diarization_result.itertracks(yield_label=True): + speakers.add(speaker) + speaker_segments.append({ + "start": turn.start, + "end": turn.end, + "speaker": speaker + }) + + global_speaker_count = len(speakers) + speaker_summary = {f"SPEAKER_{i:02d}": speaker for i, speaker in enumerate(sorted(speakers))} + + # Merge speaker information with transcription segments + segments = self._merge_speaker_segments(segments, speaker_segments) + + print(f"✅ Speaker diarization completed: {global_speaker_count} speakers detected") + + except Exception as e: + print(f"⚠️ Speaker diarization failed: {e}") + enable_speaker_diarization = False + + # Generate output files + output_files = self._generate_output_files( + audio_file_path, text, segments, enable_speaker_diarization + ) + + # Get audio duration + audio_duration = 0.0 + if segments: + audio_duration = max(seg.get("end", 0) for seg in segments) + + print(f"✅ Transcription completed successfully") + print(f" Text length: {len(text)} characters") + print(f" Segments: {len(segments)}") + print(f" Duration: {audio_duration:.2f}s") + print(f" Language: {language_detected}") + + return { + "txt_file_path": output_files.get("txt_file"), + "srt_file_path": output_files.get("srt_file"), + "audio_file": audio_file_path, + "model_used": model_size, + "segment_count": len(segments), + "audio_duration": audio_duration, + "processing_status": "success", + "saved_files": [f for f in output_files.values() if f], + "speaker_diarization_enabled": enable_speaker_diarization, + "global_speaker_count": global_speaker_count, + "speaker_summary": speaker_summary, + "language_detected": language_detected, + "text": text, + "segments": [ + { + "start": seg.get("start", 0), + "end": seg.get("end", 0), + "text": seg.get("text", "").strip(), + "speaker": seg.get("speaker", None) + } + for seg in segments + ] + } + + except Exception as e: + print(f"❌ Transcription failed: {e}") + return self._create_error_result(audio_file_path, model_size, str(e)) + + def _merge_speaker_segments(self, transcription_segments: List[Dict], speaker_segments: List[Dict]) -> List[Dict]: + """ + Merge speaker information with transcription segments, splitting transcription segments + when multiple speakers are detected within a single segment + """ + merged_segments = [] + + for trans_seg in transcription_segments: + trans_start = trans_seg.get("start", 0) + trans_end = trans_seg.get("end", 0) + trans_text = trans_seg.get("text", "").strip() + + # Find all overlapping speaker segments + overlapping_speakers = [] + for speaker_seg in speaker_segments: + speaker_start = speaker_seg["start"] + speaker_end = speaker_seg["end"] + + # Check if there's any overlap + overlap_start = max(trans_start, speaker_start) + overlap_end = min(trans_end, speaker_end) + overlap_duration = max(0, overlap_end - overlap_start) + + if overlap_duration > 0: + overlapping_speakers.append({ + "speaker": speaker_seg["speaker"], + "start": speaker_start, + "end": speaker_end, + "overlap_start": overlap_start, + "overlap_end": overlap_end, + "overlap_duration": overlap_duration + }) + + if not overlapping_speakers: + # No speaker detected, keep original segment + merged_seg = trans_seg.copy() + merged_seg["speaker"] = None + merged_segments.append(merged_seg) + continue + + # Sort overlapping speakers by start time + overlapping_speakers.sort(key=lambda x: x["overlap_start"]) + + if len(overlapping_speakers) == 1: + # Single speaker for this transcription segment + merged_seg = trans_seg.copy() + merged_seg["speaker"] = overlapping_speakers[0]["speaker"] + merged_segments.append(merged_seg) + else: + # Multiple speakers detected - split the transcription segment + print(f"🔄 Splitting segment ({trans_start:.2f}s-{trans_end:.2f}s) with {len(overlapping_speakers)} speakers") + split_segments = self._split_transcription_segment( + trans_seg, overlapping_speakers, trans_text + ) + merged_segments.extend(split_segments) + + return merged_segments + + def _split_transcription_segment(self, trans_seg: Dict, overlapping_speakers: List[Dict], trans_text: str) -> List[Dict]: + """ + Split a transcription segment into multiple segments based on speaker changes + """ + split_segments = [] + trans_start = trans_seg.get("start", 0) + trans_end = trans_seg.get("end", 0) + trans_duration = trans_end - trans_start + + # Calculate the proportion of text for each speaker based on overlap duration + total_overlap_duration = sum(sp["overlap_duration"] for sp in overlapping_speakers) + + if total_overlap_duration == 0: + # Fallback: equal distribution + text_per_speaker = len(trans_text) // len(overlapping_speakers) + + current_text_pos = 0 + for i, speaker_info in enumerate(overlapping_speakers): + # Calculate text portion for this speaker + if total_overlap_duration > 0: + text_proportion = speaker_info["overlap_duration"] / total_overlap_duration + else: + text_proportion = 1.0 / len(overlapping_speakers) + + # Calculate text length for this speaker + if i == len(overlapping_speakers) - 1: + # Last speaker gets remaining text + speaker_text_length = len(trans_text) - current_text_pos + else: + speaker_text_length = int(len(trans_text) * text_proportion) + + # Extract text for this speaker + speaker_text_end = min(current_text_pos + speaker_text_length, len(trans_text)) + speaker_text = trans_text[current_text_pos:speaker_text_end].strip() + + # Adjust word boundaries to avoid cutting words in half + if speaker_text_end < len(trans_text) and i < len(overlapping_speakers) - 1: + # Find the last complete word + last_space = speaker_text.rfind(' ') + if last_space > 0: + speaker_text = speaker_text[:last_space] + speaker_text_end = current_text_pos + last_space + 1 # +1 to skip the space + else: + # If no space found, keep original text but update position + speaker_text_end = current_text_pos + speaker_text_length + + # Use actual speaker diarization timing directly + segment_start = speaker_info["overlap_start"] + segment_end = speaker_info["overlap_end"] + + # Always create segment if we have valid timing, even with empty text + if segment_start < segment_end: + split_segment = { + "start": segment_start, + "end": segment_end, + "text": speaker_text, + "speaker": speaker_info["speaker"] + } + split_segments.append(split_segment) + print(f" → {speaker_info['speaker']}: {segment_start:.2f}s-{segment_end:.2f}s: \"{speaker_text[:50]}{'...' if len(speaker_text) > 50 else ''}\"") + + current_text_pos = speaker_text_end + + return split_segments + + def transcribe_audio_parallel( + self, + audio_file_path: str, + model_size: str = "turbo", + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False, + chunk_duration: int = 300 + ) -> Dict[str, Any]: + """ + Transcribe audio with parallel processing for long files + + Args: + audio_file_path: Path to audio file + model_size: Whisper model size + language: Language code (optional) + output_format: Output format + enable_speaker_diarization: Enable speaker identification + chunk_duration: Duration of chunks in seconds + + Returns: + Transcription result dictionary + """ + try: + print(f"🎤 Starting parallel transcription for: {audio_file_path}") + print(f"🚀 Using model: {model_size}") + print(f"⚡ Chunk duration: {chunk_duration}s") + + # Check if file exists + if not os.path.exists(audio_file_path): + raise FileNotFoundError(f"Audio file not found: {audio_file_path}") + + # Get audio duration + total_duration = self._get_audio_duration(audio_file_path) + print(f"📊 Total audio duration: {total_duration:.2f}s") + + # If audio is shorter than chunk duration, use single processing + if total_duration <= chunk_duration: + print("📝 Audio is short, using single processing") + return self.transcribe_audio( + audio_file_path, model_size, language, output_format, enable_speaker_diarization + ) + + # Split audio into chunks + chunks = self._split_audio_into_chunks(audio_file_path, chunk_duration, total_duration) + print(f"🔀 Created {len(chunks)} chunks for parallel processing") + + # Load Whisper model once from cache + model = self._load_cached_model(model_size) + + # Process chunks in parallel + chunk_results = self._process_chunks_parallel(chunks, model, language) + + # Combine results + combined_result = self._combine_chunk_results( + chunk_results, audio_file_path, model_size, + enable_speaker_diarization, total_duration + ) + + # Cleanup chunk files + self._cleanup_chunks(chunks) + + return combined_result + + except Exception as e: + print(f"❌ Parallel transcription failed: {e}") + result = self._create_error_result(audio_file_path, model_size, str(e)) + result["parallel_processing"] = True + return result + + def normalize_audio_file(self, input_file: str, output_file: str = None) -> str: + """ + Normalize audio file for better Whisper compatibility + + Args: + input_file: Input audio file path + output_file: Output file path (optional) + + Returns: + Path to normalized audio file + """ + if output_file is None: + temp_dir = tempfile.mkdtemp() + output_file = os.path.join(temp_dir, "normalized_audio.wav") + + # Convert to standardized format: 16kHz, mono, PCM + cmd = [ + "ffmpeg", "-i", input_file, + "-ar", "16000", # 16kHz sample rate (Whisper's native) + "-ac", "1", # Mono channel + "-c:a", "pcm_s16le", # PCM 16-bit encoding + "-y", # Overwrite output file + output_file + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print(f"⚠️ FFmpeg normalization failed: {result.stderr}") + return input_file # Return original file if normalization fails + else: + print("✅ Audio normalized for Whisper") + return output_file + + def _get_audio_duration(self, audio_file_path: str) -> float: + """Get audio duration using ffprobe""" + cmd = ["ffprobe", "-v", "quiet", "-show_entries", "format=duration", "-of", "csv=p=0", audio_file_path] + duration_output = subprocess.run(cmd, capture_output=True, text=True) + return float(duration_output.stdout.strip()) + + def _split_audio_into_chunks(self, audio_file_path: str, chunk_duration: int, total_duration: float) -> List[Dict]: + """Split audio file into chunks for parallel processing""" + chunks = [] + temp_dir = tempfile.mkdtemp() + + for i in range(0, int(total_duration), chunk_duration): + chunk_start = i + chunk_end = min(i + chunk_duration, total_duration) + chunk_file = os.path.join(temp_dir, f"chunk_{i//chunk_duration:03d}.wav") + + # Extract chunk using ffmpeg + cmd = [ + "ffmpeg", "-i", audio_file_path, + "-ss", str(chunk_start), + "-t", str(chunk_end - chunk_start), + "-c:a", "pcm_s16le", # Use PCM encoding for better quality + "-ar", "16000", # 16kHz sample rate for Whisper + chunk_file + ] + + subprocess.run(cmd, capture_output=True) + + if os.path.exists(chunk_file): + chunks.append({ + "file": chunk_file, + "start_time": chunk_start, + "end_time": chunk_end, + "index": len(chunks), + "temp_dir": temp_dir + }) + print(f"📦 Created chunk {len(chunks)}: {chunk_start:.1f}s-{chunk_end:.1f}s") + + return chunks + + def _process_chunks_parallel(self, chunks: List[Dict], model, language: str) -> List[Dict]: + """Process audio chunks in parallel""" + def process_chunk(chunk_info): + try: + print(f"🔄 Processing chunk {chunk_info['index']}: {chunk_info['start_time']:.1f}s-{chunk_info['end_time']:.1f}s") + + transcribe_options = { + "language": language if language and language != "auto" else None, + "task": "transcribe", + "verbose": False # Reduce verbosity for parallel processing + } + + result = model.transcribe(chunk_info["file"], **transcribe_options) + + # Adjust segment timing to global timeline + segments = [] + for seg in result.get("segments", []): + adjusted_seg = { + "start": seg["start"] + chunk_info["start_time"], + "end": seg["end"] + chunk_info["start_time"], + "text": seg["text"].strip(), + "speaker": None + } + segments.append(adjusted_seg) + + print(f"✅ Chunk {chunk_info['index']} completed: {len(segments)} segments") + + return { + "text": result.get("text", "").strip(), + "segments": segments, + "language": result.get("language", "unknown"), + "chunk_index": chunk_info["index"] + } + + except Exception as e: + print(f"❌ Chunk {chunk_info['index']} failed: {e}") + return { + "text": "", + "segments": [], + "language": "unknown", + "chunk_index": chunk_info["index"], + "error": str(e) + } + + # Process chunks in parallel using ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=min(len(chunks), 8)) as executor: + chunk_results = list(executor.map(process_chunk, chunks)) + + # Sort results by chunk index + chunk_results.sort(key=lambda x: x["chunk_index"]) + return chunk_results + + def _combine_chunk_results( + self, + chunk_results: List[Dict], + audio_file_path: str, + model_size: str, + enable_speaker_diarization: bool, + total_duration: float + ) -> Dict[str, Any]: + """Combine results from multiple chunks""" + # Combine results + full_text = " ".join([chunk["text"] for chunk in chunk_results if chunk["text"]]) + all_segments = [] + for chunk in chunk_results: + all_segments.extend(chunk["segments"]) + + # Sort segments by start time + all_segments.sort(key=lambda x: x["start"]) + + # Get detected language (use most common one) + languages = [chunk["language"] for chunk in chunk_results if chunk["language"] != "unknown"] + language_detected = max(set(languages), key=languages.count) if languages else "unknown" + + # Generate output files + output_files = self._generate_output_files( + audio_file_path, full_text, all_segments, enable_speaker_diarization + ) + + print(f"✅ Parallel transcription completed successfully") + print(f" Text length: {len(full_text)} characters") + print(f" Total segments: {len(all_segments)}") + print(f" Duration: {total_duration:.2f}s") + print(f" Language: {language_detected}") + print(f" Chunks processed: {len(chunk_results)}") + + return { + "txt_file_path": output_files.get("txt_file"), + "srt_file_path": output_files.get("srt_file"), + "audio_file": audio_file_path, + "model_used": model_size, + "segment_count": len(all_segments), + "audio_duration": total_duration, + "processing_status": "success", + "saved_files": [f for f in output_files.values() if f], + "speaker_diarization_enabled": enable_speaker_diarization, + "global_speaker_count": 0, + "speaker_summary": {}, + "language_detected": language_detected, + "text": full_text, + "segments": all_segments, + "chunks_processed": len(chunk_results), + "parallel_processing": True + } + + def _cleanup_chunks(self, chunks: List[Dict]): + """Clean up temporary chunk files""" + temp_dirs = set() + for chunk in chunks: + try: + if os.path.exists(chunk["file"]): + os.remove(chunk["file"]) + temp_dirs.add(chunk["temp_dir"]) + except Exception as e: + print(f"⚠️ Failed to cleanup chunk file: {e}") + + # Remove temp directories + for temp_dir in temp_dirs: + try: + os.rmdir(temp_dir) + except Exception as e: + print(f"⚠️ Failed to cleanup temp directory: {e}") + + def _generate_output_files( + self, + audio_file_path: str, + text: str, + segments: List[Dict], + enable_speaker_diarization: bool + ) -> Dict[str, str]: + """Generate output files (TXT and SRT)""" + base_path = Path(audio_file_path).with_suffix("") + output_files = {} + + # Generate TXT file + if text: + txt_file = f"{base_path}.txt" + with open(txt_file, 'w', encoding='utf-8') as f: + f.write(text) + output_files["txt_file"] = txt_file + + # Generate SRT file + if segments: + srt_file = f"{base_path}.srt" + srt_content = self._generate_srt_content(segments, enable_speaker_diarization) + with open(srt_file, 'w', encoding='utf-8') as f: + f.write(srt_content) + output_files["srt_file"] = srt_file + + return output_files + + def _generate_srt_content(self, segments: List[Dict], include_speakers: bool = False) -> str: + """Generate SRT format content from segments""" + srt_lines = [] + + for i, segment in enumerate(segments, 1): + start_time = self._format_timestamp(segment.get('start', 0)) + end_time = self._format_timestamp(segment.get('end', 0)) + text = segment.get('text', '').strip() + + if include_speakers and segment.get('speaker'): + text = f"[{segment['speaker']}] {text}" + + srt_lines.append(f"{i}") + srt_lines.append(f"{start_time} --> {end_time}") + srt_lines.append(text) + srt_lines.append("") # Empty line between segments + + return "\n".join(srt_lines) + + def _format_timestamp(self, seconds: float) -> str: + """Format timestamp for SRT format""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + milliseconds = int((seconds % 1) * 1000) + + return f"{hours:02d}:{minutes:02d}:{secs:02d},{milliseconds:03d}" + + def _create_error_result(self, audio_file_path: str, model_size: str, error_message: str) -> Dict[str, Any]: + """Create error result dictionary""" + return { + "txt_file_path": None, + "srt_file_path": None, + "audio_file": audio_file_path, + "model_used": model_size, + "segment_count": 0, + "audio_duration": 0, + "processing_status": "failed", + "saved_files": [], + "speaker_diarization_enabled": False, + "global_speaker_count": 0, + "speaker_summary": {}, + "error_message": error_message + } \ No newline at end of file diff --git a/src/start_local.py b/src/start_local.py new file mode 100644 index 0000000000000000000000000000000000000000..c76ff7d68edc415591d38d333b63d2caf739d6d3 --- /dev/null +++ b/src/start_local.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +""" +Local mode startup script +Sets environment variables and starts the application in local mode +""" + +import os +import sys + +def main(): + """Start application in local mode""" + + print("🏠 Starting Gradio MCP Server in LOCAL mode") + print("💡 GPU functions will be routed to Modal endpoints") + + # Set deployment mode to local + os.environ["DEPLOYMENT_MODE"] = "local" + + # Add parent directory to path for src imports + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + sys.path.insert(0, parent_dir) + + # Import and run the app + from src.app import run_local + + try: + run_local() + except KeyboardInterrupt: + print("\n🛑 Server stopped by user") + except Exception as e: + print(f"❌ Error starting server: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/start_modal.py b/src/start_modal.py new file mode 100644 index 0000000000000000000000000000000000000000..8359225a022c4e2b3b70a4fa7db64e1945453c86 --- /dev/null +++ b/src/start_modal.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +""" +Modal mode deployment script +Sets environment variables and deploys the application to Modal +""" + +import os +import sys +import subprocess + +def main(): + """Deploy application to Modal""" + + print("☁️ Deploying Gradio MCP Server to MODAL") + print("🚀 GPU functions will run locally on Modal") + + # Set deployment mode to modal + os.environ["DEPLOYMENT_MODE"] = "modal" + + try: + # Deploy to Modal using modal deploy command + print("🚀 Deploying to Modal...") + result = subprocess.run([ + "modal", "deploy", "src.app::gradio_mcp_app" + ], check=True, capture_output=True, text=True) + + print("✅ Successfully deployed to Modal!") + print("Output:", result.stdout) + + except subprocess.CalledProcessError as e: + print(f"❌ Modal deployment failed: {e}") + print("Error output:", e.stderr) + sys.exit(1) + except FileNotFoundError: + print("❌ Modal CLI not found. Please install it with: pip install modal") + sys.exit(1) + except Exception as e: + print(f"❌ Unexpected error: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/test_deployment.py b/src/test_deployment.py new file mode 100644 index 0000000000000000000000000000000000000000..0337b0abd8a0b8fd52329a8d93be43d78e5bd385 --- /dev/null +++ b/src/test_deployment.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +""" +Test script to verify deployment configuration +""" + +import os +import sys + +def test_local_mode(): + """Test local mode configuration""" + print("🧪 Testing LOCAL mode configuration...") + + # Set local mode + os.environ["DEPLOYMENT_MODE"] = "local" + + try: + from config import is_local_mode, is_modal_mode, get_cache_dir + + assert is_local_mode() == True, "Should be in local mode" + assert is_modal_mode() == False, "Should not be in modal mode" + + cache_dir = get_cache_dir() + assert "gradio_mcp_cache" in cache_dir, f"Cache dir should be local: {cache_dir}" + + print("✅ Local mode configuration OK") + return True + + except Exception as e: + print(f"❌ Local mode test failed: {e}") + return False + +def test_modal_mode(): + """Test modal mode configuration""" + print("🧪 Testing MODAL mode configuration...") + + # Set modal mode + os.environ["DEPLOYMENT_MODE"] = "modal" + + try: + # Clear config module cache to reload with new env var + if 'config' in sys.modules: + del sys.modules['config'] + + from config import is_local_mode, is_modal_mode, get_cache_dir + + assert is_local_mode() == False, "Should not be in local mode" + assert is_modal_mode() == True, "Should be in modal mode" + + cache_dir = get_cache_dir() + assert cache_dir == "/root/cache", f"Cache dir should be modal: {cache_dir}" + + print("✅ Modal mode configuration OK") + return True + + except Exception as e: + print(f"❌ Modal mode test failed: {e}") + return False + +def test_gpu_adapters(): + """Test GPU adapters""" + print("🧪 Testing GPU adapters...") + + try: + from gpu_adapters import transcribe_audio_adaptive_sync + + # This should not crash, even if endpoints are not configured + result = transcribe_audio_adaptive_sync( + "test_file.mp3", + "turbo", + None, + "srt", + False + ) + + # Should return error result but not crash + assert "processing_status" in result or "error_message" in result, "Should return valid result structure" + + print("✅ GPU adapters OK") + return True + + except Exception as e: + print(f"❌ GPU adapters test failed: {e}") + return False + +def test_imports(): + """Test all imports work correctly""" + print("🧪 Testing imports...") + + try: + # Test config imports + from config import DeploymentMode, get_deployment_mode + + # Test MCP tools imports + from mcp_tools import get_mcp_server + + # Test app imports (should work in both modes) + from app import create_app, main, get_app + + print("✅ All imports OK") + return True + + except Exception as e: + print(f"❌ Import test failed: {e}") + return False + +def test_hf_spaces_mode(): + """Test Hugging Face Spaces mode""" + print("🧪 Testing HF Spaces mode...") + + try: + # Clear deployment mode to simulate HF Spaces + old_mode = os.environ.get("DEPLOYMENT_MODE") + if "DEPLOYMENT_MODE" in os.environ: + del os.environ["DEPLOYMENT_MODE"] + + # Clear config module cache + if 'config' in sys.modules: + del sys.modules['config'] + if 'app' in sys.modules: + del sys.modules['app'] + + from app import get_app + + app = get_app() + assert app is not None, "Should create app for HF Spaces" + + # Restore environment + if old_mode: + os.environ["DEPLOYMENT_MODE"] = old_mode + + print("✅ HF Spaces mode OK") + return True + + except Exception as e: + print(f"❌ HF Spaces test failed: {e}") + return False + +def main(): + """Run all tests""" + print("🚀 Running deployment configuration tests...\n") + + tests = [ + test_imports, + test_local_mode, + test_modal_mode, + test_hf_spaces_mode, + test_gpu_adapters, + ] + + passed = 0 + total = len(tests) + + for test in tests: + try: + if test(): + passed += 1 + except Exception as e: + print(f"❌ Test {test.__name__} crashed: {e}") + print() + + print(f"📊 Test Results: {passed}/{total} passed") + + if passed == total: + print("🎉 All tests passed! Deployment configuration is ready.") + return 0 + else: + print("⚠️ Some tests failed. Please check the configuration.") + return 1 + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/src/tools/__init__.py b/src/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a9ebb6e7f348d597adb89544090893bb89a3476 --- /dev/null +++ b/src/tools/__init__.py @@ -0,0 +1,5 @@ +""" +Tools Module - MCP tools and utilities +""" + +__all__ = [] \ No newline at end of file diff --git a/src/tools/__pycache__/__init__.cpython-310.pyc b/src/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..026e3d2e841bdca9c855db97071f007f798253cd Binary files /dev/null and b/src/tools/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/tools/__pycache__/download_tools.cpython-310.pyc b/src/tools/__pycache__/download_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..022a091de30cbc745cb0321b4bf92fd164325c29 Binary files /dev/null and b/src/tools/__pycache__/download_tools.cpython-310.pyc differ diff --git a/src/tools/__pycache__/mcp_tools.cpython-310.pyc b/src/tools/__pycache__/mcp_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5e9aeea043f4102dd9ad8fb025d8a3cb5b49d1b Binary files /dev/null and b/src/tools/__pycache__/mcp_tools.cpython-310.pyc differ diff --git a/src/tools/__pycache__/storage_tools.cpython-310.pyc b/src/tools/__pycache__/storage_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc8bbcab1da5a4b7c9d2d39acdd3e63d30ac3bf0 Binary files /dev/null and b/src/tools/__pycache__/storage_tools.cpython-310.pyc differ diff --git a/src/tools/__pycache__/transcription_tools.cpython-310.pyc b/src/tools/__pycache__/transcription_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0512507fde234eb7708b7c7a7fdea252c3e888 Binary files /dev/null and b/src/tools/__pycache__/transcription_tools.cpython-310.pyc differ diff --git a/src/tools/download_tools.py b/src/tools/download_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddde36519a8ed864159be0c2504e0a649cfc488 --- /dev/null +++ b/src/tools/download_tools.py @@ -0,0 +1,220 @@ +""" +Download tools using local service architecture +Updated to use PodcastDownloadService for local execution only +""" + +import asyncio +import os +import json +import time +from pathlib import Path +from typing import Dict, Any + +from ..services import PodcastDownloadService, FileManagementService +from ..models.services import PodcastDownloadRequest + + +# Global service instances for reuse +_podcast_download_service = None +_file_management_service = None + + +def get_podcast_download_service() -> PodcastDownloadService: + """Get or create global PodcastDownloadService instance for local downloads""" + global _podcast_download_service + if _podcast_download_service is None: + # Use storage config for download folder + _podcast_download_service = PodcastDownloadService() # Will use storage config defaults + return _podcast_download_service + + +def get_file_management_service() -> FileManagementService: + """Get or create global FileManagementService instance""" + global _file_management_service + if _file_management_service is None: + _file_management_service = FileManagementService() + return _file_management_service + + +async def download_apple_podcast_tool(url: str) -> Dict[str, Any]: + """ + Download Apple Podcast audio files and save to specified directory (LOCAL EXECUTION). + + Args: + url: Complete URL of Apple Podcast page + + Returns: + Download result dictionary containing the following key fields: + - "status" (str): Download status, "success" or "failed" + - "original_url" (str): Input original podcast URL + - "audio_file_path" (str|None): Complete MP3 file path when successful, None when failed + - "error_message" (str): Only exists when failed, contains specific error description + """ + try: + print(f"🏠 Downloading Apple Podcast locally: {url}") + service = get_podcast_download_service() + + # Use local download service + result = await service.download_podcast( + url=url, + output_folder="downloads", + convert_to_mp3=True, + keep_original=False + ) + + if result.success: + return { + "status": "success", + "original_url": url, + "audio_file_path": result.file_path, + "podcast_info": { + "title": result.podcast_info.title if result.podcast_info else "Unknown", + "platform": "Apple Podcasts" + } + } + else: + return { + "status": "failed", + "original_url": url, + "audio_file_path": None, + "error_message": result.error_message + } + + except Exception as e: + return { + "status": "failed", + "original_url": url, + "audio_file_path": None, + "error_message": f"Local download tool error: {str(e)}" + } + + +async def download_xyz_podcast_tool(url: str) -> Dict[str, Any]: + """ + Download XiaoYuZhou podcast audio files and save to specified directory (LOCAL EXECUTION). + + Args: + url: Complete URL of XiaoYuZhou podcast page, format: https://www.xiaoyuzhoufm.com/episode/xxxxx + + Returns: + Download result dictionary containing the following key fields: + - "status" (str): Download status, "success" or "failed" + - "original_url" (str): Input original podcast URL + - "audio_file_path" (str|None): Complete MP3 file path when successful, None when failed + - "error_message" (str): Only exists when failed, contains specific error description + """ + try: + print(f"🏠 Downloading XiaoYuZhou Podcast locally: {url}") + service = get_podcast_download_service() + + # Use local download service + result = await service.download_podcast( + url=url, + output_folder="downloads", + convert_to_mp3=True, + keep_original=False + ) + + if result.success: + return { + "status": "success", + "original_url": url, + "audio_file_path": result.file_path, + "podcast_info": { + "title": result.podcast_info.title if result.podcast_info else "Unknown", + "platform": "XiaoYuZhou" + } + } + else: + return { + "status": "failed", + "original_url": url, + "audio_file_path": None, + "error_message": result.error_message + } + + except Exception as e: + return { + "status": "failed", + "original_url": url, + "audio_file_path": None, + "error_message": f"Local download tool error: {str(e)}" + } + + +async def get_mp3_files_tool(directory: str) -> Dict[str, Any]: + """ + Scan specified directory to get detailed information list of all MP3 audio files (LOCAL EXECUTION). + + Args: + directory: Absolute or relative path of directory to scan + + Returns: + Dictionary containing MP3 file information + """ + try: + service = get_file_management_service() + return await service.scan_mp3_files(directory) + + except Exception as e: + return { + "total_files": 0, + "scanned_directory": directory, + "file_list": [], + "error_message": f"Local file scan tool error: {str(e)}" + } + + +async def get_file_info_tool(file_path: str) -> Dict[str, Any]: + """ + Get basic file information including size, modification time, etc (LOCAL EXECUTION). + + Args: + file_path: File path to query + + Returns: + File information dictionary + """ + try: + service = get_file_management_service() + return await service.get_file_info(file_path) + + except Exception as e: + return { + "status": "failed", + "file_path": file_path, + "file_exists": False, + "error_message": f"Local file info tool error: {str(e)}" + } + + +async def read_text_file_segments_tool( + file_path: str, + chunk_size: int = 65536, + start_position: int = 0 +) -> Dict[str, Any]: + """ + Read text file content in segments, intelligently handling text boundaries (LOCAL EXECUTION). + + Args: + file_path: Path to file to read (supports TXT, SRT and other text files) + chunk_size: Byte size to read each time, default 64KB + start_position: Starting position to read from (byte offset), default 0 + + Returns: + Read result dictionary + """ + try: + service = get_file_management_service() + return await service.read_text_file_segments( + file_path=file_path, + chunk_size=chunk_size, + start_position=start_position + ) + + except Exception as e: + return { + "status": "failed", + "file_path": file_path, + "error_message": f"Local file read tool error: {str(e)}" + } \ No newline at end of file diff --git a/src/tools/mcp_tools.py b/src/tools/mcp_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..834df388a3a1ec24520973bc78a6e6ec69791b1e --- /dev/null +++ b/src/tools/mcp_tools.py @@ -0,0 +1,124 @@ +""" +MCP Tools using the new service architecture +""" + +from typing import Dict, Any + +from .transcription_tools import transcribe_audio_file_tool +from .download_tools import ( + download_apple_podcast_tool, + download_xyz_podcast_tool, + get_mp3_files_tool, + get_file_info_tool, + read_text_file_segments_tool +) + + +# ==================== Transcription Tools ==================== + +async def transcribe_audio_file( + audio_file_path: str, + model_size: str = "turbo", + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False +) -> Dict[str, Any]: + """ + Transcribe audio files to text using Whisper model with new service architecture + + Args: + audio_file_path: Complete path to audio file + model_size: Whisper model size (tiny, base, small, medium, large, turbo) + language: Audio language code (e.g. "zh" for Chinese, "en" for English) + output_format: Output format (srt, txt, json) + enable_speaker_diarization: Whether to enable speaker identification + + Returns: + Transcription result dictionary with file paths and metadata + """ + return await transcribe_audio_file_tool( + audio_file_path=audio_file_path, + model_size=model_size, + language=language, + output_format=output_format, + enable_speaker_diarization=enable_speaker_diarization + ) + + +# ==================== Download Tools ==================== + +async def download_apple_podcast(url: str) -> Dict[str, Any]: + """ + Download Apple Podcast audio files using new service architecture + + Args: + url: Complete URL of Apple Podcast page + + Returns: + Download result dictionary with file path and metadata + """ + return await download_apple_podcast_tool(url) + + +async def download_xyz_podcast(url: str) -> Dict[str, Any]: + """ + Download XiaoYuZhou podcast audio files using new service architecture + + Args: + url: Complete URL of XiaoYuZhou podcast page + + Returns: + Download result dictionary with file path and metadata + """ + return await download_xyz_podcast_tool(url) + + +# ==================== File Management Tools ==================== + +async def get_mp3_files(directory: str) -> Dict[str, Any]: + """ + Scan directory to get detailed information list of all MP3 audio files + + Args: + directory: Absolute or relative path of directory to scan + + Returns: + MP3 file information dictionary with detailed file list + """ + return await get_mp3_files_tool(directory) + + +async def get_file_info(file_path: str) -> Dict[str, Any]: + """ + Get basic file information including size, modification time, etc. + + Args: + file_path: File path to query + + Returns: + File information dictionary with detailed metadata + """ + return await get_file_info_tool(file_path) + + +async def read_text_file_segments( + file_path: str, + chunk_size: int = 65536, + start_position: int = 0 +) -> Dict[str, Any]: + """ + Read text file content in segments with intelligent boundary detection + + Args: + file_path: Path to file to read (supports TXT, SRT and other text files) + chunk_size: Byte size to read each time (default 64KB) + start_position: Starting position to read from (byte offset, default 0) + + Returns: + Read result dictionary with content and metadata + """ + return await read_text_file_segments_tool( + file_path=file_path, + chunk_size=chunk_size, + start_position=start_position + ) \ No newline at end of file diff --git a/src/tools/storage_tools.py b/src/tools/storage_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c3bbe3407fb563259af15f734d8501cca5d925 --- /dev/null +++ b/src/tools/storage_tools.py @@ -0,0 +1,292 @@ +""" +Storage Management Tools +Provides tools for managing download and transcript storage +""" + +from typing import Dict, Any, List +from pathlib import Path + +from ..utils.storage_config import get_storage_config + + +async def get_storage_info_tool() -> Dict[str, Any]: + """ + Get comprehensive storage information including directory sizes and file counts + + Returns: + Storage information dictionary + """ + try: + storage_config = get_storage_config() + storage_info = storage_config.get_storage_info() + + print(f"📊 Storage Information:") + print(f" Downloads: {storage_info['downloads_dir']}") + print(f" Transcripts: {storage_info['transcripts_dir']}") + print(f" Cache: {storage_info['cache_dir']}") + + return {"status": "success", **storage_info} + + except Exception as e: + return {"status": "failed", "error_message": str(e)} + + +async def list_audio_files_tool() -> Dict[str, Any]: + """ + List all audio files in the downloads directory + + Returns: + List of audio files with metadata + """ + try: + storage_config = get_storage_config() + audio_files = storage_config.get_audio_files() + + file_list = [] + total_size = 0 + + for audio_file in audio_files: + file_size = audio_file.stat().st_size + total_size += file_size + + # Check for corresponding transcript files + transcript_files = storage_config.get_transcript_files(audio_file.name) + has_transcripts = { + 'txt': transcript_files['txt'].exists(), + 'srt': transcript_files['srt'].exists(), + 'json': transcript_files['json'].exists() + } + + file_info = { + "filename": audio_file.name, + "path": str(audio_file), + "size_mb": round(file_size / (1024 * 1024), 2), + "modified": audio_file.stat().st_mtime, + "has_transcripts": has_transcripts, + "transcript_count": sum(has_transcripts.values()) + } + file_list.append(file_info) + + print(f"📁 Found {len(audio_files)} audio files ({round(total_size / (1024 * 1024), 2)} MB total)") + + return { + "status": "success", + "audio_files_count": len(audio_files), + "total_size_mb": round(total_size / (1024 * 1024), 2), + "downloads_directory": str(storage_config.downloads_dir), + "audio_files": file_list + } + + except Exception as e: + return { + "status": "failed", + "error_message": f"Audio files listing tool error: {str(e)}" + } + + +async def list_transcript_files_tool() -> Dict[str, Any]: + """ + List all transcript files in the transcripts directory + + Returns: + List of transcript files organized by format + """ + try: + storage_config = get_storage_config() + transcript_files = storage_config.get_transcript_files() + + organized_files = {} + total_files = 0 + total_size = 0 + + for format_type, files in transcript_files.items(): + format_info = [] + format_size = 0 + + for transcript_file in files: + file_size = transcript_file.stat().st_size + format_size += file_size + total_size += file_size + + # Check if corresponding audio file exists + base_name = transcript_file.stem + audio_files = storage_config.get_audio_files() + has_audio = any(af.stem == base_name for af in audio_files) + + file_info = { + "filename": transcript_file.name, + "path": str(transcript_file), + "size_kb": round(file_size / 1024, 2), + "modified": transcript_file.stat().st_mtime, + "base_name": base_name, + "has_audio": has_audio + } + format_info.append(file_info) + + organized_files[format_type] = { + "count": len(files), + "size_kb": round(format_size / 1024, 2), + "files": format_info + } + total_files += len(files) + + print(f"📄 Found {total_files} transcript files ({round(total_size / 1024, 2)} KB total)") + + return { + "status": "success", + "total_files": total_files, + "total_size_kb": round(total_size / 1024, 2), + "transcripts_directory": str(storage_config.transcripts_dir), + "formats": organized_files + } + + except Exception as e: + return { + "status": "failed", + "error_message": f"Transcript files listing tool error: {str(e)}" + } + + +async def cleanup_cache_tool(pattern: str = "temp_*") -> Dict[str, Any]: + """ + Clean up temporary files in cache directory + + Args: + pattern: File pattern to match for cleanup (default: temp_*) + + Returns: + Cleanup result + """ + try: + storage_config = get_storage_config() + + # Get cache size before cleanup + cache_info_before = storage_config.get_storage_info() + cache_size_before = cache_info_before['cache_size_mb'] + + # Perform cleanup + storage_config.cleanup_temp_files(pattern) + + # Get cache size after cleanup + cache_info_after = storage_config.get_storage_info() + cache_size_after = cache_info_after['cache_size_mb'] + + cleaned_mb = cache_size_before - cache_size_after + + print(f"🗑️ Cache cleanup completed") + print(f" Pattern: {pattern}") + print(f" Cleaned: {cleaned_mb:.2f} MB") + print(f" Cache size: {cache_size_before:.2f} MB → {cache_size_after:.2f} MB") + + return { + "status": "success", + "cleanup_pattern": pattern, + "cache_directory": str(storage_config.cache_dir), + "size_before_mb": cache_size_before, + "size_after_mb": cache_size_after, + "cleaned_mb": cleaned_mb + } + + except Exception as e: + return { + "status": "failed", + "error_message": f"Cache cleanup tool error: {str(e)}" + } + + +async def check_transcript_status_tool(audio_filename: str = None) -> Dict[str, Any]: + """ + Check transcript status for audio files + + Args: + audio_filename: Specific audio file to check (optional) + + Returns: + Transcript status information + """ + try: + storage_config = get_storage_config() + + if audio_filename: + # Check specific file + audio_path = storage_config.get_download_path(audio_filename) + if not audio_path.exists(): + return { + "status": "failed", + "error_message": f"Audio file not found: {audio_filename}" + } + + transcript_files = storage_config.get_transcript_files(audio_filename) + status = { + "audio_file": audio_filename, + "audio_exists": True, + "transcripts": { + format_type: { + "exists": file_path.exists(), + "path": str(file_path), + "size_kb": round(file_path.stat().st_size / 1024, 2) if file_path.exists() else 0 + } + for format_type, file_path in transcript_files.items() + } + } + + return { + "status": "success", + "mode": "single_file", + **status + } + else: + # Check all audio files + audio_files = storage_config.get_audio_files() + statuses = [] + + summary = { + "total_audio_files": len(audio_files), + "files_with_transcripts": 0, + "files_without_transcripts": 0, + "transcript_formats": {"txt": 0, "srt": 0, "json": 0} + } + + for audio_file in audio_files: + transcript_files = storage_config.get_transcript_files(audio_file.name) + + has_any_transcript = any(tf.exists() for tf in transcript_files.values()) + if has_any_transcript: + summary["files_with_transcripts"] += 1 + else: + summary["files_without_transcripts"] += 1 + + file_status = { + "audio_file": audio_file.name, + "has_transcripts": has_any_transcript, + "transcript_formats": { + format_type: file_path.exists() + for format_type, file_path in transcript_files.items() + } + } + + # Count transcript formats + for format_type, exists in file_status["transcript_formats"].items(): + if exists: + summary["transcript_formats"][format_type] += 1 + + statuses.append(file_status) + + print(f"📊 Transcript Status Summary:") + print(f" Total audio files: {summary['total_audio_files']}") + print(f" With transcripts: {summary['files_with_transcripts']}") + print(f" Without transcripts: {summary['files_without_transcripts']}") + print(f" Format counts: TXT({summary['transcript_formats']['txt']}) SRT({summary['transcript_formats']['srt']}) JSON({summary['transcript_formats']['json']})") + + return { + "status": "success", + "mode": "all_files", + "summary": summary, + "file_statuses": statuses + } + + except Exception as e: + return { + "status": "failed", + "error_message": f"Transcript status tool error: {str(e)}" + } \ No newline at end of file diff --git a/src/tools/transcription_tools.py b/src/tools/transcription_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cf6471ba6e7dbda7e154e02ee76777f33bebdd --- /dev/null +++ b/src/tools/transcription_tools.py @@ -0,0 +1,216 @@ +""" +Transcription tools using the enhanced service architecture +Updated to use ModalTranscriptionService for better separation of concerns +""" + +import asyncio +from typing import Dict, Any + +from ..services import ModalTranscriptionService + + +# Global service instance for reuse +_modal_transcription_service = None + + +def _format_srt_time(seconds: float) -> str: + """Format seconds to SRT time format (HH:MM:SS,mmm)""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + millisecs = int((seconds % 1) * 1000) + return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}" + + +def get_modal_transcription_service() -> ModalTranscriptionService: + """Get or create global ModalTranscriptionService instance""" + global _modal_transcription_service + if _modal_transcription_service is None: + _modal_transcription_service = ModalTranscriptionService(use_direct_modal_calls=True) + return _modal_transcription_service + + +async def transcribe_audio_file_tool( + audio_file_path: str, + model_size: str = "turbo", # Default to turbo model + language: str = None, + output_format: str = "srt", + enable_speaker_diarization: bool = False, + use_parallel_processing: bool = True, # Enable parallel processing by default + chunk_duration: int = 60, # 60 seconds chunks for parallel processing + use_intelligent_segmentation: bool = True # Enable intelligent segmentation by default +) -> Dict[str, Any]: + """ + MCP tool function for audio transcription using Modal endpoints with intelligent processing + Enhanced to save transcription results to local files + + Args: + audio_file_path: Path to audio file + model_size: Whisper model size (tiny, base, small, medium, large, turbo) + language: Language code (e.g., 'en', 'zh', None for auto-detect) + output_format: Output format (srt, txt, json) + enable_speaker_diarization: Whether to enable speaker diarization + use_parallel_processing: Whether to use distributed processing for long audio + chunk_duration: Duration of each chunk in seconds for parallel processing + use_intelligent_segmentation: Whether to use intelligent silence-based segmentation + + Returns: + Transcription result dictionary with local file paths + """ + try: + import os + import pathlib + + service = get_modal_transcription_service() + modal_result = await service.transcribe_audio_file( + audio_file_path=audio_file_path, + model_size=model_size, + language=language, + output_format=output_format, + enable_speaker_diarization=enable_speaker_diarization, + use_parallel_processing=use_parallel_processing, + chunk_duration=chunk_duration, + use_intelligent_segmentation=use_intelligent_segmentation + ) + + # Check if transcription was successful + if modal_result.get("processing_status") != "success": + return modal_result + + # Debug: Print modal result structure + print(f"🔍 Modal result keys: {list(modal_result.keys())}") + print(f"🔍 Has text: {bool(modal_result.get('text'))}") + print(f"🔍 Has segments: {bool(modal_result.get('segments'))}") + if modal_result.get("segments"): + print(f"🔍 Segments count: {len(modal_result['segments'])}") + + # Save transcription results to local files using storage config + from ..utils.storage_config import get_storage_config + storage_config = get_storage_config() + + base_name = pathlib.Path(audio_file_path).stem + output_dir = storage_config.transcripts_dir + + saved_files = [] + txt_file_path = None + srt_file_path = None + json_file_path = None + + # Generate SRT content if segments are available + if modal_result.get("segments"): + segments = modal_result["segments"] + srt_content = "" + for i, segment in enumerate(segments, 1): + start_time = _format_srt_time(segment.get("start", 0)) + end_time = _format_srt_time(segment.get("end", 0)) + text = segment.get("text", "").strip() + + if text: + if enable_speaker_diarization and segment.get("speaker"): + text = f"[{segment['speaker']}] {text}" + + srt_content += f"{i}\n{start_time} --> {end_time}\n{text}\n\n" + + if srt_content: + srt_file_path = output_dir / f"{base_name}.srt" + with open(srt_file_path, 'w', encoding='utf-8') as f: + f.write(srt_content) + saved_files.append(str(srt_file_path)) + print(f"💾 Saved SRT file: {srt_file_path}") + + # Generate TXT content if text is available + if modal_result.get("text"): + txt_file_path = output_dir / f"{base_name}.txt" + with open(txt_file_path, 'w', encoding='utf-8') as f: + f.write(modal_result["text"]) + saved_files.append(str(txt_file_path)) + print(f"💾 Saved TXT file: {txt_file_path}") + + # Save JSON file with full results (always save for debugging) + import json + json_file_path = output_dir / f"{base_name}.json" + with open(json_file_path, 'w', encoding='utf-8') as f: + json.dump(modal_result, f, indent=2, ensure_ascii=False) + saved_files.append(str(json_file_path)) + print(f"💾 Saved JSON file: {json_file_path}") + + # Warn if no text/segments found + if not modal_result.get("segments") and not modal_result.get("text"): + print("⚠️ Warning: No text or segments found in transcription result") + + # Update result with local file paths + result = modal_result.copy() + result["txt_file_path"] = str(txt_file_path) if txt_file_path else None + result["srt_file_path"] = str(srt_file_path) if srt_file_path else None + result["json_file_path"] = str(json_file_path) if json_file_path else None + result["saved_files"] = saved_files + result["local_files_saved"] = len(saved_files) + + print(f"✅ Transcription completed and saved {len(saved_files)} local files") + + return result + + except Exception as e: + return { + "processing_status": "failed", + "error_message": f"Tool error: {str(e)}" + } + + +async def check_modal_endpoints_health() -> Dict[str, Any]: + """ + Check the health status of Modal endpoints + + Returns: + Health status dictionary for all endpoints + """ + try: + service = get_modal_transcription_service() + return await service.check_endpoints_health() + + except Exception as e: + return { + "status": "failed", + "error_message": f"Health check tool error: {str(e)}" + } + + +async def get_system_status() -> Dict[str, Any]: + """ + Get comprehensive system status including health checks + + Returns: + System status dictionary + """ + try: + service = get_modal_transcription_service() + return await service.get_system_status() + + except Exception as e: + return { + "status": "failed", + "error_message": f"System status tool error: {str(e)}" + } + + +def get_modal_endpoint_url(endpoint_label: str) -> str: + """ + Get Modal endpoint URL for given label + + Args: + endpoint_label: Modal endpoint label + + Returns: + Full endpoint URL + """ + try: + service = get_modal_transcription_service() + return service.get_endpoint_url(endpoint_label) + + except Exception as e: + # Fallback to default URL pattern + return f"https://richardsucran--{endpoint_label}.modal.run" + + +# Note: Download functionality has been moved to download_tools.py +# These functions are now implemented there using PodcastDownloadService for local downloads \ No newline at end of file diff --git a/src/ui/__init__.py b/src/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..075f4f9107bf6d6f61966a6f629b600d3852b2b8 --- /dev/null +++ b/src/ui/__init__.py @@ -0,0 +1,5 @@ +""" +UI Module - User interfaces and interactive components +""" + +__all__ = [] \ No newline at end of file diff --git a/src/ui/__pycache__/__init__.cpython-310.pyc b/src/ui/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfb313c4e967411261a63663de6359aacadf995b Binary files /dev/null and b/src/ui/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/ui/__pycache__/gradio_ui.cpython-310.pyc b/src/ui/__pycache__/gradio_ui.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d26e9b4a00d732d724811f036236e889aed69953 Binary files /dev/null and b/src/ui/__pycache__/gradio_ui.cpython-310.pyc differ diff --git a/src/ui/gradio_ui.py b/src/ui/gradio_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c944648104bc2274ccbfe9b0e54b74e2af6055 --- /dev/null +++ b/src/ui/gradio_ui.py @@ -0,0 +1,546 @@ +""" +Gradio interface module +Contains all UI components and interface logic +""" + +import gradio as gr +import asyncio +from ..tools import mcp_tools +from ..tools.download_tools import get_file_info_tool, get_mp3_files_tool, read_text_file_segments_tool +from ..tools.transcription_tools import transcribe_audio_file_tool +import os + +def create_gradio_interface(): + """Create Gradio interface + + Returns: + gr.Blocks: Configured Gradio interface + """ + + with gr.Blocks(title="MCP Tool Server") as demo: + gr.Markdown("# 🤖 Gradio + FastMCP Server") + gr.Markdown("This server provides both Gradio UI and FastMCP tools!") + + # ==================== Podcast Download Tab ==================== + with gr.Tab("Podcast Download"): + gr.Markdown("### 🎙️ Download Podcast Audio") + + url_input = gr.Textbox( + label="Podcast Link", + placeholder="Enter podcast page URL", + lines=1 + ) + + platform_choice = gr.Radio( + choices=["Apple Podcast", "XiaoYuZhou"], + label="Select Podcast Platform", + value="Apple Podcast" + ) + + # Transcription options + with gr.Row(): + auto_transcribe = gr.Checkbox( + label="Auto-transcribe after download", + value=True, + info="Start transcription immediately after download" + ) + enable_speaker_diarization = gr.Checkbox( + label="Enable speaker diarization", + value=False, + info="Identify different speakers (requires Hugging Face Token)" + ) + + download_btn = gr.Button("📥 Start Download", variant="primary") + result_output = gr.JSON(label="Download Results") + + async def download_podcast_and_transcribe(url, platform, auto_transcribe, enable_speaker): + """Call corresponding download tool based on selected platform""" + if platform == "Apple Podcast": + download_result = await mcp_tools.download_apple_podcast(url) + else: + download_result = await mcp_tools.download_xyz_podcast(url) + + # 2. Check if download was successful + if download_result["status"] != "success": + return { + "download_status": "failed", + "error_message": download_result.get("error_message", "Download failed"), + "transcription_status": "not_started" + } + + # 3. If not auto-transcribing, return only download results + if not auto_transcribe: + return { + "download_status": "success", + "audio_file": download_result["audio_file_path"], + "transcription_status": "skipped (user chose not to auto-transcribe)" + } + + # 4. Start transcription + try: + audio_path = download_result["audio_file_path"] + print(f"Transcribing audio file: {audio_path}") + transcribe_result = await mcp_tools.transcribe_audio_file( + audio_path, + model_size="turbo", + language=None, + output_format="srt", + enable_speaker_diarization=enable_speaker + ) + + # 5. Merge results + result = { + "download_status": "success", + "audio_file": audio_path, + "transcription_status": "success", + "txt_file_path": transcribe_result.get("txt_file_path"), + "srt_file_path": transcribe_result.get("srt_file_path"), + "transcription_details": { + "model_used": transcribe_result.get("model_used"), + "segment_count": transcribe_result.get("segment_count"), + "audio_duration": transcribe_result.get("audio_duration"), + "saved_files": transcribe_result.get("saved_files", []), + "speaker_diarization_enabled": transcribe_result.get("speaker_diarization_enabled", False) + } + } + + # 6. Add speaker diarization info if enabled + if enable_speaker and transcribe_result.get("speaker_diarization_enabled", False): + result["speaker_diarization"] = { + "global_speaker_count": transcribe_result.get("global_speaker_count", 0), + "speaker_summary": transcribe_result.get("speaker_summary", {}) + } + + return result + + except Exception as e: + return { + "download_status": "success", + "audio_file": download_result["audio_file_path"], + "transcription_status": "failed", + "error_message": str(e) + } + + # Bind callback function + download_btn.click( + download_podcast_and_transcribe, + inputs=[url_input, platform_choice, auto_transcribe, enable_speaker_diarization], + outputs=result_output + ) + + # ==================== Audio Transcription Tab ==================== + with gr.Tab("Audio Transcription"): + gr.Markdown("### 🎤 Audio Transcription and Speaker Diarization") + gr.Markdown("Upload audio files for high-quality transcription with speaker diarization support") + + with gr.Row(): + with gr.Column(scale=2): + # Audio file input + audio_file_input = gr.Textbox( + label="Audio File Path", + placeholder="Enter complete path to audio file (supports mp3, wav, m4a, etc.)", + lines=1 + ) + + # Transcription parameter settings + with gr.Row(): + model_size_choice = gr.Dropdown( + choices=["tiny", "base", "small", "medium", "large", "turbo"], + value="turbo", + label="Model Size", + info="Affects transcription accuracy and speed" + ) + language_choice = gr.Dropdown( + choices=["auto", "zh", "en", "ja", "ko", "fr", "de", "es"], + value="auto", + label="Language", + info="auto=auto-detect" + ) + + with gr.Row(): + output_format_choice = gr.Radio( + choices=["srt", "txt", "json"], + value="srt", + label="Output Format" + ) + enable_speaker_separation = gr.Checkbox( + label="Enable speaker diarization", + value=False, + info="Requires Hugging Face Token" + ) + + transcribe_btn = gr.Button("🎤 Start Transcription", variant="primary", size="lg") + + with gr.Column(scale=1): + # Audio file information + audio_info_output = gr.JSON(label="Audio File Information", visible=False) + + # Transcription progress and status + transcribe_status = gr.Textbox( + label="Transcription Status", + value="Waiting to start transcription...", + interactive=False, + lines=3 + ) + + # Transcription results display + transcribe_result_output = gr.JSON( + label="Transcription Results", + visible=True + ) + + # Speaker diarization results (if enabled) + speaker_info_output = gr.JSON( + label="Speaker Diarization Information", + visible=False + ) + + def perform_transcription(audio_path, model_size, language, output_format, enable_speaker): + """Execute audio transcription""" + if not audio_path.strip(): + return { + "error": "Please enter audio file path" + }, "Transcription failed: No audio file selected", gr.update(visible=False) + + # Check if file exists + import asyncio + file_info = asyncio.run(get_file_info_tool(audio_path)) + if file_info["status"] != "success": + return { + "error": f"File does not exist or cannot be accessed: {file_info.get('error_message', 'Unknown error')}" + }, "Transcription failed: File inaccessible", gr.update(visible=False) + + try: + # Process language parameter + lang = None if language == "auto" else language + + # Call transcription tool + result = asyncio.run(transcribe_audio_file_tool( + audio_file_path=audio_path, + model_size=model_size, + language=lang, + output_format=output_format, + enable_speaker_diarization=enable_speaker + )) + + # Prepare status information + if result.get("processing_status") == "success": + status_text = f"""✅ Transcription completed! +📁 Generated files: {len(result.get('saved_files', []))} files +🎵 Audio duration: {result.get('audio_duration', 0):.2f} seconds +📝 Transcription segments: {result.get('segment_count', 0)} segments +🎯 Model used: {result.get('model_used', 'N/A')} +🎭 Speaker diarization: {'Enabled' if result.get('speaker_diarization_enabled', False) else 'Disabled'}""" + + # Show speaker information + speaker_visible = result.get('speaker_diarization_enabled', False) and result.get('global_speaker_count', 0) > 0 + speaker_info = result.get('speaker_summary', {}) if speaker_visible else {} + + return result, status_text, gr.update(visible=speaker_visible, value=speaker_info) + else: + error_msg = result.get('error_message', 'Unknown error') + return result, f"❌ Transcription failed: {error_msg}", gr.update(visible=False) + + except Exception as e: + return { + "error": f"Exception occurred during transcription: {str(e)}" + }, f"❌ Transcription exception: {str(e)}", gr.update(visible=False) + + # Bind transcription button + transcribe_btn.click( + perform_transcription, + inputs=[ + audio_file_input, + model_size_choice, + language_choice, + output_format_choice, + enable_speaker_separation + ], + outputs=[ + transcribe_result_output, + transcribe_status, + speaker_info_output + ] + ) + + # ==================== MP3 File Management Tab ==================== + with gr.Tab("MP3 File Management"): + gr.Markdown("### 🎵 MP3 File Management") + + dir_input = gr.Dropdown( + label="Directory Path", + choices=[ + "/root/cache/apple_podcasts", + "/root/cache/xyz_podcasts" + ], + value="/root/cache/apple_podcasts" + ) + + file_list = gr.Textbox( + label="MP3 File List", + interactive=False, + lines=10, + max_lines=20, + show_copy_button=True, + autoscroll=True + ) + + def list_mp3_files(directory): + """List MP3 files in directory""" + files = asyncio.run(get_mp3_files_tool(directory)) + return "\n".join(files) if files else "No MP3 files found in directory" + + # Bind callback function + dir_input.change( + list_mp3_files, + inputs=[dir_input], + outputs=[file_list] + ) + + # ==================== Transcription Text Management Tab ==================== + with gr.Tab("Transcription Text Management"): + gr.Markdown("### 📝 Transcription Text File Management") + gr.Markdown("Manage and edit TXT and SRT files generated from audio transcription") + + with gr.Row(): + with gr.Column(scale=2): + # File path input + file_path_input = gr.Textbox( + label="File Path", + placeholder="Enter path to TXT or SRT file to read", + lines=1 + ) + + # File information display + file_info_output = gr.JSON(label="File Information", visible=False) + + with gr.Row(): + load_file_btn = gr.Button("📂 Load File", variant="secondary") + save_file_btn = gr.Button("💾 Save File", variant="primary") + refresh_btn = gr.Button("🔄 Refresh", variant="secondary") + + with gr.Column(scale=1): + # Read control + gr.Markdown("#### 📖 Segmented Reading Control") + current_position = gr.Number( + label="Current Position (bytes)", + value=0, + minimum=0 + ) + chunk_size = gr.Number( + label="Chunk Size (bytes)", + value=65536, # 64KB + minimum=1024, + maximum=1048576 # Max 1MB + ) + + with gr.Row(): + prev_chunk_btn = gr.Button("⬅️ Previous", size="sm") + next_chunk_btn = gr.Button("➡️ Next", size="sm") + + # Progress display + progress_display = gr.Textbox( + label="Reading Progress", + value="No file loaded", + interactive=False, + lines=3 + ) + + # Write control + gr.Markdown("#### ✏️ Write Control") + write_mode = gr.Radio( + choices=["w", "a", "r+"], + value="w", + label="Write Mode", + info="w=overwrite, a=append, r+=position" + ) + write_position = gr.Number( + label="Write Position (bytes)", + value=0, + minimum=0, + visible=False + ) + + # Text content editor + content_editor = gr.Textbox( + label="File Content", + placeholder="File content will be displayed here after loading...", + lines=20, + max_lines=30, + show_copy_button=True, + autoscroll=False + ) + + # Status information + status_output = gr.Textbox( + label="Operation Status", + interactive=False, + lines=2 + ) + + # Internal state variables + file_state = gr.State({ + "file_path": "", + "file_size": 0, + "current_pos": 0, + "chunk_size": 65536, + "content": "" + }) + + def load_file_info(file_path): + """Load file information""" + if not file_path.strip(): + return {}, "Please enter file path", "No file selected", gr.update(visible=False) + + info = asyncio.run(get_file_info_tool(file_path)) + if info["status"] == "success": + return ( + info, + f"File: {info['filename']} | Size: {info['file_size_mb']} MB", + "File information loaded successfully", + gr.update(visible=True) + ) + else: + return ( + {}, + f"Error: {info.get('error_message', 'Unknown error')}", + "Failed to load file information", + gr.update(visible=False) + ) + + def read_file_content(file_path, position, chunk_size): + """Read file content""" + if not file_path.strip(): + return "", 0, "No file selected", { + "file_path": "", + "file_size": 0, + "current_pos": 0, + "chunk_size": chunk_size, + "content": "" + } + + result = asyncio.run(read_text_file_segments_tool(file_path, int(chunk_size), int(position))) + + if result["status"] == "success": + new_state = { + "file_path": file_path, + "file_size": result["file_size"], + "current_pos": result["current_position"], + "chunk_size": chunk_size, + "content": result["content"] + } + + progress_text = ( + f"Progress: {result['progress_percentage']:.1f}% " + f"({result['current_position']}/{result['file_size']} bytes)\n" + f"Boundary type: {result.get('actual_boundary', 'Unknown')}\n" + f"{'End of file reached' if result['end_of_file_reached'] else 'More content available'}" + ) + + return ( + result["content"], + result["current_position"], + progress_text, + new_state + ) + else: + return ( + "", + position, + f"Read failed: {result.get('error_message', 'Unknown error')}", + { + "file_path": file_path, + "file_size": 0, + "current_pos": position, + "chunk_size": chunk_size, + "content": "" + } + ) + + def save_file_content(file_path, content, mode, position): + """Save file content""" + if not file_path.strip(): + return "Please select a file first" + + if not content.strip(): + return "No content to save" + + # Determine whether to use position parameter based on mode + write_pos = position if mode == "r+" else None + result = write_text_file_content(file_path, content, mode, write_pos) + + if result["status"] == "success": + operation_info = f"Operation: {result.get('operation_type', mode)}" + size_info = f"Size change: {result.get('size_change', 0):+d} bytes" + return f"Save successful!\n{operation_info}\nWrote {result['characters_written']} characters\n{size_info}" + else: + return f"Save failed: {result.get('error_message', 'Unknown error')}" + + def navigate_chunks(file_state, direction): + """Navigate to previous or next chunk""" + if not file_state["file_path"]: + return file_state["current_pos"], "Please load a file first" + + chunk_size = file_state["chunk_size"] + current_pos = file_state["current_pos"] + + if direction == "prev": + new_pos = max(0, current_pos - chunk_size * 2) # Go back two chunks + elif direction == "next": + new_pos = current_pos # Next chunk starts from current position + else: + return current_pos, "Invalid navigation direction" + + return new_pos, f"Navigated to position: {new_pos}" + + # Bind event handlers + load_file_btn.click( + load_file_info, + inputs=[file_path_input], + outputs=[file_info_output, progress_display, status_output, file_info_output] + ).then( + read_file_content, + inputs=[file_path_input, current_position, chunk_size], + outputs=[content_editor, current_position, progress_display, file_state] + ) + + refresh_btn.click( + read_file_content, + inputs=[file_path_input, current_position, chunk_size], + outputs=[content_editor, current_position, progress_display, file_state] + ) + + # Control position input visibility when write mode changes + write_mode.change( + lambda mode: gr.update(visible=(mode == "r+")), + inputs=[write_mode], + outputs=[write_position] + ) + + save_file_btn.click( + save_file_content, + inputs=[file_path_input, content_editor, write_mode, write_position], + outputs=[status_output] + ) + + prev_chunk_btn.click( + lambda state: navigate_chunks(state, "prev"), + inputs=[file_state], + outputs=[current_position, status_output] + ).then( + read_file_content, + inputs=[file_path_input, current_position, chunk_size], + outputs=[content_editor, current_position, progress_display, file_state] + ) + + next_chunk_btn.click( + lambda state: navigate_chunks(state, "next"), + inputs=[file_state], + outputs=[current_position, status_output] + ).then( + read_file_content, + inputs=[file_path_input, current_position, chunk_size], + outputs=[content_editor, current_position, progress_display, file_state] + ) + + return demo \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b02d1e439edf5b9ceb974197fa15fcd07172d23 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,32 @@ +""" +Utility modules for audio processing +""" + +from .config import AudioProcessingConfig +from .errors import ( + AudioProcessingError, + TranscriptionError, + SpeakerDiarizationError, + FileProcessingError, + SpeakerDetectionError, + AudioSplittingError, + ModelLoadError, + ConfigurationError, + DeploymentError +) +from .formatters import SRTFormatter, TextFormatter + +__all__ = [ + "AudioProcessingConfig", + "AudioProcessingError", + "TranscriptionError", + "SpeakerDiarizationError", + "FileProcessingError", + "SpeakerDetectionError", + "AudioSplittingError", + "ModelLoadError", + "ConfigurationError", + "DeploymentError", + "SRTFormatter", + "TextFormatter" +] \ No newline at end of file diff --git a/src/utils/__pycache__/__init__.cpython-310.pyc b/src/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48d18ecf5b182d563adca4cdab70d5302b2e2dcd Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/utils/__pycache__/config.cpython-310.pyc b/src/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bb2f6e7038c9fa6219a82966734c2ffcac8a55a Binary files /dev/null and b/src/utils/__pycache__/config.cpython-310.pyc differ diff --git a/src/utils/__pycache__/errors.cpython-310.pyc b/src/utils/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd9e59c2f05adb05b0bfe298ff2b9b30b549674 Binary files /dev/null and b/src/utils/__pycache__/errors.cpython-310.pyc differ diff --git a/src/utils/__pycache__/formatters.cpython-310.pyc b/src/utils/__pycache__/formatters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..258c40acfdbd1836c63b5beca3358e225e7b937b Binary files /dev/null and b/src/utils/__pycache__/formatters.cpython-310.pyc differ diff --git a/src/utils/__pycache__/storage_config.cpython-310.pyc b/src/utils/__pycache__/storage_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4397c4ae98a267fac269f3cd8dd65f49b153505 Binary files /dev/null and b/src/utils/__pycache__/storage_config.cpython-310.pyc differ diff --git a/src/utils/config.py b/src/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e157cb6a0cbf995e19b313158f14bf499c9f6c78 --- /dev/null +++ b/src/utils/config.py @@ -0,0 +1,107 @@ +""" +Configuration management for audio processing +""" + +import os +from dataclasses import dataclass, field +from typing import Dict, Optional, List +import json + + +@dataclass +class ModelConfig: + """Model configuration""" + name: str + params: str + description: str = "" + + +@dataclass +class AudioProcessingConfig: + """Centralized configuration for audio processing""" + + # Model configurations + whisper_models: Dict[str, ModelConfig] = field(default_factory=lambda: { + "tiny": ModelConfig("tiny", "39M", "Fastest, lowest accuracy"), + "base": ModelConfig("base", "74M", "Fast, low accuracy"), + "small": ModelConfig("small", "244M", "Medium speed, medium accuracy"), + "medium": ModelConfig("medium", "769M", "Slow, high accuracy"), + "large": ModelConfig("large", "1550M", "Slowest, highest accuracy"), + "turbo": ModelConfig("turbo", "809M", "Balanced speed and accuracy") + }) + + # Default settings + default_model: str = "turbo" + default_language: Optional[str] = None + min_segment_length: float = 30.0 + min_silence_length: float = 1.0 + + # Processing settings + max_parallel_segments: int = 10 + timeout_seconds: int = 600 + + # File paths + cache_dir: str = "./cache" + model_dir: str = "./models" + + # Modal settings + modal_app_name: str = "podcast-transcription" + modal_gpu_type: str = "A10G" + modal_memory: int = 10240 + modal_cpu: int = 4 + + # Speaker diarization + hf_token_env_var: str = "HF_TOKEN" + speaker_embedding_model: str = "pyannote/embedding" + speaker_diarization_model: str = "pyannote/speaker-diarization-3.1" + + # Output formats + supported_output_formats: List[str] = field(default_factory=lambda: ["txt", "srt", "json"]) + + @classmethod + def from_file(cls, config_path: str) -> "AudioProcessingConfig": + """Load configuration from JSON file""" + if not os.path.exists(config_path): + return cls() + + with open(config_path, 'r', encoding='utf-8') as f: + config_dict = json.load(f) + + # Convert model configs + if 'whisper_models' in config_dict: + models = {} + for name, model_data in config_dict['whisper_models'].items(): + models[name] = ModelConfig(**model_data) + config_dict['whisper_models'] = models + + return cls(**config_dict) + + def to_file(self, config_path: str): + """Save configuration to JSON file""" + config_dict = {} + for key, value in self.__dict__.items(): + if key == 'whisper_models': + config_dict[key] = { + name: model.__dict__ for name, model in value.items() + } + else: + config_dict[key] = value + + with open(config_path, 'w', encoding='utf-8') as f: + json.dump(config_dict, f, indent=2, ensure_ascii=False) + + def get_model_config(self, model_name: str) -> ModelConfig: + """Get model configuration by name""" + if model_name not in self.whisper_models: + raise ValueError(f"Unsupported model: {model_name}") + return self.whisper_models[model_name] + + @property + def is_speaker_diarization_available(self) -> bool: + """Check if speaker diarization is available""" + return os.environ.get(self.hf_token_env_var) is not None + + @property + def hf_token(self) -> Optional[str]: + """Get Hugging Face token""" + return os.environ.get(self.hf_token_env_var) \ No newline at end of file diff --git a/src/utils/errors.py b/src/utils/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..59ecd873f364ceb396d387e20eed1fa35e2d9af0 --- /dev/null +++ b/src/utils/errors.py @@ -0,0 +1,95 @@ +""" +Custom error classes for audio processing +""" + + +class AudioProcessingError(Exception): + """Base exception for audio processing errors""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code or "AUDIO_PROCESSING_ERROR" + self.details = details or {} + + def to_dict(self) -> dict: + """Convert error to dictionary format""" + return { + "error": self.error_code, + "message": self.message, + "details": self.details + } + + +class TranscriptionError(AudioProcessingError): + """Exception for transcription-related errors""" + + def __init__(self, message: str, model: str = None, audio_file: str = None, **kwargs): + super().__init__(message, error_code="TRANSCRIPTION_ERROR", **kwargs) + if model: + self.details["model"] = model + if audio_file: + self.details["audio_file"] = audio_file + + +class SpeakerDetectionError(AudioProcessingError): + """Exception for speaker detection-related errors""" + + def __init__(self, message: str, audio_file: str = None, **kwargs): + super().__init__(message, error_code="SPEAKER_DETECTION_ERROR", **kwargs) + if audio_file: + self.details["audio_file"] = audio_file + + +class SpeakerDiarizationError(AudioProcessingError): + """Exception for speaker diarization-related errors""" + + def __init__(self, message: str, audio_file: str = None, **kwargs): + super().__init__(message, error_code="SPEAKER_DIARIZATION_ERROR", **kwargs) + if audio_file: + self.details["audio_file"] = audio_file + + +class AudioSplittingError(AudioProcessingError): + """Exception for audio splitting-related errors""" + + def __init__(self, message: str, audio_file: str = None, **kwargs): + super().__init__(message, error_code="AUDIO_SPLITTING_ERROR", **kwargs) + if audio_file: + self.details["audio_file"] = audio_file + + +class FileProcessingError(AudioProcessingError): + """Exception for file processing-related errors""" + + def __init__(self, message: str, file_path: str = None, **kwargs): + super().__init__(message, error_code="FILE_PROCESSING_ERROR", **kwargs) + if file_path: + self.details["file_path"] = file_path + + +class ModelLoadError(AudioProcessingError): + """Exception for model loading errors""" + + def __init__(self, message: str, model_name: str = None, **kwargs): + super().__init__(message, error_code="MODEL_LOAD_ERROR", **kwargs) + if model_name: + self.details["model_name"] = model_name + + +class ConfigurationError(AudioProcessingError): + """Exception for configuration-related errors""" + + def __init__(self, message: str, config_key: str = None, **kwargs): + super().__init__(message, error_code="CONFIGURATION_ERROR", **kwargs) + if config_key: + self.details["config_key"] = config_key + + +class DeploymentError(AudioProcessingError): + """Exception for deployment-related errors""" + + def __init__(self, message: str, service: str = None, **kwargs): + super().__init__(message, error_code="DEPLOYMENT_ERROR", **kwargs) + if service: + self.details["service"] = service \ No newline at end of file diff --git a/src/utils/formatters.py b/src/utils/formatters.py new file mode 100644 index 0000000000000000000000000000000000000000..6942af7abd94c157700a3897e9abe5116b3dde82 --- /dev/null +++ b/src/utils/formatters.py @@ -0,0 +1,149 @@ +""" +Formatting utilities for audio processing outputs +""" + +from typing import List, Dict, Any +from ..interfaces.transcriber import TranscriptionSegment + + +class TimestampFormatter: + """Utility for formatting timestamps""" + + @staticmethod + def format_srt_timestamp(seconds: float) -> str: + """Format timestamp for SRT format""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + millis = int((seconds % 1) * 1000) + return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" + + @staticmethod + def format_readable_timestamp(seconds: float) -> str: + """Format timestamp for human reading""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + if hours > 0: + return f"{hours:02d}:{minutes:02d}:{secs:02d}" + else: + return f"{minutes:02d}:{secs:02d}" + + +class SRTFormatter: + """Utility for SRT subtitle formatting""" + + @staticmethod + def format_segments( + segments: List[TranscriptionSegment], + include_speakers: bool = False + ) -> str: + """Format transcription segments as SRT""" + srt_content = "" + srt_index = 1 + + for segment in segments: + if not segment.text.strip(): + continue + + start_time = TimestampFormatter.format_srt_timestamp(segment.start) + end_time = TimestampFormatter.format_srt_timestamp(segment.end) + + # Format text with optional speaker information + if include_speakers and segment.speaker: + formatted_text = f"{segment.speaker}: {segment.text.strip()}" + else: + formatted_text = segment.text.strip() + + srt_content += f"{srt_index}\n{start_time} --> {end_time}\n{formatted_text}\n\n" + srt_index += 1 + + return srt_content + + +class TextFormatter: + """Utility for plain text formatting""" + + @staticmethod + def format_segments( + segments: List[TranscriptionSegment], + include_timestamps: bool = False, + include_speakers: bool = False + ) -> str: + """Format transcription segments as plain text""" + lines = [] + + for segment in segments: + if not segment.text.strip(): + continue + + parts = [] + + if include_timestamps: + timestamp = TimestampFormatter.format_readable_timestamp(segment.start) + parts.append(f"[{timestamp}]") + + if include_speakers and segment.speaker: + parts.append(f"{segment.speaker}:") + + parts.append(segment.text.strip()) + + lines.append(" ".join(parts)) + + return "\n".join(lines) + + @staticmethod + def format_continuous_text(segments: List[TranscriptionSegment]) -> str: + """Format segments as continuous text without breaks""" + texts = [segment.text.strip() for segment in segments if segment.text.strip()] + return " ".join(texts) + + +# ==================== Legacy compatibility functions ==================== + +def generate_srt_format(segments: List[Dict[str, Any]], include_speakers: bool = False) -> str: + """ + Legacy function for generating SRT format from segment dictionaries + + Args: + segments: List of segment dictionaries with 'start', 'end', 'text', and optional 'speaker' + include_speakers: Whether to include speaker information + + Returns: + SRT formatted string + """ + srt_content = "" + srt_index = 1 + + for segment in segments: + text = segment.get("text", "").strip() + if not text: + continue + + start_time = format_timestamp(segment.get("start", 0)) + end_time = format_timestamp(segment.get("end", 0)) + + # Format text with optional speaker information + if include_speakers and segment.get("speaker"): + formatted_text = f"{segment['speaker']}: {text}" + else: + formatted_text = text + + srt_content += f"{srt_index}\n{start_time} --> {end_time}\n{formatted_text}\n\n" + srt_index += 1 + + return srt_content + + +def format_timestamp(seconds: float) -> str: + """ + Legacy function for formatting timestamps + + Args: + seconds: Time in seconds + + Returns: + SRT formatted timestamp + """ + return TimestampFormatter.format_srt_timestamp(seconds) \ No newline at end of file diff --git a/src/utils/storage_config.py b/src/utils/storage_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ad05bca80592f7133e98ae7128d9dc3206dc4ac6 --- /dev/null +++ b/src/utils/storage_config.py @@ -0,0 +1,253 @@ +""" +Storage Configuration Management +Centralizes all storage path configurations for downloads and transcripts +""" + +import os +from pathlib import Path +from typing import Optional +from dotenv import load_dotenv + + +class StorageConfig: + """Centralized storage configuration for podcast processing""" + + def __init__(self, config_file: str = "config.env"): + """ + Initialize storage configuration + + Args: + config_file: Path to configuration file + """ + self.config_file = config_file + self._load_config() + self._ensure_directories() + + def _load_config(self): + """Load configuration from environment file or Modal environment""" + # Check if we're running in Modal environment + is_modal_env = ( + os.getenv("MODAL_TASK_ID") or + os.getenv("MODAL_IS_INSIDE_CONTAINER") or + os.getenv("DEPLOYMENT_MODE") == "modal" + ) + + if is_modal_env: + print("🔧 Using Modal environment configuration") + # Use Modal defaults - don't load config files + self.downloads_dir = Path("/root/downloads").resolve() + self.transcripts_dir = Path("/root/transcripts").resolve() + self.cache_dir = Path("/root/cache").resolve() + else: + print("🔧 Using local environment configuration") + # Load from config file if it exists + if os.path.exists(self.config_file): + load_dotenv(self.config_file, override=False) + print(f"📄 Loaded config from {self.config_file}") + + # Load from .env if it exists + if os.path.exists(".env"): + load_dotenv(".env", override=False) + print("📄 Loaded config from .env") + + # Set defaults for local environment + self.downloads_dir = Path(os.getenv("DOWNLOADS_DIR", "./downloads")).resolve() + self.transcripts_dir = Path(os.getenv("TRANSCRIPTS_DIR", "./transcripts")).resolve() + self.cache_dir = Path(os.getenv("CACHE_DIR", "./cache")).resolve() + + # Common settings (apply to both environments) + self.download_quality = os.getenv("DOWNLOAD_QUALITY", "highest") + self.convert_to_mp3 = os.getenv("CONVERT_TO_MP3", "true").lower() == "true" + self.default_model_size = os.getenv("DEFAULT_MODEL_SIZE", "turbo") + self.default_output_format = os.getenv("DEFAULT_OUTPUT_FORMAT", "srt") + self.enable_speaker_diarization = os.getenv("ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true" + self.use_parallel_processing = os.getenv("USE_PARALLEL_PROCESSING", "true").lower() == "true" + self.chunk_duration = int(os.getenv("CHUNK_DURATION", "60")) + + # Store environment type for reference + self.is_modal_env = is_modal_env + + def _ensure_directories(self): + """Ensure all configured directories exist""" + for directory in [self.downloads_dir, self.transcripts_dir, self.cache_dir]: + try: + directory.mkdir(parents=True, exist_ok=True) + if self.is_modal_env: + print(f"📁 Modal storage directory ready: {directory}") + else: + print(f"📁 Local storage directory ready: {directory}") + except Exception as e: + print(f"⚠️ Failed to create directory {directory}: {e}") + # In Modal environment, some directories might be managed differently + if not self.is_modal_env: + raise + + def get_download_path(self, filename: str) -> Path: + """ + Get full path for downloaded audio file + + Args: + filename: Audio filename + + Returns: + Full path for downloaded file + """ + return self.downloads_dir / filename + + def get_transcript_path(self, audio_filename: str, output_format: str = None) -> Path: + """ + Get full path for transcript file + + Args: + audio_filename: Original audio filename + output_format: Output format (txt, srt, json) + + Returns: + Full path for transcript file + """ + if output_format is None: + output_format = self.default_output_format + + # Remove audio extension and add transcript extension + base_name = Path(audio_filename).stem + transcript_filename = f"{base_name}.{output_format}" + + return self.transcripts_dir / transcript_filename + + def get_cache_path(self, filename: str) -> Path: + """ + Get full path for cache file + + Args: + filename: Cache filename + + Returns: + Full path for cache file + """ + return self.cache_dir / filename + + def get_audio_files(self) -> list[Path]: + """ + Get list of all audio files in downloads directory + + Returns: + List of audio file paths + """ + audio_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.aac', '.ogg'} + audio_files = [] + + for file_path in self.downloads_dir.iterdir(): + if file_path.is_file() and file_path.suffix.lower() in audio_extensions: + audio_files.append(file_path) + + return sorted(audio_files) + + def get_transcript_files(self, audio_filename: str = None) -> dict[str, Path]: + """ + Get paths for all transcript formats for a given audio file + + Args: + audio_filename: Audio filename (optional) + + Returns: + Dictionary mapping format to file path + """ + if audio_filename: + base_name = Path(audio_filename).stem + return { + 'txt': self.get_transcript_path(audio_filename, 'txt'), + 'srt': self.get_transcript_path(audio_filename, 'srt'), + 'json': self.get_transcript_path(audio_filename, 'json') + } + else: + # Return all transcript files + transcript_files = {'txt': [], 'srt': [], 'json': []} + for file_path in self.transcripts_dir.iterdir(): + if file_path.is_file(): + ext = file_path.suffix[1:] # Remove the dot + if ext in transcript_files: + transcript_files[ext].append(file_path) + return transcript_files + + def cleanup_temp_files(self, pattern: str = "temp_*"): + """ + Clean up temporary files in cache directory + + Args: + pattern: File pattern to match for cleanup + """ + import glob + temp_files = glob.glob(str(self.cache_dir / pattern)) + for temp_file in temp_files: + try: + os.remove(temp_file) + print(f"🗑️ Cleaned up temp file: {temp_file}") + except Exception as e: + print(f"⚠️ Failed to cleanup {temp_file}: {e}") + + def get_storage_info(self) -> dict: + """ + Get storage configuration information + + Returns: + Dictionary with storage information + """ + audio_files = self.get_audio_files() + transcript_files = self.get_transcript_files() + + def get_dir_size(directory: Path) -> int: + """Get total size of directory in bytes""" + total_size = 0 + try: + for file_path in directory.rglob('*'): + if file_path.is_file(): + total_size += file_path.stat().st_size + except Exception: + pass + return total_size + + return { + "environment": "modal" if self.is_modal_env else "local", + "downloads_dir": str(self.downloads_dir), + "transcripts_dir": str(self.transcripts_dir), + "cache_dir": str(self.cache_dir), + "audio_files_count": len(audio_files), + "transcript_txt_count": len(transcript_files.get('txt', [])), + "transcript_srt_count": len(transcript_files.get('srt', [])), + "transcript_json_count": len(transcript_files.get('json', [])), + "downloads_size_mb": round(get_dir_size(self.downloads_dir) / (1024 * 1024), 2), + "transcripts_size_mb": round(get_dir_size(self.transcripts_dir) / (1024 * 1024), 2), + "cache_size_mb": round(get_dir_size(self.cache_dir) / (1024 * 1024), 2), + } + + +# Global storage configuration instance +_storage_config: Optional[StorageConfig] = None + + +def get_storage_config() -> StorageConfig: + """ + Get global storage configuration instance + + Returns: + StorageConfig instance + """ + global _storage_config + if _storage_config is None: + _storage_config = StorageConfig() + return _storage_config + + +def get_downloads_dir() -> Path: + """Get downloads directory path""" + return get_storage_config().downloads_dir + + +def get_transcripts_dir() -> Path: + """Get transcripts directory path""" + return get_storage_config().transcripts_dir + + +def get_cache_dir() -> Path: + """Get cache directory path""" + return get_storage_config().cache_dir \ No newline at end of file