File size: 3,950 Bytes
daa1bb4 f7f30cc 4d2b16a f7f30cc bcdcebb daa1bb4 89d8cc9 daa1bb4 bcdcebb ad95a9a 89d8cc9 ad95a9a bcdcebb 0c75fa7 bcdcebb ad95a9a 36f72ba daa1bb4 bcdcebb 89d8cc9 daa1bb4 bcdcebb daa1bb4 f7f30cc bcdcebb ad95a9a f7f30cc a7b9a59 f7f30cc bcdcebb a7b9a59 f7f30cc a7b9a59 bcdcebb 06ea63c 4d2b16a f7f30cc 06ea63c 89d8cc9 06ea63c 4d2b16a 06ea63c 89d8cc9 4d2b16a 06ea63c 89d8cc9 4d2b16a 06ea63c 4d2b16a 89d8cc9 f7f30cc 06ea63c f7f30cc 89d8cc9 a7b9a59 89d8cc9 06ea63c 89d8cc9 06ea63c 89d8cc9 06ea63c bcdcebb 89d8cc9 a7b9a59 ad95a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging
app = FastAPI()
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")
# Configuration
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "gpt-4o-mini"
# Load system prompt mappings
with open("model_map.json", "r", encoding="utf-8") as f:
MODEL_PROMPTS = json.load(f)
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
# Build request to backend with injected system prompt
def build_payload(chat: ChatRequest):
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
payload_messages = [{"role": "system", "content": system_prompt}] + [
{"role": msg.role, "content": msg.content} for msg in filtered_messages
]
return {
"model": BACKEND_MODEL,
"messages": payload_messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
# Stream generator without forcing UTF-8
def stream_generator(requested_model: str, payload: dict, headers: dict):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines(decode_unicode=False): # Keep as bytes
if not line:
continue
if line.startswith(b"data:"):
content = line[6:].strip()
if content == b"[DONE]":
yield b"data: [DONE]\n\n"
continue
try:
json_obj = json.loads(content.decode("utf-8"))
if json_obj.get("model") == BACKEND_MODEL:
json_obj["model"] = requested_model
yield f"data: {json.dumps(json_obj)}\n\n".encode("utf-8")
except json.JSONDecodeError:
logger.warning("Invalid JSON in stream chunk: %s", content)
else:
logger.debug("Non-data stream line skipped: %s", line)
# Main endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
try:
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream"
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # Raise error for bad responses
data = response.json()
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(content=data)
except Exception as e:
logger.error("Error in /v1/chat/completions: %s", str(e))
return JSONResponse(
content={"error": "Internal server error."},
status_code=500
)
|