Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from pydantic import BaseModel | |
import json | |
import logging | |
import logging.config | |
import os | |
from core.config import API_HOST, API_PORT, CORS_SETTINGS, LOG_CONFIG | |
from core.exceptions import APIError, handle_api_error | |
from core.text_generation import text_generator | |
# Configure logging | |
logging.config.dictConfig(LOG_CONFIG) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="AI Text Generation API", | |
description="API for text generation using multiple AI providers", | |
version="1.0.0") | |
# Enable CORS with specific headers for SSE | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Update this in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
expose_headers=["Content-Type", "Cache-Control"] | |
) | |
# API configuration and setup | |
class PromptRequest(BaseModel): | |
model: str | |
prompt: str | |
async def read_root(): | |
"""API root endpoint.""" | |
return {"status": "ok", "message": "API is running"} | |
async def get_models(): | |
"""Get list of all available models.""" | |
try: | |
# Return models as a JSON array | |
return JSONResponse(content=text_generator.get_available_models()) | |
except APIError as e: | |
error_response = handle_api_error(e) | |
raise HTTPException( | |
status_code=error_response["status_code"], | |
detail=error_response["detail"] | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error in get_models: {str(e)}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
async def generate_stream(model: str, prompt: str): | |
"""Stream generator for text generation.""" | |
try: | |
async for chunk in text_generator.generate_stream(model, prompt): | |
# Add extra newline to ensure proper event separation | |
yield f"data: {json.dumps({'content': chunk})}\n\n" | |
except APIError as e: | |
error_response = handle_api_error(e) | |
yield f"data: {json.dumps({'error': error_response['detail']})}\n\n" | |
except Exception as e: | |
logger.error(f"Unexpected error in generate_stream: {str(e)}") | |
yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n" | |
finally: | |
yield "data: [DONE]\n\n" | |
async def generate_response(request: Request): | |
"""Generate response using selected model (supports both GET and POST).""" | |
try: | |
# Handle both GET and POST methods | |
if request.method == "GET": | |
params = dict(request.query_params) | |
model = params.get("model") | |
prompt = params.get("prompt") | |
else: | |
body = await request.json() | |
model = body.get("model") | |
prompt = body.get("prompt") | |
if not model or not prompt: | |
raise HTTPException(status_code=400, detail="Missing model or prompt parameter") | |
logger.info(f"Received {request.method} request for model: {model}") | |
headers = { | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no" # Disable buffering for nginx | |
} | |
return StreamingResponse( | |
generate_stream(model, prompt), | |
media_type="text/event-stream", | |
headers=headers | |
) | |
except APIError as e: | |
error_response = handle_api_error(e) | |
raise HTTPException( | |
status_code=error_response["status_code"], | |
detail=error_response["detail"] | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error in generate_response: {str(e)}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host=API_HOST, port=API_PORT) |