File size: 4,294 Bytes
6ff1f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
import json
import logging
import logging.config
import os

from core.config import API_HOST, API_PORT, CORS_SETTINGS, LOG_CONFIG
from core.exceptions import APIError, handle_api_error
from core.text_generation import text_generator

# Configure logging
logging.config.dictConfig(LOG_CONFIG)
logger = logging.getLogger(__name__)

app = FastAPI(title="AI Text Generation API",
             description="API for text generation using multiple AI providers",
             version="1.0.0")

# Enable CORS with specific headers for SSE
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Update this in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["Content-Type", "Cache-Control"]
)

# Mount static files
frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'frontend')
app.mount("/static", StaticFiles(directory=frontend_dir), name="static")

class PromptRequest(BaseModel):
    model: str
    prompt: str

@app.get("/")
async def read_root():
    """Serve the frontend HTML."""
    return FileResponse(os.path.join(frontend_dir, 'index.html'))

@app.get("/models")
async def get_models():
    """Get list of all available models."""
    try:
        # Return models as a JSON array
        return JSONResponse(content=text_generator.get_available_models())
    except APIError as e:
        error_response = handle_api_error(e)
        raise HTTPException(
            status_code=error_response["status_code"],
            detail=error_response["detail"]
        )
    except Exception as e:
        logger.error(f"Unexpected error in get_models: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

async def generate_stream(model: str, prompt: str):
    """Stream generator for text generation."""
    try:
        async for chunk in text_generator.generate_stream(model, prompt):
            # Add extra newline to ensure proper event separation
            yield f"data: {json.dumps({'content': chunk})}\n\n"
    except APIError as e:
        error_response = handle_api_error(e)
        yield f"data: {json.dumps({'error': error_response['detail']})}\n\n"
    except Exception as e:
        logger.error(f"Unexpected error in generate_stream: {str(e)}")
        yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
    finally:
        yield "data: [DONE]\n\n"

@app.get("/generate")
@app.post("/generate")
async def generate_response(request: Request):
    """Generate response using selected model (supports both GET and POST)."""
    try:
        # Handle both GET and POST methods
        if request.method == "GET":
            params = dict(request.query_params)
            model = params.get("model")
            prompt = params.get("prompt")
        else:
            body = await request.json()
            model = body.get("model")
            prompt = body.get("prompt")

        if not model or not prompt:
            raise HTTPException(status_code=400, detail="Missing model or prompt parameter")

        logger.info(f"Received {request.method} request for model: {model}")
        
        headers = {
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"  # Disable buffering for nginx
        }
        
        return StreamingResponse(
            generate_stream(model, prompt),
            media_type="text/event-stream",
            headers=headers
        )
        
    except APIError as e:
        error_response = handle_api_error(e)
        raise HTTPException(
            status_code=error_response["status_code"],
            detail=error_response["detail"]
        )
    except Exception as e:
        logger.error(f"Unexpected error in generate_response: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host=API_HOST, port=API_PORT)