aeonshift commited on
Commit
dbcccf3
·
verified ·
1 Parent(s): 45eef8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -18
app.py CHANGED
@@ -114,9 +114,6 @@ mcp_server.tools = tools # Set tools as attribute for Deep Agent to discover
114
  # Store write streams for each session ID
115
  write_streams: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {}
116
 
117
- # Store session IDs for each connection
118
- session_ids: Dict[str, str] = {}
119
-
120
  # Initialize SseServerTransport
121
  transport = SseServerTransport("/airtable/mcp")
122
 
@@ -127,8 +124,6 @@ async def handle_sse(request: Request):
127
  async def sse_writer():
128
  logger.debug("Starting SSE writer")
129
  async with sse_stream_writer, write_stream_reader:
130
- # Extract session_id dynamically
131
- session_id = None
132
  endpoint_data = "/airtable/mcp?session_id={session_id}"
133
  await sse_stream_writer.send(
134
  {"event": "endpoint", "data": endpoint_data}
@@ -144,25 +139,19 @@ async def handle_sse(request: Request):
144
  }
145
  )
146
  # Extract session_id from the endpoint event
147
- if not session_id and session_message.message.method == "endpoint":
148
- try:
149
- message = json.loads(message_data)
150
- endpoint_url = message.get("data", "")
151
- if "session_id=" in endpoint_url:
152
- session_id = endpoint_url.split("session_id=")[1]
153
- session_ids[id(write_stream)] = session_id
154
- write_streams[session_id] = write_stream
155
- logger.debug(f"Extracted session_id: {session_id}")
156
- except Exception as e:
157
- logger.error(f"Error extracting session_id: {str(e)}")
158
 
159
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
160
  try:
161
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
162
  read_stream, write_stream = streams
163
  write_stream_reader = write_stream # Since streams are MemoryObject streams
164
- # Use a unique identifier for the stream until session_id is extracted
165
- stream_id = id(write_stream)
166
  logger.debug("Running MCP server with streams")
167
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
168
  except Exception as e:
@@ -232,10 +221,15 @@ async def handle_post_message(request: Request):
232
  )
233
  await write_stream.send(session_message)
234
  return Response(status_code=202)
 
 
 
 
235
  await transport.handle_post_message(request.scope, request.receive, request._send)
236
  logger.debug("POST message handled successfully")
237
  except Exception as e:
238
  logger.error(f"Error handling POST message: {str(e)}")
 
239
  return Response(status_code=202)
240
 
241
  # Health check endpoint
 
114
  # Store write streams for each session ID
115
  write_streams: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {}
116
 
 
 
 
117
  # Initialize SseServerTransport
118
  transport = SseServerTransport("/airtable/mcp")
119
 
 
124
  async def sse_writer():
125
  logger.debug("Starting SSE writer")
126
  async with sse_stream_writer, write_stream_reader:
 
 
127
  endpoint_data = "/airtable/mcp?session_id={session_id}"
128
  await sse_stream_writer.send(
129
  {"event": "endpoint", "data": endpoint_data}
 
139
  }
140
  )
141
  # Extract session_id from the endpoint event
142
+ message = json.loads(message_data)
143
+ if message.get("event") == "endpoint":
144
+ endpoint_url = message.get("data", "")
145
+ if "session_id=" in endpoint_url:
146
+ session_id = endpoint_url.split("session_id=")[1]
147
+ write_streams[session_id] = write_stream
148
+ logger.debug(f"Extracted session_id: {session_id}")
 
 
 
 
149
 
150
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
151
  try:
152
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
153
  read_stream, write_stream = streams
154
  write_stream_reader = write_stream # Since streams are MemoryObject streams
 
 
155
  logger.debug("Running MCP server with streams")
156
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
157
  except Exception as e:
 
221
  )
222
  await write_stream.send(session_message)
223
  return Response(status_code=202)
224
+ # If write_stream is None, log and handle gracefully
225
+ if not write_stream:
226
+ logger.error(f"No write_stream found for session_id: {session_id}")
227
+ return Response(status_code=202)
228
  await transport.handle_post_message(request.scope, request.receive, request._send)
229
  logger.debug("POST message handled successfully")
230
  except Exception as e:
231
  logger.error(f"Error handling POST message: {str(e)}")
232
+ return Response(status_code=202)
233
  return Response(status_code=202)
234
 
235
  # Health check endpoint