aeonshift commited on
Commit
0713a48
·
verified ·
1 Parent(s): 9a21c26

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +17 -1
server.py CHANGED
@@ -46,6 +46,7 @@ async def handle_sse(request: Request):
46
  placeholder_id = f"placeholder_{id(write_stream)}"
47
  if placeholder_id in write_streams:
48
  write_streams[session_id] = write_streams.pop(placeholder_id)
 
49
  await sse_stream_writer.send({"event": "message", "data": message_data})
50
 
51
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
@@ -67,7 +68,19 @@ async def handle_post(request: Request):
67
  body = await request.body()
68
  message = json.loads(body.decode())
69
  session_id = request.query_params.get("session_id")
70
- write_stream = write_streams.get(session_id) if session_id else None
 
 
 
 
 
 
 
 
 
 
 
 
71
  if message.get("method") == "tools/list" and write_stream:
72
  response = {
73
  "jsonrpc": "2.0",
@@ -82,10 +95,13 @@ async def handle_post(request: Request):
82
  }
83
  response_data = json.dumps(response)
84
  await write_stream.send({"event": "message", "data": response_data})
 
85
  return Response(status_code=202)
 
86
  if not write_stream:
87
  logger.error(f"No write_stream found for session_id: {session_id}")
88
  return Response(status_code=202)
 
89
  await transport.handle_post_message(request.scope, request.receive, request._send)
90
  return Response(status_code=202)
91
  except Exception as e:
 
46
  placeholder_id = f"placeholder_{id(write_stream)}"
47
  if placeholder_id in write_streams:
48
  write_streams[session_id] = write_streams.pop(placeholder_id)
49
+ logger.info(f"Updated placeholder {placeholder_id} to session_id {session_id}")
50
  await sse_stream_writer.send({"event": "message", "data": message_data})
51
 
52
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
 
68
  body = await request.body()
69
  message = json.loads(body.decode())
70
  session_id = request.query_params.get("session_id")
71
+ logger.info(f"Received POST with session_id: {session_id}, message: {message}")
72
+
73
+ # Try to find the write_stream, including checking for placeholder IDs
74
+ write_stream = write_streams.get(session_id)
75
+ if not write_stream:
76
+ for sid, ws in list(write_streams.items()):
77
+ if sid.startswith("placeholder_"):
78
+ write_streams[session_id] = ws
79
+ write_streams.pop(sid)
80
+ write_stream = ws
81
+ logger.info(f"Associated placeholder {sid} with session_id {session_id}")
82
+ break
83
+
84
  if message.get("method") == "tools/list" and write_stream:
85
  response = {
86
  "jsonrpc": "2.0",
 
95
  }
96
  response_data = json.dumps(response)
97
  await write_stream.send({"event": "message", "data": response_data})
98
+ logger.info(f"Sent tools/list response for session {session_id}")
99
  return Response(status_code=202)
100
+
101
  if not write_stream:
102
  logger.error(f"No write_stream found for session_id: {session_id}")
103
  return Response(status_code=202)
104
+
105
  await transport.handle_post_message(request.scope, request.receive, request._send)
106
  return Response(status_code=202)
107
  except Exception as e: