aeonshift commited on
Commit
6926355
·
verified ·
1 Parent(s): f9046d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -37
app.py CHANGED
@@ -121,22 +121,16 @@ transport = SseServerTransport("/airtable/mcp")
121
  @app.get("/airtable/mcp")
122
  async def handle_sse(request: Request):
123
  logger.debug("Handling SSE connection request")
124
- session_id = None # We'll extract this from the endpoint event
125
  async def sse_writer():
126
  nonlocal session_id
127
  logger.debug("Starting SSE writer")
128
  async with sse_stream_writer, write_stream_reader:
129
- # Since we can't iterate over sse_stream_writer.stream directly in this context,
130
- # we'll rely on the session_id being set by SseServerTransport
131
  endpoint_data = "/airtable/mcp?session_id={session_id}"
132
  await sse_stream_writer.send(
133
  {"event": "endpoint", "data": endpoint_data}
134
  )
135
  logger.debug(f"Sent endpoint event: {endpoint_data}")
136
- # The session_id is set by SseServerTransport and should be available after connect_sse
137
- # We need to extract it from the endpoint_data after substitution
138
- # SseServerTransport substitutes {session_id} with the actual session_id
139
- # We'll extract it from the logs or directly from transport if possible
140
  async for session_message in write_stream_reader:
141
  logger.debug(f"Sending message via SSE: {session_message}")
142
  message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
@@ -146,8 +140,11 @@ async def handle_sse(request: Request):
146
  endpoint_url = message.get("data", "")
147
  if "session_id=" in endpoint_url:
148
  session_id = endpoint_url.split("session_id=")[1]
149
- write_streams[session_id] = write_stream
150
- logger.debug(f"Extracted session_id: {session_id}")
 
 
 
151
  await sse_stream_writer.send(
152
  {
153
  "event": "message",
@@ -160,19 +157,19 @@ async def handle_sse(request: Request):
160
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
161
  read_stream, write_stream = streams
162
  write_stream_reader = write_stream # Since streams are MemoryObject streams
163
- # Extract session_id directly from transport's internal state
164
- # SseServerTransport creates a session_id during connect_sse
165
- # We can access it by inspecting transport's state
166
- # Unfortunately, transport's internals are not directly accessible
167
- # We'll use a temporary workaround by storing the write_stream with a placeholder
168
- # and update it once we extract the session_id
169
- placeholder_id = str(id(write_stream))
170
  write_streams[placeholder_id] = write_stream
171
  logger.debug(f"Stored write_stream with placeholder_id: {placeholder_id}")
172
  logger.debug("Running MCP server with streams")
173
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
174
  except Exception as e:
175
  logger.error(f"Error in handle_sse: {str(e)}")
 
 
 
 
 
176
  raise
177
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
178
 
@@ -233,30 +230,37 @@ async def handle_post_message(request: Request):
233
  }
234
  logger.debug(f"Manual tools/list response: {response}")
235
  response_data = json.dumps(response)
236
- # Send to all active write_streams (temporary workaround)
237
  sent = False
238
- for sid, ws in list(write_streams.items()):
 
 
239
  try:
240
- if sid == session_id:
241
- await ws.send({
242
- "event": "message",
243
- "data": response_data
244
- })
245
- logger.debug(f"Sent tools/list response to session {sid}")
246
- sent = True
247
- elif sid.startswith("placeholder_"):
248
- # Update the placeholder with the real session_id
249
- write_streams[session_id] = ws
250
- write_streams.pop(sid, None)
251
- await ws.send({
252
- "event": "message",
253
- "data": response_data
254
- })
255
- logger.debug(f"Updated placeholder {sid} to session_id {session_id} and sent tools/list response")
256
- sent = True
257
  except Exception as e:
258
- logger.error(f"Error sending to session {sid}: {str(e)}")
259
- write_streams.pop(sid, None) # Remove closed streams
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  if not sent:
261
  logger.warning(f"Failed to send tools/list response: no active write_streams found")
262
  return Response(status_code=202)
 
121
  @app.get("/airtable/mcp")
122
  async def handle_sse(request: Request):
123
  logger.debug("Handling SSE connection request")
124
+ session_id = None # We'll extract this later
125
  async def sse_writer():
126
  nonlocal session_id
127
  logger.debug("Starting SSE writer")
128
  async with sse_stream_writer, write_stream_reader:
 
 
129
  endpoint_data = "/airtable/mcp?session_id={session_id}"
130
  await sse_stream_writer.send(
131
  {"event": "endpoint", "data": endpoint_data}
132
  )
133
  logger.debug(f"Sent endpoint event: {endpoint_data}")
 
 
 
 
134
  async for session_message in write_stream_reader:
135
  logger.debug(f"Sending message via SSE: {session_message}")
136
  message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
 
140
  endpoint_url = message.get("data", "")
141
  if "session_id=" in endpoint_url:
142
  session_id = endpoint_url.split("session_id=")[1]
143
+ # Update write_streams with the real session_id
144
+ placeholder_id = f"placeholder_{id(write_stream)}"
145
+ if placeholder_id in write_streams:
146
+ write_streams[session_id] = write_streams.pop(placeholder_id)
147
+ logger.debug(f"Updated placeholder {placeholder_id} to session_id {session_id}")
148
  await sse_stream_writer.send(
149
  {
150
  "event": "message",
 
157
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
158
  read_stream, write_stream = streams
159
  write_stream_reader = write_stream # Since streams are MemoryObject streams
160
+ # Store the write_stream with a placeholder ID
161
+ placeholder_id = f"placeholder_{id(write_stream)}"
 
 
 
 
 
162
  write_streams[placeholder_id] = write_stream
163
  logger.debug(f"Stored write_stream with placeholder_id: {placeholder_id}")
164
  logger.debug("Running MCP server with streams")
165
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
166
  except Exception as e:
167
  logger.error(f"Error in handle_sse: {str(e)}")
168
+ # Clean up write_streams on error
169
+ placeholder_id = f"placeholder_{id(write_stream)}"
170
+ write_streams.pop(placeholder_id, None)
171
+ if session_id:
172
+ write_streams.pop(session_id, None)
173
  raise
174
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
175
 
 
230
  }
231
  logger.debug(f"Manual tools/list response: {response}")
232
  response_data = json.dumps(response)
 
233
  sent = False
234
+ # First, try the session_id directly
235
+ if session_id in write_streams:
236
+ write_stream = write_streams[session_id]
237
  try:
238
+ await write_stream.send({
239
+ "event": "message",
240
+ "data": response_data
241
+ })
242
+ logger.debug(f"Sent tools/list response to session {session_id}")
243
+ sent = True
 
 
 
 
 
 
 
 
 
 
 
244
  except Exception as e:
245
+ logger.error(f"Error sending to session {session_id}: {str(e)}")
246
+ write_streams.pop(session_id, None)
247
+ # If not found, look for a placeholder ID and update it
248
+ if not sent:
249
+ for sid, ws in list(write_streams.items()):
250
+ if sid.startswith("placeholder_"):
251
+ try:
252
+ write_streams[session_id] = ws
253
+ write_streams.pop(sid, None)
254
+ await ws.send({
255
+ "event": "message",
256
+ "data": response_data
257
+ })
258
+ logger.debug(f"Updated placeholder {sid} to session_id {session_id} and sent tools/list response")
259
+ sent = True
260
+ break
261
+ except Exception as e:
262
+ logger.error(f"Error sending to placeholder {sid}: {str(e)}")
263
+ write_streams.pop(sid, None)
264
  if not sent:
265
  logger.warning(f"Failed to send tools/list response: no active write_streams found")
266
  return Response(status_code=202)