File size: 3,959 Bytes
0099d95
f0feabf
498d80c
1ad1813
739823d
c53513a
0bddb91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd7197
 
 
 
 
 
 
 
0bddb91
 
 
a0fbf6c
0bddb91
 
 
 
 
 
b5d7c98
0bddb91
 
 
 
 
 
 
 
 
 
 
 
11c5c73
0bddb91
 
 
 
 
 
 
 
 
 
 
 
 
 
11c5c73
0bddb91
11c5c73
0bddb91
 
 
 
 
 
 
 
 
5322bcf
0bddb91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5322bcf
0bddb91
 
 
 
 
 
6159237
c550535
7a4300a
0bddb91
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import time
import random
import asyncio
import json

from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional
from dotenv import load_dotenv
from starlette.responses import StreamingResponse
from openai import OpenAI
from typing import List, Optional, Type
load_dotenv()

API_KEYS = [
    os.getenv("API_GEMINI_1"),
    os.getenv("API_GEMINI_2"),
    os.getenv("API_GEMINI_3")
]

BASE_URL = os.getenv("BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
EXPECTED_API_KEY = os.getenv("EXPECTED_API_KEY", "1234")
API_KEY_NAME = "Authorization"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
app = FastAPI(title="OpenAI-SDK-compatible API", version="1.0.0", description="Un wrapper FastAPI compatibile con le specifiche dell'API OpenAI.")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    model: str = "gemini-2.0-flash" 
    messages: List[Message]
    max_tokens: Optional[int] = 8196
    temperature: Optional[float] = 0.8
    stream: Optional[bool] = False
        
def verify_api_key(api_key: str = Depends(api_key_header)):
    if not api_key:
        raise HTTPException(status_code=403, detail="API key mancante")
    if api_key != f"Bearer {EXPECTED_API_KEY}":
        raise HTTPException(status_code=403, detail="API key non valida")
    return api_key

def get_openai_client():
    api_key = random.choice(API_KEYS)
    return OpenAI(api_key=api_key, base_url=BASE_URL)

def call_api_sync(params: ChatCompletionRequest):
    try:
        client = get_openai_client()
        print(params)
        response = client.chat.completions.create(
            model=params.model,
            messages=[m.model_dump() for m in params.messages],
            max_tokens=params.max_tokens,
            temperature=params.temperature,
            stream=params.stream
        )
        return response
    except Exception as e:
        if "429" in str(e):
            time.sleep(2)
            return call_api_sync(params)
        else:
            raise e

async def _resp_async_generator(params: ChatCompletionRequest):
    client = get_openai_client()
    try:
        response = client.chat.completions.create(
            model=params.model,
            messages=[m.model_dump() for m in params.messages],
            max_tokens=params.max_tokens,
            temperature=params.temperature,
            stream=True
        )
        for chunk in response:
            chunk_data = chunk.to_dict() if hasattr(chunk, "to_dict") else chunk
            yield f"data: {json.dumps(chunk_data)}\n\n"
            await asyncio.sleep(0.01)
        yield "data: [DONE]\n\n"
    except Exception as e:
        error_data = {"error": str(e)}
        yield f"data: {json.dumps(error_data)}\n\n"

@app.get("/health")
async def health_check():
    return {"message": "success"}

@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
async def chat_completions(req: ChatCompletionRequest):
    if not req.messages:
        raise HTTPException(status_code=400, detail="Nessun messaggio fornito")
    if req.stream:
        return StreamingResponse(
            _resp_async_generator(req),
            media_type="application/x-ndjson"
        )
    else:
        try:
            response = call_api_sync(req)
            return response
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def read_general(): 
    return {"response": "Benvenuto"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)