AZAA / main.py
rkihacker's picture
Update main.py
4d2b16a verified
raw
history blame
3.17 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
app = FastAPI()
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"
# Load virtual model -> system prompt mappings
with open("model_map.json", "r") as f:
MODEL_PROMPTS = json.load(f)
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
def build_payload(chat: ChatRequest):
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
messages = [{"role": "system", "content": system_prompt}] + [
{"role": msg.role, "content": msg.content} for msg in chat.messages
]
return {
"model": BACKEND_MODEL,
"messages": messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
def stream_generator(requested_model: str, payload: dict, headers: dict):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines(decode_unicode=True):
if line and line.startswith("data:"):
# Remove "data: " prefix
content = line[6:].strip()
try:
# Try to parse and replace model field
json_obj = json.loads(content)
if "model" in json_obj and json_obj["model"] == BACKEND_MODEL:
json_obj["model"] = requested_model
fixed_line = f"data: {json.dumps(json_obj)}\n\n"
except json.JSONDecodeError:
fixed_line = f"data: {content}\n\n"
yield fixed_line
elif line:
yield f"data: {line}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream"
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
data = response.json()
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(content=data)