aeonshift commited on
Commit
170adce
·
verified ·
1 Parent(s): e851628

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +8 -10
server.py CHANGED
@@ -39,23 +39,21 @@ async def handle_sse(request: Request):
39
  async for session_message in write_stream_reader:
40
  # Handle both dict and SessionMessage objects
41
  if isinstance(session_message, dict):
42
- message_data = json.dumps(session_message)
 
43
  else:
44
- message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
45
- message = json.loads(message_data)
46
- if not session_id and message.get("event") == "endpoint":
47
- endpoint_url = message.get("data", "")
 
48
  if "session_id=" in endpoint_url:
49
  session_id = endpoint_url.split("session_id=")[1]
50
  placeholder_id = f"placeholder_{id(write_stream)}"
51
  if placeholder_id in write_streams:
52
  write_streams[session_id] = write_streams.pop(placeholder_id)
53
  logger.info(f"Updated placeholder {placeholder_id} to session_id {session_id}")
54
- # Send the event as a raw SSE event
55
- await sse_stream_writer.send({
56
- "event": message.get("event", "message"),
57
- "data": message.get("data", message_data)
58
- })
59
 
60
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
61
  async with transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
 
39
  async for session_message in write_stream_reader:
40
  # Handle both dict and SessionMessage objects
41
  if isinstance(session_message, dict):
42
+ event = session_message.get("event", "message")
43
+ data = session_message.get("data", json.dumps(session_message))
44
  else:
45
+ event = "message"
46
+ data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
47
+ message = json.loads(data) if isinstance(data, str) else data
48
+ if not session_id and event == "endpoint":
49
+ endpoint_url = data
50
  if "session_id=" in endpoint_url:
51
  session_id = endpoint_url.split("session_id=")[1]
52
  placeholder_id = f"placeholder_{id(write_stream)}"
53
  if placeholder_id in write_streams:
54
  write_streams[session_id] = write_streams.pop(placeholder_id)
55
  logger.info(f"Updated placeholder {placeholder_id} to session_id {session_id}")
56
+ await sse_stream_writer.send({"event": event, "data": data})
 
 
 
 
57
 
58
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
59
  async with transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):