Update main.py
Browse files
main.py
CHANGED
@@ -56,28 +56,28 @@ def build_payload(chat: ChatRequest):
|
|
56 |
"frequency_penalty": chat.frequency_penalty
|
57 |
}
|
58 |
|
59 |
-
#
|
60 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
61 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
62 |
-
for line in r.iter_lines(decode_unicode=
|
63 |
if not line:
|
64 |
continue
|
65 |
-
if line.startswith("data:"):
|
66 |
content = line[6:].strip()
|
67 |
-
if content == "[DONE]":
|
68 |
-
yield "data: [DONE]\n\n"
|
69 |
continue
|
70 |
try:
|
71 |
-
json_obj = json.loads(content)
|
72 |
if json_obj.get("model") == BACKEND_MODEL:
|
73 |
json_obj["model"] = requested_model
|
74 |
-
yield "data:
|
75 |
except json.JSONDecodeError:
|
76 |
logger.warning("Invalid JSON in stream chunk: %s", content)
|
77 |
else:
|
78 |
logger.debug("Non-data stream line skipped: %s", line)
|
79 |
|
80 |
-
# Main
|
81 |
@app.post("/v1/chat/completions")
|
82 |
async def proxy_chat(request: Request):
|
83 |
try:
|
@@ -93,19 +93,15 @@ async def proxy_chat(request: Request):
|
|
93 |
if chat_request.stream:
|
94 |
return StreamingResponse(
|
95 |
stream_generator(chat_request.model, payload, headers),
|
96 |
-
media_type="text/event-stream
|
97 |
-
headers={"Content-Type": "text/event-stream; charset=utf-8"}
|
98 |
)
|
99 |
else:
|
100 |
response = requests.post(API_URL, headers=headers, json=payload)
|
|
|
101 |
data = response.json()
|
102 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
103 |
data["model"] = chat_request.model
|
104 |
-
return JSONResponse(
|
105 |
-
content=data,
|
106 |
-
media_type="application/json; charset=utf-8",
|
107 |
-
headers={"Content-Type": "application/json; charset=utf-8"}
|
108 |
-
)
|
109 |
|
110 |
except Exception as e:
|
111 |
logger.error("Error in /v1/chat/completions: %s", str(e))
|
|
|
56 |
"frequency_penalty": chat.frequency_penalty
|
57 |
}
|
58 |
|
59 |
+
# Stream generator without forcing UTF-8
|
60 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
61 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
62 |
+
for line in r.iter_lines(decode_unicode=False): # Keep as bytes
|
63 |
if not line:
|
64 |
continue
|
65 |
+
if line.startswith(b"data:"):
|
66 |
content = line[6:].strip()
|
67 |
+
if content == b"[DONE]":
|
68 |
+
yield b"data: [DONE]\n\n"
|
69 |
continue
|
70 |
try:
|
71 |
+
json_obj = json.loads(content.decode("utf-8"))
|
72 |
if json_obj.get("model") == BACKEND_MODEL:
|
73 |
json_obj["model"] = requested_model
|
74 |
+
yield f"data: {json.dumps(json_obj)}\n\n".encode("utf-8")
|
75 |
except json.JSONDecodeError:
|
76 |
logger.warning("Invalid JSON in stream chunk: %s", content)
|
77 |
else:
|
78 |
logger.debug("Non-data stream line skipped: %s", line)
|
79 |
|
80 |
+
# Main endpoint
|
81 |
@app.post("/v1/chat/completions")
|
82 |
async def proxy_chat(request: Request):
|
83 |
try:
|
|
|
93 |
if chat_request.stream:
|
94 |
return StreamingResponse(
|
95 |
stream_generator(chat_request.model, payload, headers),
|
96 |
+
media_type="text/event-stream"
|
|
|
97 |
)
|
98 |
else:
|
99 |
response = requests.post(API_URL, headers=headers, json=payload)
|
100 |
+
response.raise_for_status() # Raise error for bad responses
|
101 |
data = response.json()
|
102 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
103 |
data["model"] = chat_request.model
|
104 |
+
return JSONResponse(content=data)
|
|
|
|
|
|
|
|
|
105 |
|
106 |
except Exception as e:
|
107 |
logger.error("Error in /v1/chat/completions: %s", str(e))
|