|
from fastapi import FastAPI, Request |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
from typing import List, Optional, Union |
|
import requests |
|
import json |
|
import logging |
|
|
|
app = FastAPI() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger("proxy") |
|
|
|
|
|
API_URL = "https://api.typegpt.net/v1/chat/completions" |
|
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm" |
|
BACKEND_MODEL = "gpt-4o-mini" |
|
|
|
|
|
with open("model_map.json", "r", encoding="utf-8") as f: |
|
MODEL_PROMPTS = json.load(f) |
|
|
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: str |
|
|
|
class ChatRequest(BaseModel): |
|
model: str |
|
messages: List[Message] |
|
stream: Optional[bool] = False |
|
temperature: Optional[float] = 1.0 |
|
top_p: Optional[float] = 1.0 |
|
n: Optional[int] = 1 |
|
stop: Optional[Union[str, List[str]]] = None |
|
presence_penalty: Optional[float] = 0.0 |
|
frequency_penalty: Optional[float] = 0.0 |
|
|
|
|
|
def build_payload(chat: ChatRequest): |
|
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.") |
|
filtered_messages = [msg for msg in chat.messages if msg.role != "system"] |
|
payload_messages = [{"role": "system", "content": system_prompt}] + [ |
|
{"role": msg.role, "content": msg.content} for msg in filtered_messages |
|
] |
|
return { |
|
"model": BACKEND_MODEL, |
|
"messages": payload_messages, |
|
"stream": chat.stream, |
|
"temperature": chat.temperature, |
|
"top_p": chat.top_p, |
|
"n": chat.n, |
|
"stop": chat.stop, |
|
"presence_penalty": chat.presence_penalty, |
|
"frequency_penalty": chat.frequency_penalty |
|
} |
|
|
|
|
|
def stream_generator(requested_model: str, payload: dict, headers: dict): |
|
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r: |
|
for line in r.iter_lines(decode_unicode=False): |
|
if not line: |
|
continue |
|
if line.startswith(b"data:"): |
|
content = line[6:].strip() |
|
if content == b"[DONE]": |
|
yield b"data: [DONE]\n\n" |
|
continue |
|
try: |
|
json_obj = json.loads(content.decode("utf-8")) |
|
if json_obj.get("model") == BACKEND_MODEL: |
|
json_obj["model"] = requested_model |
|
yield f"data: {json.dumps(json_obj)}\n\n".encode("utf-8") |
|
except json.JSONDecodeError: |
|
logger.warning("Invalid JSON in stream chunk: %s", content) |
|
else: |
|
logger.debug("Non-data stream line skipped: %s", line) |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
async def proxy_chat(request: Request): |
|
try: |
|
body = await request.json() |
|
chat_request = ChatRequest(**body) |
|
payload = build_payload(chat_request) |
|
|
|
headers = { |
|
"Authorization": f"Bearer {API_KEY}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
if chat_request.stream: |
|
return StreamingResponse( |
|
stream_generator(chat_request.model, payload, headers), |
|
media_type="text/event-stream" |
|
) |
|
else: |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
response.raise_for_status() |
|
data = response.json() |
|
if "model" in data and data["model"] == BACKEND_MODEL: |
|
data["model"] = chat_request.model |
|
return JSONResponse(content=data) |
|
|
|
except Exception as e: |
|
logger.error("Error in /v1/chat/completions: %s", str(e)) |
|
return JSONResponse( |
|
content={"error": "Internal server error."}, |
|
status_code=500 |
|
) |
|
|