rkihacker commited on
Commit
36f72ba
·
verified ·
1 Parent(s): a7b9a59

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -12
main.py CHANGED
@@ -18,7 +18,7 @@ API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
18
  BACKEND_MODEL = "pixtral-large-latest"
19
 
20
  # Load model -> system prompt mappings
21
- with open("model_map.json", "r") as f:
22
  MODEL_PROMPTS = json.load(f)
23
 
24
  # Request schema
@@ -39,17 +39,12 @@ class ChatRequest(BaseModel):
39
 
40
  # Construct payload with enforced system prompt
41
  def build_payload(chat: ChatRequest):
42
- # Use internal system prompt
43
  system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
44
-
45
- # Remove any user-provided system messages
46
  filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
47
-
48
- # Insert enforced system prompt
49
  payload_messages = [{"role": "system", "content": system_prompt}] + [
50
  {"role": msg.role, "content": msg.content} for msg in filtered_messages
51
  ]
52
-
53
  return {
54
  "model": BACKEND_MODEL,
55
  "messages": payload_messages,
@@ -62,7 +57,7 @@ def build_payload(chat: ChatRequest):
62
  "frequency_penalty": chat.frequency_penalty
63
  }
64
 
65
- # Streaming response handler
66
  def stream_generator(requested_model: str, payload: dict, headers: dict):
67
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
68
  for line in r.iter_lines(decode_unicode=True):
@@ -77,13 +72,13 @@ def stream_generator(requested_model: str, payload: dict, headers: dict):
77
  json_obj = json.loads(content)
78
  if json_obj.get("model") == BACKEND_MODEL:
79
  json_obj["model"] = requested_model
80
- yield f"data: {json.dumps(json_obj)}\n\n"
81
  except json.JSONDecodeError:
82
  logger.warning("Invalid JSON in stream chunk: %s", content)
83
  else:
84
  logger.debug("Non-data stream line skipped: %s", line)
85
 
86
- # Proxy endpoint
87
  @app.post("/v1/chat/completions")
88
  async def proxy_chat(request: Request):
89
  try:
@@ -99,14 +94,15 @@ async def proxy_chat(request: Request):
99
  if chat_request.stream:
100
  return StreamingResponse(
101
  stream_generator(chat_request.model, payload, headers),
102
- media_type="text/event-stream"
 
103
  )
104
  else:
105
  response = requests.post(API_URL, headers=headers, json=payload)
106
  data = response.json()
107
  if "model" in data and data["model"] == BACKEND_MODEL:
108
  data["model"] = chat_request.model
109
- return JSONResponse(content=data)
110
 
111
  except Exception as e:
112
  logger.error("Error in /v1/chat/completions: %s", str(e))
 
18
  BACKEND_MODEL = "pixtral-large-latest"
19
 
20
  # Load model -> system prompt mappings
21
+ with open("model_map.json", "r", encoding="utf-8") as f:
22
  MODEL_PROMPTS = json.load(f)
23
 
24
  # Request schema
 
39
 
40
  # Construct payload with enforced system prompt
41
  def build_payload(chat: ChatRequest):
 
42
  system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
43
+ # Strip user system messages
 
44
  filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
 
 
45
  payload_messages = [{"role": "system", "content": system_prompt}] + [
46
  {"role": msg.role, "content": msg.content} for msg in filtered_messages
47
  ]
 
48
  return {
49
  "model": BACKEND_MODEL,
50
  "messages": payload_messages,
 
57
  "frequency_penalty": chat.frequency_penalty
58
  }
59
 
60
+ # Properly streamed UTF-8 chunks with model rewrite
61
  def stream_generator(requested_model: str, payload: dict, headers: dict):
62
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
63
  for line in r.iter_lines(decode_unicode=True):
 
72
  json_obj = json.loads(content)
73
  if json_obj.get("model") == BACKEND_MODEL:
74
  json_obj["model"] = requested_model
75
+ yield f"data: {json.dumps(json_obj, ensure_ascii=False)}\n\n"
76
  except json.JSONDecodeError:
77
  logger.warning("Invalid JSON in stream chunk: %s", content)
78
  else:
79
  logger.debug("Non-data stream line skipped: %s", line)
80
 
81
+ # Main endpoint
82
  @app.post("/v1/chat/completions")
83
  async def proxy_chat(request: Request):
84
  try:
 
94
  if chat_request.stream:
95
  return StreamingResponse(
96
  stream_generator(chat_request.model, payload, headers),
97
+ media_type="text/event-stream; charset=utf-8",
98
+ headers={"Content-Type": "text/event-stream; charset=utf-8"}
99
  )
100
  else:
101
  response = requests.post(API_URL, headers=headers, json=payload)
102
  data = response.json()
103
  if "model" in data and data["model"] == BACKEND_MODEL:
104
  data["model"] = chat_request.model
105
+ return JSONResponse(content=data, media_type="application/json; charset=utf-8")
106
 
107
  except Exception as e:
108
  logger.error("Error in /v1/chat/completions: %s", str(e))