Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import argparse | |
import io | |
from time import time | |
from typing import List, Optional | |
from abc import ABC, abstractmethod | |
import uvicorn | |
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse | |
from pydantic import BaseModel, Field, field_validator | |
from slowapi import Limiter | |
from slowapi.util import get_remote_address | |
import requests | |
from PIL import Image | |
from utils.auth import get_current_user, login, refresh_token, TokenResponse, Settings, LoginRequest | |
# Assuming these are in your project structure | |
from config.tts_config import SPEED, ResponseFormat, config as tts_config | |
from config.logging_config import logger | |
settings = Settings() | |
# FastAPI app setup | |
app = FastAPI( | |
title="Dhwani API", | |
description="AI Chat API supporting Indian languages", | |
version="1.0.0", | |
redirect_slashes=False, | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=False, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
limiter = Limiter(key_func=get_remote_address) | |
app.state.limiter = limiter | |
# Request/Response Models | |
class SpeechRequest(BaseModel): | |
input: str | |
voice: str | |
model: str | |
response_format: ResponseFormat = tts_config.response_format | |
speed: float = SPEED | |
def input_must_be_valid(cls, v): | |
if len(v) > 1000: | |
raise ValueError("Input cannot exceed 1000 characters") | |
return v.strip() | |
def validate_response_format(cls, v): | |
supported_formats = [ResponseFormat.MP3, ResponseFormat.FLAC, ResponseFormat.WAV] | |
if v not in supported_formats: | |
raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}") | |
return v | |
class TranscriptionResponse(BaseModel): | |
text: str | |
class TextGenerationResponse(BaseModel): | |
text: str | |
class AudioProcessingResponse(BaseModel): | |
result: str | |
# TTS Service Interface | |
class TTSService(ABC): | |
async def generate_speech(self, payload: dict) -> requests.Response: | |
pass | |
class ExternalTTSService(TTSService): | |
async def generate_speech(self, payload: dict) -> requests.Response: | |
try: | |
return requests.post( | |
settings.external_tts_url, | |
json=payload, | |
headers={"accept": "application/json", "Content-Type": "application/json"}, | |
stream=True, | |
timeout=60 | |
) | |
except requests.Timeout: | |
raise HTTPException(status_code=504, detail="External TTS API timeout") | |
except requests.RequestException as e: | |
raise HTTPException(status_code=500, detail=f"External TTS API error: {str(e)}") | |
def get_tts_service() -> TTSService: | |
return ExternalTTSService() | |
async def token(login_request: LoginRequest): | |
return await login(login_request) | |
async def refresh(token_response: TokenResponse = Depends(refresh_token)): | |
return token_response | |
async def health_check(): | |
return {"status": "healthy", "model": settings.llm_model_name} | |
async def home(): | |
return RedirectResponse(url="/docs") | |
async def generate_audio( | |
request: Request, | |
speech_request: SpeechRequest = Depends(), | |
user_id: str = Depends(get_current_user), | |
tts_service: TTSService = Depends(get_tts_service) | |
): | |
if not speech_request.input.strip(): | |
raise HTTPException(status_code=400, detail="Input cannot be empty") | |
logger.info("Processing speech request", extra={ | |
"endpoint": "/v1/audio/speech", | |
"input_length": len(speech_request.input), | |
"client_ip": get_remote_address(request), | |
"user_id": user_id | |
}) | |
payload = { | |
"input": speech_request.input, | |
"voice": speech_request.voice, | |
"model": speech_request.model, | |
"response_format": speech_request.response_format.value, | |
"speed": speech_request.speed | |
} | |
response = await tts_service.generate_speech(payload) | |
response.raise_for_status() | |
headers = { | |
"Content-Disposition": f"inline; filename=\"speech.{speech_request.response_format.value}\"", | |
"Cache-Control": "no-cache", | |
"Content-Type": f"audio/{speech_request.response_format.value}" | |
} | |
return StreamingResponse( | |
response.iter_content(chunk_size=8192), | |
media_type=f"audio/{speech_request.response_format.value}", | |
headers=headers | |
) | |
class ChatRequest(BaseModel): | |
prompt: str | |
src_lang: str = "kan_Knda" | |
def prompt_must_be_valid(cls, v): | |
if len(v) > 1000: | |
raise ValueError("Prompt cannot exceed 1000 characters") | |
return v.strip() | |
class ChatResponse(BaseModel): | |
response: str | |
async def chat( | |
request: Request, | |
chat_request: ChatRequest, | |
user_id: str = Depends(get_current_user) | |
): | |
if not chat_request.prompt: | |
raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}") | |
try: | |
external_url = "https://slabstech-dhwani-internal-api-server.hf.space/v1/chat" | |
payload = { | |
"prompt": chat_request.prompt, | |
"src_lang": chat_request.src_lang, | |
"tgt_lang": chat_request.src_lang | |
} | |
response = requests.post( | |
external_url, | |
json=payload, | |
headers={ | |
"accept": "application/json", | |
"Content-Type": "application/json" | |
}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
response_data = response.json() | |
response_text = response_data.get("response", "") | |
logger.info(f"Generated Chat response from external API: {response_text}") | |
return ChatResponse(response=response_text) | |
except requests.Timeout: | |
logger.error("External chat API request timed out") | |
raise HTTPException(status_code=504, detail="Chat service timeout") | |
except requests.RequestException as e: | |
logger.error(f"Error calling external chat API: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def process_audio( | |
file: UploadFile = File(...), | |
language: str = Query(..., enum=["kannada", "hindi", "tamil"]), | |
user_id: str = Depends(get_current_user), | |
request: Request = None, | |
): | |
logger.info("Processing audio processing request", extra={ | |
"endpoint": "/v1/process_audio", | |
"filename": file.filename, | |
"client_ip": get_remote_address(request), | |
"user_id": user_id | |
}) | |
start_time = time() | |
try: | |
file_content = await file.read() | |
files = {"file": (file.filename, file_content, file.content_type)} | |
external_url = f"{settings.external_audio_proc_url}/process_audio/?language={language}" | |
response = requests.post( | |
external_url, | |
files=files, | |
headers={"accept": "application/json"}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
processed_result = response.json().get("result", "") | |
logger.info(f"Audio processing completed in {time() - start_time:.2f} seconds") | |
return AudioProcessingResponse(result=processed_result) | |
except requests.Timeout: | |
raise HTTPException(status_code=504, detail="Audio processing service timeout") | |
except requests.RequestException as e: | |
logger.error(f"Audio processing request failed: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}") | |
async def transcribe_audio( | |
file: UploadFile = File(...), | |
language: str = Query(..., enum=["kannada", "hindi", "tamil"]), | |
user_id: str = Depends(get_current_user), | |
request: Request = None, | |
): | |
start_time = time() | |
try: | |
file_content = await file.read() | |
files = {"file": (file.filename, file_content, file.content_type)} | |
external_url = f"{settings.external_asr_url}/transcribe/?language={language}" | |
response = requests.post( | |
external_url, | |
files=files, | |
headers={"accept": "application/json"}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
transcription = response.json().get("text", "") | |
return TranscriptionResponse(text=transcription) | |
except requests.Timeout: | |
raise HTTPException(status_code=504, detail="Transcription service timeout") | |
except requests.RequestException as e: | |
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
async def chat_v2( | |
request: Request, | |
prompt: str = Form(...), | |
image: UploadFile = File(default=None), | |
user_id: str = Depends(get_current_user) | |
): | |
if not prompt: | |
raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
logger.info("Processing chat_v2 request", extra={ | |
"endpoint": "/v1/chat_v2", | |
"prompt_length": len(prompt), | |
"has_image": bool(image), | |
"client_ip": get_remote_address(request), | |
"user_id": user_id | |
}) | |
try: | |
image_data = Image.open(await image.read()) if image else None | |
response_text = f"Processed: {prompt}" + (" with image" if image_data else "") | |
return TranscriptionResponse(text=response_text) | |
except Exception as e: | |
logger.error(f"Chat_v2 processing failed: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
class TranslationRequest(BaseModel): | |
sentences: list[str] | |
src_lang: str | |
tgt_lang: str | |
class TranslationResponse(BaseModel): | |
translations: list[str] | |
async def translate( | |
request: TranslationRequest, | |
user_id: str = Depends(get_current_user) | |
): | |
logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}") | |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}" | |
payload = { | |
"sentences": request.sentences, | |
"src_lang": request.src_lang, | |
"tgt_lang": request.tgt_lang | |
} | |
try: | |
response = requests.post( | |
external_url, | |
json=payload, | |
headers={ | |
"accept": "application/json", | |
"Content-Type": "application/json" | |
}, | |
timeout=60 | |
) | |
response.raise_for_status() | |
response_data = response.json() | |
translations = response_data.get("translations", []) | |
if not translations or len(translations) != len(request.sentences): | |
logger.warning(f"Unexpected response format: {response_data}") | |
raise HTTPException(status_code=500, detail="Invalid response from translation service") | |
logger.info(f"Translation successful: {translations}") | |
return TranslationResponse(translations=translations) | |
except requests.Timeout: | |
logger.error("Translation request timed out") | |
raise HTTPException(status_code=504, detail="Translation service timeout") | |
except requests.RequestException as e: | |
logger.error(f"Error during translation: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") | |
except ValueError as e: | |
logger.error(f"Invalid JSON response: {str(e)}") | |
raise HTTPException(status_code=500, detail="Invalid response format from translation service") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run the FastAPI server.") | |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.") | |
parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.") | |
args = parser.parse_args() | |
uvicorn.run(app, host=args.host, port=args.port) |