File size: 4,027 Bytes
6ff1f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7f04e1
6ff1f88
 
 
 
 
 
 
c7f04e1
 
6ff1f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import json
import logging
import logging.config
import os

from core.config import API_HOST, API_PORT, CORS_SETTINGS, LOG_CONFIG
from core.exceptions import APIError, handle_api_error
from core.text_generation import text_generator

# Configure logging
logging.config.dictConfig(LOG_CONFIG)
logger = logging.getLogger(__name__)

app = FastAPI(title="AI Text Generation API",
             description="API for text generation using multiple AI providers",
             version="1.0.0")

# Enable CORS with specific headers for SSE
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Update this in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["Content-Type", "Cache-Control"]
)

# API configuration and setup

class PromptRequest(BaseModel):
    model: str
    prompt: str

@app.get("/")
async def read_root():
    """API root endpoint."""
    return {"status": "ok", "message": "API is running"}

@app.get("/models")
async def get_models():
    """Get list of all available models."""
    try:
        # Return models as a JSON array
        return JSONResponse(content=text_generator.get_available_models())
    except APIError as e:
        error_response = handle_api_error(e)
        raise HTTPException(
            status_code=error_response["status_code"],
            detail=error_response["detail"]
        )
    except Exception as e:
        logger.error(f"Unexpected error in get_models: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

async def generate_stream(model: str, prompt: str):
    """Stream generator for text generation."""
    try:
        async for chunk in text_generator.generate_stream(model, prompt):
            # Add extra newline to ensure proper event separation
            yield f"data: {json.dumps({'content': chunk})}\n\n"
    except APIError as e:
        error_response = handle_api_error(e)
        yield f"data: {json.dumps({'error': error_response['detail']})}\n\n"
    except Exception as e:
        logger.error(f"Unexpected error in generate_stream: {str(e)}")
        yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
    finally:
        yield "data: [DONE]\n\n"

@app.get("/generate")
@app.post("/generate")
async def generate_response(request: Request):
    """Generate response using selected model (supports both GET and POST)."""
    try:
        # Handle both GET and POST methods
        if request.method == "GET":
            params = dict(request.query_params)
            model = params.get("model")
            prompt = params.get("prompt")
        else:
            body = await request.json()
            model = body.get("model")
            prompt = body.get("prompt")

        if not model or not prompt:
            raise HTTPException(status_code=400, detail="Missing model or prompt parameter")

        logger.info(f"Received {request.method} request for model: {model}")
        
        headers = {
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"  # Disable buffering for nginx
        }
        
        return StreamingResponse(
            generate_stream(model, prompt),
            media_type="text/event-stream",
            headers=headers
        )
        
    except APIError as e:
        error_response = handle_api_error(e)
        raise HTTPException(
            status_code=error_response["status_code"],
            detail=error_response["detail"]
        )
    except Exception as e:
        logger.error(f"Unexpected error in generate_response: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host=API_HOST, port=API_PORT)