AZAA / main.py
rkihacker's picture
Update main.py
f7f30cc verified
raw
history blame
2.8 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Union
import requests
import json
app = FastAPI()
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"
# Load virtual model -> system prompt mappings
with open("model_map.json", "r") as f:
MODEL_PROMPTS = json.load(f)
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
def build_payload(chat: ChatRequest):
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
messages = [{"role": "system", "content": system_prompt}] + [
{"role": msg.role, "content": msg.content} for msg in chat.messages
]
return {
"model": BACKEND_MODEL,
"messages": messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
def stream_generator(requested_model, payload, headers):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines():
if line:
decoded = line.decode('utf-8')
# Rewrite the model name in streaming output
if BACKEND_MODEL in decoded:
decoded = decoded.replace(BACKEND_MODEL, requested_model)
yield f"data: {decoded}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream"
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
data = response.json()
# Replace model in final result
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(content=data)