aeonshift commited on
Commit
f9046d5
·
verified ·
1 Parent(s): f1bc20d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -27
app.py CHANGED
@@ -126,29 +126,49 @@ async def handle_sse(request: Request):
126
  nonlocal session_id
127
  logger.debug("Starting SSE writer")
128
  async with sse_stream_writer, write_stream_reader:
 
 
129
  endpoint_data = "/airtable/mcp?session_id={session_id}"
130
  await sse_stream_writer.send(
131
  {"event": "endpoint", "data": endpoint_data}
132
  )
133
  logger.debug(f"Sent endpoint event: {endpoint_data}")
134
- async for event in sse_stream_writer.stream:
135
- # Log the raw SSE event to understand its format
136
- logger.debug(f"Raw SSE event: {event}")
 
 
 
 
 
137
  # Extract session_id from the endpoint event
138
- if event.get("event") == "endpoint":
139
- endpoint_url = event.get("data", "")
140
  if "session_id=" in endpoint_url:
141
  session_id = endpoint_url.split("session_id=")[1]
142
  write_streams[session_id] = write_stream
143
  logger.debug(f"Extracted session_id: {session_id}")
144
- # Forward the event to the client
145
- await sse_stream_writer.send(event)
 
 
 
 
146
 
147
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
148
  try:
149
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
150
  read_stream, write_stream = streams
151
  write_stream_reader = write_stream # Since streams are MemoryObject streams
 
 
 
 
 
 
 
 
 
152
  logger.debug("Running MCP server with streams")
153
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
154
  except Exception as e:
@@ -195,7 +215,6 @@ async def handle_post_message(request: Request):
195
  }
196
  }
197
  logger.debug(f"Manual initialize response: {response}")
198
- # Send the response directly as an SSE event
199
  response_data = json.dumps(response)
200
  await write_stream.send({
201
  "event": "message",
@@ -204,31 +223,43 @@ async def handle_post_message(request: Request):
204
  return Response(status_code=202)
205
  if message.get("method") == "tools/list":
206
  logger.debug("Handling tools/list request manually")
207
- # If write_stream is not found, try to find it by iterating over write_streams
208
- if not write_stream and session_id:
209
- logger.warning(f"Session ID {session_id} not found in write_streams, attempting to proceed")
210
- response = {
211
- "jsonrpc": "2.0",
212
- "id": message.get("id"),
213
- "result": {
214
- "tools": [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools],
215
- "nextCursor": None
216
- }
217
  }
218
- logger.debug(f"Manual tools/list response: {response}")
219
- response_data = json.dumps(response)
220
- # Send to all active write_streams (temporary workaround)
221
- for sid, ws in list(write_streams.items()):
222
- try:
 
 
 
223
  await ws.send({
224
  "event": "message",
225
  "data": response_data
226
  })
227
  logger.debug(f"Sent tools/list response to session {sid}")
228
- except Exception as e:
229
- logger.error(f"Error sending to session {sid}: {str(e)}")
230
- write_streams.pop(sid, None) # Remove closed streams
231
- return Response(status_code=202)
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  # If write_stream is None, log and handle gracefully
233
  if not write_stream:
234
  logger.error(f"No write_stream found for session_id: {session_id}")
 
126
  nonlocal session_id
127
  logger.debug("Starting SSE writer")
128
  async with sse_stream_writer, write_stream_reader:
129
+ # Since we can't iterate over sse_stream_writer.stream directly in this context,
130
+ # we'll rely on the session_id being set by SseServerTransport
131
  endpoint_data = "/airtable/mcp?session_id={session_id}"
132
  await sse_stream_writer.send(
133
  {"event": "endpoint", "data": endpoint_data}
134
  )
135
  logger.debug(f"Sent endpoint event: {endpoint_data}")
136
+ # The session_id is set by SseServerTransport and should be available after connect_sse
137
+ # We need to extract it from the endpoint_data after substitution
138
+ # SseServerTransport substitutes {session_id} with the actual session_id
139
+ # We'll extract it from the logs or directly from transport if possible
140
+ async for session_message in write_stream_reader:
141
+ logger.debug(f"Sending message via SSE: {session_message}")
142
+ message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
143
+ message = json.loads(message_data)
144
  # Extract session_id from the endpoint event
145
+ if not session_id and message.get("event") == "endpoint":
146
+ endpoint_url = message.get("data", "")
147
  if "session_id=" in endpoint_url:
148
  session_id = endpoint_url.split("session_id=")[1]
149
  write_streams[session_id] = write_stream
150
  logger.debug(f"Extracted session_id: {session_id}")
151
+ await sse_stream_writer.send(
152
+ {
153
+ "event": "message",
154
+ "data": message_data
155
+ }
156
+ )
157
 
158
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
159
  try:
160
  async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
161
  read_stream, write_stream = streams
162
  write_stream_reader = write_stream # Since streams are MemoryObject streams
163
+ # Extract session_id directly from transport's internal state
164
+ # SseServerTransport creates a session_id during connect_sse
165
+ # We can access it by inspecting transport's state
166
+ # Unfortunately, transport's internals are not directly accessible
167
+ # We'll use a temporary workaround by storing the write_stream with a placeholder
168
+ # and update it once we extract the session_id
169
+ placeholder_id = str(id(write_stream))
170
+ write_streams[placeholder_id] = write_stream
171
+ logger.debug(f"Stored write_stream with placeholder_id: {placeholder_id}")
172
  logger.debug("Running MCP server with streams")
173
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
174
  except Exception as e:
 
215
  }
216
  }
217
  logger.debug(f"Manual initialize response: {response}")
 
218
  response_data = json.dumps(response)
219
  await write_stream.send({
220
  "event": "message",
 
223
  return Response(status_code=202)
224
  if message.get("method") == "tools/list":
225
  logger.debug("Handling tools/list request manually")
226
+ response = {
227
+ "jsonrpc": "2.0",
228
+ "id": message.get("id"),
229
+ "result": {
230
+ "tools": [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools],
231
+ "nextCursor": None
 
 
 
 
232
  }
233
+ }
234
+ logger.debug(f"Manual tools/list response: {response}")
235
+ response_data = json.dumps(response)
236
+ # Send to all active write_streams (temporary workaround)
237
+ sent = False
238
+ for sid, ws in list(write_streams.items()):
239
+ try:
240
+ if sid == session_id:
241
  await ws.send({
242
  "event": "message",
243
  "data": response_data
244
  })
245
  logger.debug(f"Sent tools/list response to session {sid}")
246
+ sent = True
247
+ elif sid.startswith("placeholder_"):
248
+ # Update the placeholder with the real session_id
249
+ write_streams[session_id] = ws
250
+ write_streams.pop(sid, None)
251
+ await ws.send({
252
+ "event": "message",
253
+ "data": response_data
254
+ })
255
+ logger.debug(f"Updated placeholder {sid} to session_id {session_id} and sent tools/list response")
256
+ sent = True
257
+ except Exception as e:
258
+ logger.error(f"Error sending to session {sid}: {str(e)}")
259
+ write_streams.pop(sid, None) # Remove closed streams
260
+ if not sent:
261
+ logger.warning(f"Failed to send tools/list response: no active write_streams found")
262
+ return Response(status_code=202)
263
  # If write_stream is None, log and handle gracefully
264
  if not write_stream:
265
  logger.error(f"No write_stream found for session_id: {session_id}")