Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import argparse | |
import io | |
from time import time | |
from typing import List, Optional | |
from abc import ABC, abstractmethod | |
import uvicorn | |
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form, Security | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, field_validator, Field | |
from slowapi import Limiter | |
from slowapi.util import get_remote_address | |
import requests | |
from PIL import Image | |
# Import from auth.py | |
from utils.auth import get_current_user, get_current_user_with_admin, login, refresh_token, register, TokenResponse, Settings, LoginRequest, RegisterRequest, bearer_scheme | |
# Assuming these are in your project structure | |
from config.tts_config import SPEED, ResponseFormat, config as tts_config | |
from config.logging_config import logger | |
settings = Settings() | |
# FastAPI app setup with enhanced docs | |
app = FastAPI( | |
title="Dhwani API", | |
description="A multilingual AI-powered API supporting Indian languages for chat, text-to-speech, audio processing, and transcription. " | |
"**Authentication Guide:** \n" | |
#"1. Register a new user via `/v1/register` with a POST request containing `username` and `password` (requires admin access). \n" | |
"1. Obtain an access token by sending a POST request to `/v1/token` with `username` and `password`. \n" | |
"2. Click the 'Authorize' button (top-right), enter your access token (e.g., `your_access_token`) in the 'bearerAuth' field, and click 'Authorize'. \n" | |
"All protected endpoints require this token for access. \n" | |
#"Only the 'admin' user (default password: adminpass) can register new users.", | |
version="1.0.0", | |
redirect_slashes=False, | |
openapi_tags=[ | |
{"name": "Chat", "description": "Chat-related endpoints"}, | |
{"name": "Audio", "description": "Audio processing and TTS endpoints"}, | |
{"name": "Translation", "description": "Text translation endpoints"}, | |
{"name": "Authentication", "description": "User authentication and registration"}, | |
{"name": "Utility", "description": "General utility endpoints"}, | |
], | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=False, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Rate limiting based on user_id | |
limiter = Limiter(key_func=lambda request: get_current_user(request.scope.get("route").dependencies)) | |
# Request/Response Models | |
class SpeechRequest(BaseModel): | |
input: str = Field(..., description="Text to convert to speech (max 1000 characters)") | |
voice: str = Field(..., description="Voice identifier for the TTS service") | |
model: str = Field(..., description="TTS model to use") | |
response_format: ResponseFormat = Field(tts_config.response_format, description="Audio format: mp3, flac, or wav") | |
speed: float = Field(SPEED, description="Speech speed (default: 1.0)") | |
def input_must_be_valid(cls, v): | |
if len(v) > 1000: | |
raise ValueError("Input cannot exceed 1000 characters") | |
return v.strip() | |
def validate_response_format(cls, v): | |
supported_formats = [ResponseFormat.MP3, ResponseFormat.FLAC, ResponseFormat.WAV] | |
if v not in supported_formats: | |
raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}") | |
return v | |
class Config: | |
schema_extra = { | |
"example": { | |
"input": "Hello, how are you?", | |
"voice": "female-1", | |
"model": "tts-model-1", | |
"response_format": "mp3", | |
"speed": 1.0 | |
} | |
} | |
class TranscriptionResponse(BaseModel): | |
text: str = Field(..., description="Transcribed text from the audio") | |
class Config: | |
schema_extra = {"example": {"text": "Hello, how are you?"}} | |
class TextGenerationResponse(BaseModel): | |
text: str = Field(..., description="Generated text response") | |
class Config: | |
schema_extra = {"example": {"text": "Hi there, I'm doing great!"}} | |
class AudioProcessingResponse(BaseModel): | |
result: str = Field(..., description="Processed audio result") | |
class Config: | |
schema_extra = {"example": {"result": "Processed audio output"}} | |
# TTS Service Interface | |
class TTSService(ABC): | |
async def generate_speech(self, payload: dict) -> requests.Response: | |
pass | |
class ExternalTTSService(TTSService): | |
async def generate_speech(self, payload: dict) -> requests.Response: | |
try: | |
return requests.post( | |
settings.external_tts_url, | |
json=payload, | |
headers={"accept": "application/json", "Content-Type": "application/json"}, | |
stream=True, | |
timeout=60 | |
) | |
except requests.Timeout: | |
raise HTTPException(status_code=504, detail="External TTS API timeout") | |
except requests.RequestException as e: | |
raise HTTPException(status_code=500, detail=f"External TTS API error: {str(e)}") | |
def get_tts_service() -> TTSService: | |
return ExternalTTSService() | |
# Endpoints with enhanced Swagger docs | |
async def health_check(): | |
return {"status": "healthy", "model": settings.llm_model_name} | |
async def home(): | |
return RedirectResponse(url="/docs") | |
async def token(login_request: LoginRequest): | |
return await login(login_request) | |
async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)): | |
return await refresh_token(credentials) | |
async def register_user( | |
register_request: RegisterRequest, | |
current_user: str = Depends(get_current_user_with_admin) # Enforce admin-only access | |
): | |
return await register(register_request, current_user) # Pass current_user explicitly | |
async def generate_audio( | |
request: Request, | |
speech_request: SpeechRequest = Depends(), | |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), | |
tts_service: TTSService = Depends(get_tts_service) | |
): | |
user_id = await get_current_user(credentials) | |
if not speech_request.input.strip(): | |
raise HTTPException(status_code=400, detail="Input cannot be empty") | |
logger.info("Processing speech request", extra={ | |
"endpoint": "/v1/audio/speech", | |
"input_length": len(speech_request.input), | |
"client_ip": get_remote_address(request), | |
"user_id": user_id | |
}) | |
payload = { | |
"input": speech_request.input, | |
"voice": speech_request.voice, | |
"model": speech_request.model, | |
"response_format": speech_request.response_format.value, | |
"speed": speech_request.speed | |
} | |
response = await tts_service.generate_speech(payload) | |
response.raise_for_status() | |
headers = { | |
"Content-Disposition": f"inline; filename=\"speech.{speech_request.response_format.value}\"", | |
"Cache-Control": "no-cache", | |
"Content-Type": f"audio/{speech_request.response_format.value}" | |
} | |
return StreamingResponse( | |
response.iter_content(chunk_size=8192), | |
media_type=f"audio/{speech_request.response_format.value}", | |
headers=headers | |
) | |
class ChatRequest(BaseModel): | |
prompt: str = Field(..., description="Text prompt for chat (max 1000 characters)") | |
src_lang: str = Field("kan_Knda", description="Source language code (default: Kannada)") | |
def prompt_must_be_valid(cls, v): | |
if len(v) > 1000: | |
raise ValueError("Prompt cannot exceed 1000 characters") | |
return v.strip() | |
class Config: | |
schema_extra = { | |
"example": { | |
"prompt": "Hello, how are you?", | |
"src_lang": "kan_Knda" | |
} | |
} | |
class ChatResponse(BaseModel): | |
response: str = Field(..., description="Generated chat response") | |
class Config: | |
schema_extra = {"example": {"response": "Hi there, I'm doing great!"}} | |
async def chat( | |
request: Request, | |
chat_request: ChatRequest, | |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme) | |
): | |
user_id = await get_current_user(credentials) | |
if not chat_request.prompt: | |
raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}") | |
try: | |
external_url = "https://slabstech-dhwani-internal-api-server.hf.space/v1/chat" | |
payload = { | |
"prompt": chat_request.prompt, | |
"src_lang": chat_request.src_lang, | |
"tgt_lang": chat_request.src_lang | |
} | |
response = requests.post( | |
external_url, | |
json=payload, | |
headers={ | |
"accept": "application/json", | |
"Content-Type": "application/json" | |
}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
response_data = response.json() | |
response_text = response_data.get("response", "") | |
logger.info(f"Generated Chat response from external API: {response_text}") | |
return ChatResponse(response=response_text) | |
except requests.Timeout: | |
logger.error("External chat API request timed out") | |
raise HTTPException(status_code=504, detail="Chat service timeout") | |
except requests.RequestException as e: | |
logger.error(f"Error calling external chat API: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def process_audio( | |
request: Request, | |
file: UploadFile = File(..., description="Audio file to process"), | |
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"), | |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme) | |
): | |
user_id = await get_current_user(credentials) | |
logger.info("Processing audio processing request", extra={ | |
"endpoint": "/v1/process_audio", | |
"filename": file.filename, | |
"client_ip": get_remote_address(request), | |
"user_id": user_id | |
}) | |
start_time = time() | |
try: | |
file_content = await file.read() | |
files = {"file": (file.filename, file_content, file.content_type)} | |
external_url = f"{settings.external_audio_proc_url}/process_audio/?language={language}" | |
response = requests.post( | |
external_url, | |
files=files, | |
headers={"accept": "application/json"}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
processed_result = response.json().get("result", "") | |
logger.info(f"Audio processing completed in {time() - start_time:.2f} seconds") | |
return AudioProcessingResponse(result=processed_result) | |
except requests.Timeout: | |
raise HTTPException(status_code=504, detail="Audio processing service timeout") | |
except requests.RequestException as e: | |
logger.error(f"Audio processing request failed: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}") | |
async def transcribe_audio( | |
file: UploadFile = File(..., description="Audio file to transcribe"), | |
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"), | |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme) | |
): | |
user_id = await get_current_user(credentials) | |
start_time = time() | |
try: | |
file_content = await file.read() | |
files = {"file": (file.filename, file_content, file.content_type)} | |
external_url = f"{settings.external_asr_url}/transcribe/?language={language}" | |
response = requests.post( | |
external_url, | |
files=files, | |
headers={"accept": "application/json"}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
transcription = response.json().get("text", "") | |
return TranscriptionResponse(text=transcription) | |
except requests.Timeout: | |
raise HTTPException(status_code=504, detail="Transcription service timeout") | |
except requests.RequestException as e: | |
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
async def chat_v2( | |
request: Request, | |
prompt: str = Form(..., description="Text prompt for chat"), | |
image: UploadFile = File(default=None, description="Optional image to accompany the prompt"), | |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme) | |
): | |
user_id = await get_current_user(credentials) | |
if not prompt: | |
raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
logger.info("Processing chat_v2 request", extra={ | |
"endpoint": "/v1/chat_v2", | |
"prompt_length": len(prompt), | |
"has_image": bool(image), | |
"client_ip": get_remote_address(request), | |
"user_id": user_id | |
}) | |
try: | |
image_data = Image.open(await image.read()) if image else None | |
response_text = f"Processed: {prompt}" + (" with image" if image_data else "") | |
return TranscriptionResponse(text=response_text) | |
except Exception as e: | |
logger.error(f"Chat_v2 processing failed: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
class TranslationRequest(BaseModel): | |
sentences: List[str] = Field(..., description="List of sentences to translate") | |
src_lang: str = Field(..., description="Source language code") | |
tgt_lang: str = Field(..., description="Target language code") | |
class Config: | |
schema_extra = { | |
"example": { | |
"sentences": ["Hello", "How are you?"], | |
"src_lang": "en", | |
"tgt_lang": "kan_Knda" | |
} | |
} | |
class TranslationResponse(BaseModel): | |
translations: List[str] = Field(..., description="Translated sentences") | |
class Config: | |
schema_extra = {"example": {"translations": ["ನಮಸ್ಕಾರ", "ನೀವು ಹೇಗಿದ್ದೀರಿ?"]}} | |
async def translate( | |
request: TranslationRequest, | |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme) | |
): | |
user_id = await get_current_user(credentials) | |
logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}") | |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}" | |
payload = { | |
"sentences": request.sentences, | |
"src_lang": request.src_lang, | |
"tgt_lang": request.tgt_lang | |
} | |
try: | |
response = requests.post( | |
external_url, | |
json=payload, | |
headers={ | |
"accept": "application/json", | |
"Content-Type": "application/json" | |
}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
response_data = response.json() | |
translations = response_data.get("translations", []) | |
if not translations or len(translations) != len(request.sentences): | |
logger.warning(f"Unexpected response format: {response_data}") | |
raise HTTPException(status_code=500, detail="Invalid response from translation service") | |
logger.info(f"Translation successful: {translations}") | |
return TranslationResponse(translations=translations) | |
except requests.Timeout: | |
logger.error("Translation request timed out") | |
raise HTTPException(status_code=504, detail="Translation service timeout") | |
except requests.RequestException as e: | |
logger.error(f"Error during translation: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") | |
except ValueError as e: | |
logger.error(f"Invalid JSON response: {str(e)}") | |
raise HTTPException(status_code=500, detail="Invalid response format from translation service") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run the FastAPI server.") | |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.") | |
parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.") | |
args = parser.parse_args() | |
uvicorn.run(app, host=args.host, port=args.port) |