AZAA / main.py
rkihacker's picture
Update main.py
ad95a9a verified
raw
history blame
4.15 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging
app = FastAPI()
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")
# Configuration
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"
# Load system prompt mappings
with open("model_map.json", "r", encoding="utf-8") as f:
MODEL_PROMPTS = json.load(f)
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
# Build request to backend with injected system prompt
def build_payload(chat: ChatRequest):
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
payload_messages = [{"role": "system", "content": system_prompt}] + [
{"role": msg.role, "content": msg.content} for msg in filtered_messages
]
return {
"model": BACKEND_MODEL,
"messages": payload_messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
# Streaming chunk handler with model replacement and UTF-8 fix
def stream_generator(requested_model: str, payload: dict, headers: dict):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("data:"):
content = line[6:].strip()
if content == "[DONE]":
yield "data: [DONE]\n\n"
continue
try:
json_obj = json.loads(content)
if json_obj.get("model") == BACKEND_MODEL:
json_obj["model"] = requested_model
yield "data: " + json.dumps(json_obj, ensure_ascii=False) + "\n\n"
except json.JSONDecodeError:
logger.warning("Invalid JSON in stream chunk: %s", content)
else:
logger.debug("Non-data stream line skipped: %s", line)
# Main API endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
try:
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream; charset=utf-8",
headers={"Content-Type": "text/event-stream; charset=utf-8"}
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
data = response.json()
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(
content=data,
media_type="application/json; charset=utf-8",
headers={"Content-Type": "application/json; charset=utf-8"}
)
except Exception as e:
logger.error("Error in /v1/chat/completions: %s", str(e))
return JSONResponse(
content={"error": "Internal server error."},
status_code=500
)