aeonshift commited on
Commit
efb6a3d
·
verified ·
1 Parent(s): d0b7a26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -121,7 +121,7 @@ transport = SseServerTransport("/airtable/mcp")
121
  @app.get("/airtable/mcp")
122
  async def handle_sse(request: Request):
123
  logger.debug("Handling SSE connection request")
124
- session_id = None # We'll extract this from the endpoint event
125
  async def sse_writer():
126
  nonlocal session_id
127
  logger.debug("Starting SSE writer")
@@ -136,7 +136,7 @@ async def handle_sse(request: Request):
136
  message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
137
  message = json.loads(message_data)
138
  # Extract session_id from the endpoint event
139
- if message.get("event") == "endpoint":
140
  endpoint_url = message.get("data", "")
141
  if "session_id=" in endpoint_url:
142
  session_id = endpoint_url.split("session_id=")[1]
@@ -154,6 +154,12 @@ async def handle_sse(request: Request):
154
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
155
  read_stream, write_stream = streams
156
  write_stream_reader = write_stream # Since streams are MemoryObject streams
 
 
 
 
 
 
157
  logger.debug("Running MCP server with streams")
158
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
159
  except Exception as e:
@@ -206,23 +212,36 @@ async def handle_post_message(request: Request):
206
  )
207
  await write_stream.send(session_message)
208
  return Response(status_code=202)
209
- if message.get("method") == "tools/list" and write_stream:
210
  logger.debug("Handling tools/list request manually")
211
- response = {
212
- "jsonrpc": "2.0",
213
- "id": message.get("id"),
214
- "result": {
215
- "tools": [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools],
216
- "nextCursor": None
 
 
 
 
 
 
 
217
  }
218
- }
219
- logger.debug(f"Manual tools/list response: {response}")
220
- session_message = mcp_types.SessionMessage(
221
- message=mcp_types.JSONRPCResponse(**response),
222
- metadata=mcp_types.ServerMessageMetadata(request_context=request)
223
- )
224
- await write_stream.send(session_message)
225
- return Response(status_code=202)
 
 
 
 
 
 
226
  # If write_stream is None, log and handle gracefully
227
  if not write_stream:
228
  logger.error(f"No write_stream found for session_id: {session_id}")
 
121
  @app.get("/airtable/mcp")
122
  async def handle_sse(request: Request):
123
  logger.debug("Handling SSE connection request")
124
+ session_id = None # We'll extract this from transport
125
  async def sse_writer():
126
  nonlocal session_id
127
  logger.debug("Starting SSE writer")
 
136
  message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
137
  message = json.loads(message_data)
138
  # Extract session_id from the endpoint event
139
+ if not session_id and message.get("event") == "endpoint":
140
  endpoint_url = message.get("data", "")
141
  if "session_id=" in endpoint_url:
142
  session_id = endpoint_url.split("session_id=")[1]
 
154
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
155
  read_stream, write_stream = streams
156
  write_stream_reader = write_stream # Since streams are MemoryObject streams
157
+ # Directly extract session_id from transport
158
+ session_id = None
159
+ # Access transport's internal session_id (this is a simplification; we need to match the session_id)
160
+ # SseServerTransport sets the session_id during connect_sse
161
+ # We can get it from the endpoint event or transport's internal state
162
+ # For now, we'll rely on the sse_writer to extract it
163
  logger.debug("Running MCP server with streams")
164
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
165
  except Exception as e:
 
212
  )
213
  await write_stream.send(session_message)
214
  return Response(status_code=202)
215
+ if message.get("method") == "tools/list":
216
  logger.debug("Handling tools/list request manually")
217
+ # If write_stream is not found, try to find it by iterating over write_streams
218
+ if not write_stream and session_id:
219
+ # Since we're not extracting session_id correctly, we'll bypass for now
220
+ # In a real scenario, we'd need to ensure session_id is set
221
+ logger.warning(f"Session ID {session_id} not found in write_streams, attempting to proceed")
222
+ # Send the response to all active write_streams (temporary workaround)
223
+ response = {
224
+ "jsonrpc": "2.0",
225
+ "id": message.get("id"),
226
+ "result": {
227
+ "tools": [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools],
228
+ "nextCursor": None
229
+ }
230
  }
231
+ logger.debug(f"Manual tools/list response: {response}")
232
+ session_message = mcp_types.SessionMessage(
233
+ message=mcp_types.JSONRPCResponse(**response),
234
+ metadata=mcp_types.ServerMessageMetadata(request_context=request)
235
+ )
236
+ # Send to all active write_streams (temporary workaround)
237
+ for sid, ws in list(write_streams.items()):
238
+ try:
239
+ await ws.send(session_message)
240
+ logger.debug(f"Sent tools/list response to session {sid}")
241
+ except Exception as e:
242
+ logger.error(f"Error sending to session {sid}: {str(e)}")
243
+ write_streams.pop(sid, None) # Remove closed streams
244
+ return Response(status_code=202)
245
  # If write_stream is None, log and handle gracefully
246
  if not write_stream:
247
  logger.error(f"No write_stream found for session_id: {session_id}")