File size: 4,146 Bytes
daa1bb4 f7f30cc 4d2b16a f7f30cc bcdcebb daa1bb4 89d8cc9 daa1bb4 bcdcebb ad95a9a 89d8cc9 ad95a9a bcdcebb ad95a9a 36f72ba daa1bb4 bcdcebb 89d8cc9 daa1bb4 bcdcebb daa1bb4 f7f30cc bcdcebb ad95a9a f7f30cc a7b9a59 f7f30cc bcdcebb a7b9a59 f7f30cc a7b9a59 bcdcebb ad95a9a 4d2b16a f7f30cc 4d2b16a 89d8cc9 4d2b16a 89d8cc9 4d2b16a 89d8cc9 4d2b16a ad95a9a 4d2b16a 89d8cc9 f7f30cc ad95a9a f7f30cc 89d8cc9 a7b9a59 89d8cc9 36f72ba 89d8cc9 ad95a9a bcdcebb 89d8cc9 a7b9a59 ad95a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging
app = FastAPI()
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")
# Configuration
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"
# Load system prompt mappings
with open("model_map.json", "r", encoding="utf-8") as f:
MODEL_PROMPTS = json.load(f)
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
# Build request to backend with injected system prompt
def build_payload(chat: ChatRequest):
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
payload_messages = [{"role": "system", "content": system_prompt}] + [
{"role": msg.role, "content": msg.content} for msg in filtered_messages
]
return {
"model": BACKEND_MODEL,
"messages": payload_messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
# Streaming chunk handler with model replacement and UTF-8 fix
def stream_generator(requested_model: str, payload: dict, headers: dict):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("data:"):
content = line[6:].strip()
if content == "[DONE]":
yield "data: [DONE]\n\n"
continue
try:
json_obj = json.loads(content)
if json_obj.get("model") == BACKEND_MODEL:
json_obj["model"] = requested_model
yield "data: " + json.dumps(json_obj, ensure_ascii=False) + "\n\n"
except json.JSONDecodeError:
logger.warning("Invalid JSON in stream chunk: %s", content)
else:
logger.debug("Non-data stream line skipped: %s", line)
# Main API endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
try:
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream; charset=utf-8",
headers={"Content-Type": "text/event-stream; charset=utf-8"}
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
data = response.json()
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(
content=data,
media_type="application/json; charset=utf-8",
headers={"Content-Type": "application/json; charset=utf-8"}
)
except Exception as e:
logger.error("Error in /v1/chat/completions: %s", str(e))
return JSONResponse(
content={"error": "Internal server error."},
status_code=500
)
|