AZAA / main.py
rkihacker's picture
Update main.py
36f72ba verified
raw
history blame
4.01 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging
app = FastAPI()
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")
# TypeGPT API settings
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"
# Load model -> system prompt mappings
with open("model_map.json", "r", encoding="utf-8") as f:
MODEL_PROMPTS = json.load(f)
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
# Construct payload with enforced system prompt
def build_payload(chat: ChatRequest):
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
# Strip user system messages
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
payload_messages = [{"role": "system", "content": system_prompt}] + [
{"role": msg.role, "content": msg.content} for msg in filtered_messages
]
return {
"model": BACKEND_MODEL,
"messages": payload_messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
# Properly streamed UTF-8 chunks with model rewrite
def stream_generator(requested_model: str, payload: dict, headers: dict):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("data:"):
content = line[6:].strip()
if content == "[DONE]":
yield "data: [DONE]\n\n"
continue
try:
json_obj = json.loads(content)
if json_obj.get("model") == BACKEND_MODEL:
json_obj["model"] = requested_model
yield f"data: {json.dumps(json_obj, ensure_ascii=False)}\n\n"
except json.JSONDecodeError:
logger.warning("Invalid JSON in stream chunk: %s", content)
else:
logger.debug("Non-data stream line skipped: %s", line)
# Main endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
try:
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream; charset=utf-8",
headers={"Content-Type": "text/event-stream; charset=utf-8"}
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
data = response.json()
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(content=data, media_type="application/json; charset=utf-8")
except Exception as e:
logger.error("Error in /v1/chat/completions: %s", str(e))
return JSONResponse(content={"error": "Internal server error."}, status_code=500)