rkihacker commited on
Commit
06ea63c
·
verified ·
1 Parent(s): df657b2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -15
main.py CHANGED
@@ -56,28 +56,28 @@ def build_payload(chat: ChatRequest):
56
  "frequency_penalty": chat.frequency_penalty
57
  }
58
 
59
- # Streaming chunk handler with model replacement and UTF-8 fix
60
  def stream_generator(requested_model: str, payload: dict, headers: dict):
61
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
62
- for line in r.iter_lines(decode_unicode=True):
63
  if not line:
64
  continue
65
- if line.startswith("data:"):
66
  content = line[6:].strip()
67
- if content == "[DONE]":
68
- yield "data: [DONE]\n\n"
69
  continue
70
  try:
71
- json_obj = json.loads(content)
72
  if json_obj.get("model") == BACKEND_MODEL:
73
  json_obj["model"] = requested_model
74
- yield "data: " + json.dumps(json_obj, ensure_ascii=False) + "\n\n"
75
  except json.JSONDecodeError:
76
  logger.warning("Invalid JSON in stream chunk: %s", content)
77
  else:
78
  logger.debug("Non-data stream line skipped: %s", line)
79
 
80
- # Main API endpoint
81
  @app.post("/v1/chat/completions")
82
  async def proxy_chat(request: Request):
83
  try:
@@ -93,19 +93,15 @@ async def proxy_chat(request: Request):
93
  if chat_request.stream:
94
  return StreamingResponse(
95
  stream_generator(chat_request.model, payload, headers),
96
- media_type="text/event-stream; charset=utf-8",
97
- headers={"Content-Type": "text/event-stream; charset=utf-8"}
98
  )
99
  else:
100
  response = requests.post(API_URL, headers=headers, json=payload)
 
101
  data = response.json()
102
  if "model" in data and data["model"] == BACKEND_MODEL:
103
  data["model"] = chat_request.model
104
- return JSONResponse(
105
- content=data,
106
- media_type="application/json; charset=utf-8",
107
- headers={"Content-Type": "application/json; charset=utf-8"}
108
- )
109
 
110
  except Exception as e:
111
  logger.error("Error in /v1/chat/completions: %s", str(e))
 
56
  "frequency_penalty": chat.frequency_penalty
57
  }
58
 
59
+ # Stream generator without forcing UTF-8
60
  def stream_generator(requested_model: str, payload: dict, headers: dict):
61
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
62
+ for line in r.iter_lines(decode_unicode=False): # Keep as bytes
63
  if not line:
64
  continue
65
+ if line.startswith(b"data:"):
66
  content = line[6:].strip()
67
+ if content == b"[DONE]":
68
+ yield b"data: [DONE]\n\n"
69
  continue
70
  try:
71
+ json_obj = json.loads(content.decode("utf-8"))
72
  if json_obj.get("model") == BACKEND_MODEL:
73
  json_obj["model"] = requested_model
74
+ yield f"data: {json.dumps(json_obj)}\n\n".encode("utf-8")
75
  except json.JSONDecodeError:
76
  logger.warning("Invalid JSON in stream chunk: %s", content)
77
  else:
78
  logger.debug("Non-data stream line skipped: %s", line)
79
 
80
+ # Main endpoint
81
  @app.post("/v1/chat/completions")
82
  async def proxy_chat(request: Request):
83
  try:
 
93
  if chat_request.stream:
94
  return StreamingResponse(
95
  stream_generator(chat_request.model, payload, headers),
96
+ media_type="text/event-stream"
 
97
  )
98
  else:
99
  response = requests.post(API_URL, headers=headers, json=payload)
100
+ response.raise_for_status() # Raise error for bad responses
101
  data = response.json()
102
  if "model" in data and data["model"] == BACKEND_MODEL:
103
  data["model"] = chat_request.model
104
+ return JSONResponse(content=data)
 
 
 
 
105
 
106
  except Exception as e:
107
  logger.error("Error in /v1/chat/completions: %s", str(e))