Spaces:
Building
Building
"""Audio API endpoints for Flare | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
Provides text-to-speech (TTS) and speech-to-text (STT) endpoints. | |
""" | |
from fastapi import APIRouter, HTTPException, Response, Body | |
from pydantic import BaseModel | |
from typing import Optional | |
from datetime import datetime | |
import sys | |
from logger import log_info, log_error, log_warning, log_debug | |
from tts_factory import TTSFactory | |
from tts_preprocessor import TTSPreprocessor | |
from config_provider import ConfigProvider | |
router = APIRouter(tags=["audio"]) | |
# ===================== Models ===================== | |
class TTSRequest(BaseModel): | |
text: str | |
voice_id: Optional[str] = None | |
language: Optional[str] = "tr-TR" | |
class STTRequest(BaseModel): | |
audio_data: str # Base64 encoded audio | |
language: Optional[str] = "tr-TR" | |
format: Optional[str] = "webm" # webm, wav, mp3 | |
# ===================== Helpers ===================== | |
def log(message: str): | |
"""Log helper with timestamp""" | |
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] | |
print(f"[{timestamp}] {message}") | |
sys.stdout.flush() | |
# ===================== TTS Endpoints ===================== | |
async def generate_tts(request: TTSRequest): | |
"""Generate TTS audio from text - public endpoint for chat""" | |
try: | |
# Create TTS provider | |
tts_provider = TTSFactory.create_provider() | |
if not tts_provider: | |
# Return empty response for no TTS | |
log_info("π΅ TTS disabled - returning empty response") | |
return Response( | |
content=b"", | |
media_type="audio/mpeg", | |
headers={"X-TTS-Status": "disabled"} | |
) | |
log_info(f"π€ TTS request: '{request.text[:50]}...' with provider: {tts_provider.get_provider_name()}") | |
# Preprocess text if needed | |
preprocessor = TTSPreprocessor(language=request.language) | |
processed_text = preprocessor.preprocess( | |
request.text, | |
tts_provider.get_preprocessing_flags() | |
) | |
log_debug(f"π Preprocessed text: {processed_text[:100]}...") | |
# Generate audio | |
audio_data = await tts_provider.synthesize( | |
text=processed_text, | |
voice_id=request.voice_id | |
) | |
log_info(f"β TTS generated {len(audio_data)} bytes of audio") | |
# Return audio as binary response | |
return Response( | |
content=audio_data, | |
media_type="audio/mpeg", | |
headers={ | |
"Content-Disposition": 'inline; filename="tts_output.mp3"', | |
"X-TTS-Provider": tts_provider.get_provider_name(), | |
"X-TTS-Language": request.language, | |
"Cache-Control": "no-cache" | |
} | |
) | |
except Exception as e: | |
log_error("β TTS generation error", e) | |
raise HTTPException( | |
status_code=500, | |
detail=f"TTS generation failed: {str(e)}" | |
) | |
async def get_tts_voices(): | |
"""Get available TTS voices - public endpoint""" | |
try: | |
tts_provider = TTSFactory.create_provider() | |
if not tts_provider: | |
return { | |
"voices": [], | |
"provider": "none", | |
"enabled": False | |
} | |
voices = tts_provider.get_supported_voices() | |
# Convert dict to list format | |
voice_list = [ | |
{"id": voice_id, "name": voice_name} | |
for voice_id, voice_name in voices.items() | |
] | |
return { | |
"voices": voice_list, | |
"provider": tts_provider.get_provider_name(), | |
"enabled": True | |
} | |
except Exception as e: | |
log_error("β Error getting TTS voices", e) | |
return { | |
"voices": [], | |
"provider": "error", | |
"enabled": False, | |
"error": str(e) | |
} | |
async def get_tts_status(): | |
"""Get TTS service status""" | |
cfg = ConfigProvider.get() | |
return { | |
"enabled": cfg.global_config.tts_provider.name != "no_tts", | |
"provider": cfg.global_config.tts_provider.name, | |
"provider_config": { | |
"name": cfg.global_config.tts_provider.name, | |
"has_api_key": bool(cfg.global_config.tts_provider.api_key), | |
"endpoint": cfg.global_config.tts_provider.endpoint | |
} | |
} | |
# ===================== STT Endpoints ===================== | |
async def transcribe_audio(request: STTRequest): | |
"""Transcribe audio to text""" | |
try: | |
from stt_factory import STTFactory | |
from stt_interface import STTConfig | |
import base64 | |
# Create STT provider | |
stt_provider = STTFactory.create_provider() | |
if not stt_provider or not stt_provider.supports_realtime(): | |
log_warning("π΅ STT disabled or doesn't support transcription") | |
raise HTTPException( | |
status_code=503, | |
detail="STT service not available" | |
) | |
# Get config | |
cfg = ConfigProvider.get() | |
stt_config = cfg.global_config.stt_provider.settings | |
# Decode audio data | |
audio_bytes = base64.b64decode(request.audio_data) | |
# Create STT config | |
config = STTConfig( | |
language=request.language or stt_config.get("language", "tr-TR"), | |
sample_rate=16000, | |
encoding=request.format.upper() if request.format else "WEBM_OPUS", | |
enable_punctuation=stt_config.get("enable_punctuation", True), | |
enable_word_timestamps=False, | |
model=stt_config.get("model", "latest_long"), | |
use_enhanced=stt_config.get("use_enhanced", True), | |
single_utterance=True, | |
interim_results=False | |
) | |
# Start streaming session | |
await stt_provider.start_streaming(config) | |
# Process audio | |
transcription = "" | |
confidence = 0.0 | |
try: | |
async for result in stt_provider.stream_audio(audio_bytes): | |
if result.is_final: | |
transcription = result.text | |
confidence = result.confidence | |
break | |
finally: | |
# Stop streaming | |
await stt_provider.stop_streaming() | |
log_info(f"β STT transcription completed: '{transcription[:50]}...'") | |
return { | |
"text": transcription, | |
"confidence": confidence, | |
"language": request.language, | |
"provider": stt_provider.get_provider_name() | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
log_error("β STT transcription error", e) | |
raise HTTPException( | |
status_code=500, | |
detail=f"Transcription failed: {str(e)}" | |
) | |
async def get_stt_languages(): | |
"""Get supported STT languages""" | |
try: | |
from stt_factory import STTFactory | |
stt_provider = STTFactory.create_provider() | |
if not stt_provider: | |
return { | |
"languages": [], | |
"provider": "none", | |
"enabled": False | |
} | |
languages = stt_provider.get_supported_languages() | |
return { | |
"languages": languages, | |
"provider": stt_provider.get_provider_name(), | |
"enabled": True | |
} | |
except Exception as e: | |
log_error("β Error getting STT languages", e) | |
return { | |
"languages": [], | |
"provider": "error", | |
"enabled": False, | |
"error": str(e) | |
} | |
async def get_stt_status(): | |
"""Get STT service status""" | |
cfg = ConfigProvider.get() | |
return { | |
"enabled": cfg.global_config.stt_provider.name != "no_stt", | |
"provider": cfg.global_config.stt_provider.name, | |
"provider_config": { | |
"name": cfg.global_config.stt_provider.name, | |
"has_api_key": bool(cfg.global_config.stt_provider.api_key), | |
"endpoint": cfg.global_config.stt_provider.endpoint | |
} | |
} |