flare / audio_routes.py
ciyidogan's picture
Update audio_routes.py
5d32118 verified
raw
history blame
8.42 kB
"""Audio API endpoints for Flare
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Provides text-to-speech (TTS) and speech-to-text (STT) endpoints.
"""
from fastapi import APIRouter, HTTPException, Response, Body
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
import sys
from logger import log_info, log_error, log_warning, log_debug
from tts_factory import TTSFactory
from tts_preprocessor import TTSPreprocessor
from config_provider import ConfigProvider
router = APIRouter(tags=["audio"])
# ===================== Models =====================
class TTSRequest(BaseModel):
text: str
voice_id: Optional[str] = None
language: Optional[str] = "tr-TR"
class STTRequest(BaseModel):
audio_data: str # Base64 encoded audio
language: Optional[str] = "tr-TR"
format: Optional[str] = "webm" # webm, wav, mp3
# ===================== Helpers =====================
def log(message: str):
"""Log helper with timestamp"""
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
print(f"[{timestamp}] {message}")
sys.stdout.flush()
# ===================== TTS Endpoints =====================
@router.post("/tts/generate")
async def generate_tts(request: TTSRequest):
"""Generate TTS audio from text - public endpoint for chat"""
try:
# Create TTS provider
tts_provider = TTSFactory.create_provider()
if not tts_provider:
# Return empty response for no TTS
log_info("πŸ“΅ TTS disabled - returning empty response")
return Response(
content=b"",
media_type="audio/mpeg",
headers={"X-TTS-Status": "disabled"}
)
log_info(f"🎀 TTS request: '{request.text[:50]}...' with provider: {tts_provider.get_provider_name()}")
# Preprocess text if needed
preprocessor = TTSPreprocessor(language=request.language)
processed_text = preprocessor.preprocess(
request.text,
tts_provider.get_preprocessing_flags()
)
log_debug(f"πŸ“ Preprocessed text: {processed_text[:100]}...")
# Generate audio
audio_data = await tts_provider.synthesize(
text=processed_text,
voice_id=request.voice_id
)
log_info(f"βœ… TTS generated {len(audio_data)} bytes of audio")
# Return audio as binary response
return Response(
content=audio_data,
media_type="audio/mpeg",
headers={
"Content-Disposition": 'inline; filename="tts_output.mp3"',
"X-TTS-Provider": tts_provider.get_provider_name(),
"X-TTS-Language": request.language,
"Cache-Control": "no-cache"
}
)
except Exception as e:
log_error("❌ TTS generation error", e)
raise HTTPException(
status_code=500,
detail=f"TTS generation failed: {str(e)}"
)
@router.get("/tts/voices")
async def get_tts_voices():
"""Get available TTS voices - public endpoint"""
try:
tts_provider = TTSFactory.create_provider()
if not tts_provider:
return {
"voices": [],
"provider": "none",
"enabled": False
}
voices = tts_provider.get_supported_voices()
# Convert dict to list format
voice_list = [
{"id": voice_id, "name": voice_name}
for voice_id, voice_name in voices.items()
]
return {
"voices": voice_list,
"provider": tts_provider.get_provider_name(),
"enabled": True
}
except Exception as e:
log_error("❌ Error getting TTS voices", e)
return {
"voices": [],
"provider": "error",
"enabled": False,
"error": str(e)
}
@router.get("/tts/status")
async def get_tts_status():
"""Get TTS service status"""
cfg = ConfigProvider.get()
return {
"enabled": cfg.global_config.tts_provider.name != "no_tts",
"provider": cfg.global_config.tts_provider.name,
"provider_config": {
"name": cfg.global_config.tts_provider.name,
"has_api_key": bool(cfg.global_config.tts_provider.api_key),
"endpoint": cfg.global_config.tts_provider.endpoint
}
}
# ===================== STT Endpoints =====================
@router.post("/stt/transcribe")
async def transcribe_audio(request: STTRequest):
"""Transcribe audio to text"""
try:
from stt_factory import STTFactory
from stt_interface import STTConfig
import base64
# Create STT provider
stt_provider = STTFactory.create_provider()
if not stt_provider or not stt_provider.supports_realtime():
log_warning("πŸ“΅ STT disabled or doesn't support transcription")
raise HTTPException(
status_code=503,
detail="STT service not available"
)
# Get config
cfg = ConfigProvider.get()
stt_config = cfg.global_config.stt_provider.settings
# Decode audio data
audio_bytes = base64.b64decode(request.audio_data)
# Create STT config
config = STTConfig(
language=request.language or stt_config.get("language", "tr-TR"),
sample_rate=16000,
encoding=request.format.upper() if request.format else "WEBM_OPUS",
enable_punctuation=stt_config.get("enable_punctuation", True),
enable_word_timestamps=False,
model=stt_config.get("model", "latest_long"),
use_enhanced=stt_config.get("use_enhanced", True),
single_utterance=True,
interim_results=False
)
# Start streaming session
await stt_provider.start_streaming(config)
# Process audio
transcription = ""
confidence = 0.0
try:
async for result in stt_provider.stream_audio(audio_bytes):
if result.is_final:
transcription = result.text
confidence = result.confidence
break
finally:
# Stop streaming
await stt_provider.stop_streaming()
log_info(f"βœ… STT transcription completed: '{transcription[:50]}...'")
return {
"text": transcription,
"confidence": confidence,
"language": request.language,
"provider": stt_provider.get_provider_name()
}
except HTTPException:
raise
except Exception as e:
log_error("❌ STT transcription error", e)
raise HTTPException(
status_code=500,
detail=f"Transcription failed: {str(e)}"
)
@router.get("/stt/languages")
async def get_stt_languages():
"""Get supported STT languages"""
try:
from stt_factory import STTFactory
stt_provider = STTFactory.create_provider()
if not stt_provider:
return {
"languages": [],
"provider": "none",
"enabled": False
}
languages = stt_provider.get_supported_languages()
return {
"languages": languages,
"provider": stt_provider.get_provider_name(),
"enabled": True
}
except Exception as e:
log_error("❌ Error getting STT languages", e)
return {
"languages": [],
"provider": "error",
"enabled": False,
"error": str(e)
}
@router.get("/stt/status")
async def get_stt_status():
"""Get STT service status"""
cfg = ConfigProvider.get()
return {
"enabled": cfg.global_config.stt_provider.name != "no_stt",
"provider": cfg.global_config.stt_provider.name,
"provider_config": {
"name": cfg.global_config.stt_provider.name,
"has_api_key": bool(cfg.global_config.stt_provider.api_key),
"endpoint": cfg.global_config.stt_provider.endpoint
}
}