AI / main.py
MatteoScript's picture
Update main.py
c5c3f69 verified
raw
history blame
4.02 kB
import os
import time
import random
import asyncio
import json
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional
from dotenv import load_dotenv
from starlette.responses import StreamingResponse
from openai import OpenAI
from typing import List, Optional, Type
load_dotenv()
API_KEYS = [
os.getenv("API_GEMINI_1"),
os.getenv("API_GEMINI_2"),
os.getenv("API_GEMINI_3")
]
BASE_URL = os.getenv("BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
EXPECTED_API_KEY = os.getenv("API_HUGGINGFACE")
API_KEY_NAME = "Authorization"
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "gemini-2.0-flash"
messages: List[Message]
max_tokens: Optional[int] = 8196
temperature: Optional[float] = 0.8
stream: Optional[bool] = False
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
app = FastAPI(title="OpenAI-SDK-compatible API", version="1.0.0", description="Un wrapper FastAPI compatibile con le specifiche dell'API OpenAI.")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def verify_api_key(api_key: str = Depends(api_key_header)):
if not api_key:
raise HTTPException(status_code=403, detail="API key mancante")
if api_key != f"Bearer {EXPECTED_API_KEY}":
raise HTTPException(status_code=403, detail="API key non valida")
return api_key
def get_openai_client():
api_key = random.choice(API_KEYS)
return OpenAI(api_key=api_key, base_url=BASE_URL)
def call_api_sync(params: ChatCompletionRequest):
try:
client = get_openai_client()
print(params)
response = client.chat.completions.create(
model=params.model,
messages=[m.model_dump() for m in params.messages],
max_tokens=params.max_tokens,
temperature=params.temperature,
stream=params.stream
)
return response
except Exception as e:
if "429" in str(e):
time.sleep(2)
return call_api_sync(params)
else:
raise e
async def _resp_async_generator(params: ChatCompletionRequest):
client = get_openai_client()
try:
response = client.chat.completions.create(
model=params.model,
messages=[m.model_dump() for m in params.messages],
max_tokens=params.max_tokens,
temperature=params.temperature,
stream=True
)
for chunk in response:
chunk_data = chunk.to_dict() if hasattr(chunk, "to_dict") else chunk
yield f"data: {json.dumps(chunk_data)}\n\n"
await asyncio.sleep(0.01)
yield "data: [DONE]\n\n"
except Exception as e:
error_data = {"error": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
# ---------------------------------------------------------------------------------------
@app.get("/")
def read_general():
return {"response": "Benvenuto"}
@app.get("/health")
async def health_check():
return {"message": "success"}
# ---------------------------------- Generazione Testo ---------------------------------------
@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
async def chat_completions(req: ChatCompletionRequest):
if not req.messages:
raise HTTPException(status_code=400, detail="Nessun messaggio fornito")
if req.stream:
return StreamingResponse(
_resp_async_generator(req),
media_type="application/x-ndjson"
)
else:
try:
response = call_api_sync(req)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))