aeonshift commited on
Commit
e47e562
·
verified ·
1 Parent(s): 170adce

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +10 -7
server.py CHANGED
@@ -18,8 +18,9 @@ app = FastAPI()
18
  server = Server(name="airtable-mcp")
19
  transport = SseServerTransport("/airtable/mcp")
20
 
21
- # Store write streams for each session ID
22
  write_streams = {}
 
23
 
24
  # Configure environment variables (for logging purposes)
25
  token = os.getenv("AIRTABLE_API_TOKEN")
@@ -37,7 +38,6 @@ async def handle_sse(request: Request):
37
  endpoint_data = f"/airtable/mcp?session_id={{session_id}}"
38
  await sse_stream_writer.send({"event": "endpoint", "data": endpoint_data})
39
  async for session_message in write_stream_reader:
40
- # Handle both dict and SessionMessage objects
41
  if isinstance(session_message, dict):
42
  event = session_message.get("event", "message")
43
  data = session_message.get("data", json.dumps(session_message))
@@ -52,6 +52,7 @@ async def handle_sse(request: Request):
52
  placeholder_id = f"placeholder_{id(write_stream)}"
53
  if placeholder_id in write_streams:
54
  write_streams[session_id] = write_streams.pop(placeholder_id)
 
55
  logger.info(f"Updated placeholder {placeholder_id} to session_id {session_id}")
56
  await sse_stream_writer.send({"event": event, "data": data})
57
 
@@ -78,16 +79,18 @@ async def handle_post(request: Request):
78
 
79
  # Try to find the write_stream, including checking for placeholder IDs
80
  write_stream = write_streams.get(session_id)
 
81
  if not write_stream:
82
  for sid, ws in list(write_streams.items()):
83
  if sid.startswith("placeholder_"):
84
  write_streams[session_id] = ws
85
  write_streams.pop(sid)
86
  write_stream = ws
 
87
  logger.info(f"Associated placeholder {sid} with session_id {session_id}")
88
  break
89
 
90
- if message.get("method") == "tools/list" and write_stream:
91
  response = {
92
  "jsonrpc": "2.0",
93
  "id": message.get("id"),
@@ -100,12 +103,12 @@ async def handle_post(request: Request):
100
  }
101
  }
102
  response_data = json.dumps(response)
103
- await write_stream.send({"event": "message", "data": response_data})
104
- logger.info(f"Sent tools/list response for session {session_id}")
105
  return Response(status_code=202)
106
 
107
- if not write_stream:
108
- logger.error(f"No write_stream found for session_id: {session_id}")
109
  return Response(status_code=202)
110
 
111
  await transport.handle_post_message(request.scope, request.receive, request._send)
 
18
  server = Server(name="airtable-mcp")
19
  transport = SseServerTransport("/airtable/mcp")
20
 
21
+ # Store write streams and SSE writers for each session ID
22
  write_streams = {}
23
+ sse_writers = {}
24
 
25
  # Configure environment variables (for logging purposes)
26
  token = os.getenv("AIRTABLE_API_TOKEN")
 
38
  endpoint_data = f"/airtable/mcp?session_id={{session_id}}"
39
  await sse_stream_writer.send({"event": "endpoint", "data": endpoint_data})
40
  async for session_message in write_stream_reader:
 
41
  if isinstance(session_message, dict):
42
  event = session_message.get("event", "message")
43
  data = session_message.get("data", json.dumps(session_message))
 
52
  placeholder_id = f"placeholder_{id(write_stream)}"
53
  if placeholder_id in write_streams:
54
  write_streams[session_id] = write_streams.pop(placeholder_id)
55
+ sse_writers[session_id] = sse_stream_writer
56
  logger.info(f"Updated placeholder {placeholder_id} to session_id {session_id}")
57
  await sse_stream_writer.send({"event": event, "data": data})
58
 
 
79
 
80
  # Try to find the write_stream, including checking for placeholder IDs
81
  write_stream = write_streams.get(session_id)
82
+ sse_writer = sse_writers.get(session_id)
83
  if not write_stream:
84
  for sid, ws in list(write_streams.items()):
85
  if sid.startswith("placeholder_"):
86
  write_streams[session_id] = ws
87
  write_streams.pop(sid)
88
  write_stream = ws
89
+ sse_writers[session_id] = sse_writers.get(sid)
90
  logger.info(f"Associated placeholder {sid} with session_id {session_id}")
91
  break
92
 
93
+ if message.get("method") == "tools/list" and write_stream and sse_writer:
94
  response = {
95
  "jsonrpc": "2.0",
96
  "id": message.get("id"),
 
103
  }
104
  }
105
  response_data = json.dumps(response)
106
+ await sse_writer.send({"event": "message", "data": response_data})
107
+ logger.info(f"Sent tools/list response for session {session_id} via SSE")
108
  return Response(status_code=202)
109
 
110
+ if not write_stream or not sse_writer:
111
+ logger.error(f"No write_stream or SSE writer found for session_id: {session_id}")
112
  return Response(status_code=202)
113
 
114
  await transport.handle_post_message(request.scope, request.receive, request._send)