Update main.py
Browse files
main.py
CHANGED
@@ -17,7 +17,7 @@ API_URL = "https://api.typegpt.net/v1/chat/completions"
|
|
17 |
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
|
18 |
BACKEND_MODEL = "pixtral-large-latest"
|
19 |
|
20 |
-
# Load model
|
21 |
with open("model_map.json", "r") as f:
|
22 |
MODEL_PROMPTS = json.load(f)
|
23 |
|
@@ -37,24 +37,32 @@ class ChatRequest(BaseModel):
|
|
37 |
presence_penalty: Optional[float] = 0.0
|
38 |
frequency_penalty: Optional[float] = 0.0
|
39 |
|
40 |
-
#
|
41 |
def build_payload(chat: ChatRequest):
|
|
|
42 |
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return {
|
44 |
"model": BACKEND_MODEL,
|
|
|
45 |
"stream": chat.stream,
|
46 |
"temperature": chat.temperature,
|
47 |
"top_p": chat.top_p,
|
48 |
"n": chat.n,
|
49 |
"stop": chat.stop,
|
50 |
"presence_penalty": chat.presence_penalty,
|
51 |
-
"frequency_penalty": chat.frequency_penalty
|
52 |
-
"messages": [{"role": "system", "content": system_prompt}] + [
|
53 |
-
{"role": msg.role, "content": msg.content} for msg in chat.messages
|
54 |
-
]
|
55 |
}
|
56 |
|
57 |
-
#
|
58 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
59 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
60 |
for line in r.iter_lines(decode_unicode=True):
|
@@ -75,13 +83,14 @@ def stream_generator(requested_model: str, payload: dict, headers: dict):
|
|
75 |
else:
|
76 |
logger.debug("Non-data stream line skipped: %s", line)
|
77 |
|
78 |
-
#
|
79 |
@app.post("/v1/chat/completions")
|
80 |
async def proxy_chat(request: Request):
|
81 |
try:
|
82 |
body = await request.json()
|
83 |
chat_request = ChatRequest(**body)
|
84 |
payload = build_payload(chat_request)
|
|
|
85 |
headers = {
|
86 |
"Authorization": f"Bearer {API_KEY}",
|
87 |
"Content-Type": "application/json"
|
@@ -100,5 +109,5 @@ async def proxy_chat(request: Request):
|
|
100 |
return JSONResponse(content=data)
|
101 |
|
102 |
except Exception as e:
|
103 |
-
logger.error("Error in
|
104 |
return JSONResponse(content={"error": "Internal server error."}, status_code=500)
|
|
|
17 |
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
|
18 |
BACKEND_MODEL = "pixtral-large-latest"
|
19 |
|
20 |
+
# Load model -> system prompt mappings
|
21 |
with open("model_map.json", "r") as f:
|
22 |
MODEL_PROMPTS = json.load(f)
|
23 |
|
|
|
37 |
presence_penalty: Optional[float] = 0.0
|
38 |
frequency_penalty: Optional[float] = 0.0
|
39 |
|
40 |
+
# Construct payload with enforced system prompt
|
41 |
def build_payload(chat: ChatRequest):
|
42 |
+
# Use internal system prompt
|
43 |
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
|
44 |
+
|
45 |
+
# Remove any user-provided system messages
|
46 |
+
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
|
47 |
+
|
48 |
+
# Insert enforced system prompt
|
49 |
+
payload_messages = [{"role": "system", "content": system_prompt}] + [
|
50 |
+
{"role": msg.role, "content": msg.content} for msg in filtered_messages
|
51 |
+
]
|
52 |
+
|
53 |
return {
|
54 |
"model": BACKEND_MODEL,
|
55 |
+
"messages": payload_messages,
|
56 |
"stream": chat.stream,
|
57 |
"temperature": chat.temperature,
|
58 |
"top_p": chat.top_p,
|
59 |
"n": chat.n,
|
60 |
"stop": chat.stop,
|
61 |
"presence_penalty": chat.presence_penalty,
|
62 |
+
"frequency_penalty": chat.frequency_penalty
|
|
|
|
|
|
|
63 |
}
|
64 |
|
65 |
+
# Streaming response handler
|
66 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
67 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
68 |
for line in r.iter_lines(decode_unicode=True):
|
|
|
83 |
else:
|
84 |
logger.debug("Non-data stream line skipped: %s", line)
|
85 |
|
86 |
+
# Proxy endpoint
|
87 |
@app.post("/v1/chat/completions")
|
88 |
async def proxy_chat(request: Request):
|
89 |
try:
|
90 |
body = await request.json()
|
91 |
chat_request = ChatRequest(**body)
|
92 |
payload = build_payload(chat_request)
|
93 |
+
|
94 |
headers = {
|
95 |
"Authorization": f"Bearer {API_KEY}",
|
96 |
"Content-Type": "application/json"
|
|
|
109 |
return JSONResponse(content=data)
|
110 |
|
111 |
except Exception as e:
|
112 |
+
logger.error("Error in /v1/chat/completions: %s", str(e))
|
113 |
return JSONResponse(content={"error": "Internal server error."}, status_code=500)
|