import logging import os import uuid from contextlib import asynccontextmanager from tempfile import NamedTemporaryFile from azure.storage.blob import BlobServiceClient, generate_blob_sas, BlobSasPermissions from datetime import datetime, timedelta import torchaudio from fastapi import BackgroundTasks, Depends, FastAPI, Header, HTTPException from fastapi.security import APIKeyHeader from pydantic import BaseModel from inference import load_models, process_voice_conversion logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables models = None API_KEY = os.getenv("API_KEY") api_key_header = APIKeyHeader(name="Authorization", auto_error=False) async def verify_api_key(authorization: str = Header(None)): if not authorization: logger.warning("No API key provided") raise HTTPException(status_code=401, detail="API key is missing") if authorization.startswith("Bearer "): token = authorization.replace("Bearer ", "") else: token = authorization if token != API_KEY: logger.warning("Invalid API key provided") raise HTTPException(status_code=401, detail="Invalid API key") return token def get_azure_blob_client(): account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME", "getpoints") account_key = os.getenv("AZURE_STORAGE_KEY", "ts/PL1cr3X1F9JWgksAtqcWsQvPBK9UJ3BtNQBL98kYU17U3JxEiFI2vJrNDzmAyFRleOdRdoG03+ASt9RDnZA==") blob_endpoint = os.getenv("AZURE_BLOB_ENDPOINT", "https://getpoints.blob.core.windows.net/") blob_service_client = BlobServiceClient( account_url=blob_endpoint, credential=account_key ) return blob_service_client blob_client = get_azure_blob_client() AZURE_CONTAINER_NAME = os.getenv("AZURE_CONTAINER_NAME", "seedvc-outputs") async def ensure_container_exists(): """Ensure the Azure container exists, create if it doesn't""" try: container_client = blob_client.get_container_client(AZURE_CONTAINER_NAME) container_client.get_container_properties() logger.info(f"Container '{AZURE_CONTAINER_NAME}' already exists") except Exception: try: blob_client.create_container(AZURE_CONTAINER_NAME) logger.info(f"Created container '{AZURE_CONTAINER_NAME}'") except Exception as e: logger.error(f"Failed to create container '{AZURE_CONTAINER_NAME}': {e}") raise @asynccontextmanager async def lifespan(app: FastAPI): global models logger.info("Loading Seed-VC model...") try: await ensure_container_exists() models = load_models() logger.info("Seed-VC model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise yield logger.info("Shutting down Seed-VC API") app = FastAPI(title="Seed-VC API", lifespan=lifespan) TARGET_VOICES = { "male": "examples/reference/s1p2.wav", "female": "examples/reference/s1p1.wav", "trump": "examples/reference/trump_0.wav", } class VoiceConversionRequest(BaseModel): source_audio_key: str target_voice: str def download_blob_to_temp(blob_name): temp_file = NamedTemporaryFile(delete=False, suffix=".wav") blob_client_instance = blob_client.get_blob_client( container=AZURE_CONTAINER_NAME, blob=blob_name ) with open(temp_file.name, "wb") as f: download_stream = blob_client_instance.download_blob() f.write(download_stream.readall()) return temp_file.name @app.post("/convert", dependencies=[Depends(verify_api_key)]) async def generate_speech(request: VoiceConversionRequest, background_tasks: BackgroundTasks): if not models: raise HTTPException(status_code=500, detail="Model not loaded") if request.target_voice not in TARGET_VOICES: raise HTTPException( status_code=400, detail=f"Target voice not supported. Choose from: {', '.join(TARGET_VOICES.keys())}") try: target_audio_path = TARGET_VOICES[request.target_voice] logger.info( f"Converting voice: {request.source_audio_key} to {request.target_voice}") # Generate a unique filename audio_id = str(uuid.uuid4()) output_filename = f"{audio_id}.wav" local_path = f"/tmp/{output_filename}" logger.info("Downloading source audio from Azure Blob Storage") try: source_temp_path = download_blob_to_temp(request.source_audio_key) except Exception as e: logger.error(f"Failed to download source audio: {e}") raise HTTPException( status_code=404, detail="Source audio not found") vc_wave, sr = process_voice_conversion( models=models, source=source_temp_path, target_name=target_audio_path, output=None) os.unlink(source_temp_path) torchaudio.save(local_path, vc_wave, sr) # Upload to Azure Blob Storage blob_name = f"seedvc-outputs/{output_filename}" blob_client_instance = blob_client.get_blob_client( container=AZURE_CONTAINER_NAME, blob=blob_name ) with open(local_path, "rb") as data: blob_client_instance.upload_blob(data, overwrite=True) # Generate SAS URL for temporary access sas_token = generate_blob_sas( account_name=blob_client.account_name, container_name=AZURE_CONTAINER_NAME, blob_name=blob_name, account_key=os.getenv("AZURE_STORAGE_KEY", "ts/PL1cr3X1F9JWgksAtqcWsQvPBK9UJ3BtNQBL98kYU17U3JxEiFI2vJrNDzmAyFRleOdRdoG03+ASt9RDnZA=="), permission=BlobSasPermissions(read=True), expiry=datetime.utcnow() + timedelta(hours=1) ) blob_url = f"{blob_client_instance.url}?{sas_token}" background_tasks.add_task(os.remove, local_path) return { "audio_url": blob_url, "blob_name": blob_name } except Exception as e: logger.error(f"Error in voice conversion: {e}") raise HTTPException( status_code=500, detail="Error in voice conversion") @app.get("/voices", dependencies=[Depends(verify_api_key)]) async def list_voices(): return {"voices": list(TARGET_VOICES.keys())} @app.get("/health", dependencies=[Depends(verify_api_key)]) async def health_check(): if models: return {"status": "healthy", "model": "loaded"} return {"status": "unhealthy", "model": "not loaded"}