Update main.py
Browse files
main.py
CHANGED
@@ -18,7 +18,7 @@ API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
|
|
18 |
BACKEND_MODEL = "pixtral-large-latest"
|
19 |
|
20 |
# Load model -> system prompt mappings
|
21 |
-
with open("model_map.json", "r") as f:
|
22 |
MODEL_PROMPTS = json.load(f)
|
23 |
|
24 |
# Request schema
|
@@ -39,17 +39,12 @@ class ChatRequest(BaseModel):
|
|
39 |
|
40 |
# Construct payload with enforced system prompt
|
41 |
def build_payload(chat: ChatRequest):
|
42 |
-
# Use internal system prompt
|
43 |
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
|
44 |
-
|
45 |
-
# Remove any user-provided system messages
|
46 |
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
|
47 |
-
|
48 |
-
# Insert enforced system prompt
|
49 |
payload_messages = [{"role": "system", "content": system_prompt}] + [
|
50 |
{"role": msg.role, "content": msg.content} for msg in filtered_messages
|
51 |
]
|
52 |
-
|
53 |
return {
|
54 |
"model": BACKEND_MODEL,
|
55 |
"messages": payload_messages,
|
@@ -62,7 +57,7 @@ def build_payload(chat: ChatRequest):
|
|
62 |
"frequency_penalty": chat.frequency_penalty
|
63 |
}
|
64 |
|
65 |
-
#
|
66 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
67 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
68 |
for line in r.iter_lines(decode_unicode=True):
|
@@ -77,13 +72,13 @@ def stream_generator(requested_model: str, payload: dict, headers: dict):
|
|
77 |
json_obj = json.loads(content)
|
78 |
if json_obj.get("model") == BACKEND_MODEL:
|
79 |
json_obj["model"] = requested_model
|
80 |
-
yield f"data: {json.dumps(json_obj)}\n\n"
|
81 |
except json.JSONDecodeError:
|
82 |
logger.warning("Invalid JSON in stream chunk: %s", content)
|
83 |
else:
|
84 |
logger.debug("Non-data stream line skipped: %s", line)
|
85 |
|
86 |
-
#
|
87 |
@app.post("/v1/chat/completions")
|
88 |
async def proxy_chat(request: Request):
|
89 |
try:
|
@@ -99,14 +94,15 @@ async def proxy_chat(request: Request):
|
|
99 |
if chat_request.stream:
|
100 |
return StreamingResponse(
|
101 |
stream_generator(chat_request.model, payload, headers),
|
102 |
-
media_type="text/event-stream"
|
|
|
103 |
)
|
104 |
else:
|
105 |
response = requests.post(API_URL, headers=headers, json=payload)
|
106 |
data = response.json()
|
107 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
108 |
data["model"] = chat_request.model
|
109 |
-
return JSONResponse(content=data)
|
110 |
|
111 |
except Exception as e:
|
112 |
logger.error("Error in /v1/chat/completions: %s", str(e))
|
|
|
18 |
BACKEND_MODEL = "pixtral-large-latest"
|
19 |
|
20 |
# Load model -> system prompt mappings
|
21 |
+
with open("model_map.json", "r", encoding="utf-8") as f:
|
22 |
MODEL_PROMPTS = json.load(f)
|
23 |
|
24 |
# Request schema
|
|
|
39 |
|
40 |
# Construct payload with enforced system prompt
|
41 |
def build_payload(chat: ChatRequest):
|
|
|
42 |
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
|
43 |
+
# Strip user system messages
|
|
|
44 |
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
|
|
|
|
|
45 |
payload_messages = [{"role": "system", "content": system_prompt}] + [
|
46 |
{"role": msg.role, "content": msg.content} for msg in filtered_messages
|
47 |
]
|
|
|
48 |
return {
|
49 |
"model": BACKEND_MODEL,
|
50 |
"messages": payload_messages,
|
|
|
57 |
"frequency_penalty": chat.frequency_penalty
|
58 |
}
|
59 |
|
60 |
+
# Properly streamed UTF-8 chunks with model rewrite
|
61 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
62 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
63 |
for line in r.iter_lines(decode_unicode=True):
|
|
|
72 |
json_obj = json.loads(content)
|
73 |
if json_obj.get("model") == BACKEND_MODEL:
|
74 |
json_obj["model"] = requested_model
|
75 |
+
yield f"data: {json.dumps(json_obj, ensure_ascii=False)}\n\n"
|
76 |
except json.JSONDecodeError:
|
77 |
logger.warning("Invalid JSON in stream chunk: %s", content)
|
78 |
else:
|
79 |
logger.debug("Non-data stream line skipped: %s", line)
|
80 |
|
81 |
+
# Main endpoint
|
82 |
@app.post("/v1/chat/completions")
|
83 |
async def proxy_chat(request: Request):
|
84 |
try:
|
|
|
94 |
if chat_request.stream:
|
95 |
return StreamingResponse(
|
96 |
stream_generator(chat_request.model, payload, headers),
|
97 |
+
media_type="text/event-stream; charset=utf-8",
|
98 |
+
headers={"Content-Type": "text/event-stream; charset=utf-8"}
|
99 |
)
|
100 |
else:
|
101 |
response = requests.post(API_URL, headers=headers, json=payload)
|
102 |
data = response.json()
|
103 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
104 |
data["model"] = chat_request.model
|
105 |
+
return JSONResponse(content=data, media_type="application/json; charset=utf-8")
|
106 |
|
107 |
except Exception as e:
|
108 |
logger.error("Error in /v1/chat/completions: %s", str(e))
|