aeonshift commited on
Commit
bda84dd
·
verified ·
1 Parent(s): 6a9c2dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -15
app.py CHANGED
@@ -12,6 +12,7 @@ from sse_starlette import EventSourceResponse
12
  import anyio
13
  import asyncio
14
  import logging
 
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.DEBUG)
@@ -110,22 +111,29 @@ mcp_server = Server(name="airtable-mcp")
110
  mcp_server.tool_handlers = tool_handlers # Set as attribute
111
  mcp_server.tools = tools # Set tools as attribute for Deep Agent to discover
112
 
113
- # Store write_stream globally to access in POST handler
114
- write_stream_global = None
 
 
 
115
 
116
  # SSE endpoint for GET requests
117
  @app.get("/airtable/mcp")
118
  async def handle_sse(request: Request):
119
- global write_stream_global
120
  logger.debug("Handling SSE connection request")
121
  async def sse_writer():
122
  logger.debug("Starting SSE writer")
123
  async with sse_stream_writer, write_stream_reader:
 
 
124
  endpoint_data = "/airtable/mcp?session_id={session_id}"
125
  await sse_stream_writer.send(
126
  {"event": "endpoint", "data": endpoint_data}
127
  )
128
  logger.debug(f"Sent endpoint event: {endpoint_data}")
 
 
 
129
  async for session_message in write_stream_reader:
130
  logger.debug(f"Sending message via SSE: {session_message}")
131
  await sse_stream_writer.send(
@@ -138,12 +146,19 @@ async def handle_sse(request: Request):
138
  )
139
 
140
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
141
- async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
142
- read_stream, write_stream = streams
143
- write_stream_reader = write_stream # Since streams are MemoryObject streams
144
- write_stream_global = write_stream # Store write_stream for POST handler
145
- logger.debug("Running MCP server with streams")
146
- await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options(), tools=tools)
 
 
 
 
 
 
 
147
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
148
 
149
  # Message handling endpoint for POST requests
@@ -154,8 +169,9 @@ async def handle_post_message(request: Request):
154
  logger.debug(f"Received POST message body: {body}")
155
  try:
156
  message = json.loads(body.decode())
157
- # Handle initialize request manually to ensure capabilities include tools
158
- if message.get("method") == "initialize" and write_stream_global:
 
159
  logger.debug("Handling initialize request manually")
160
  response = {
161
  "jsonrpc": "2.0",
@@ -179,10 +195,9 @@ async def handle_post_message(request: Request):
179
  message=mcp_types.JSONRPCResponse(**response),
180
  metadata=mcp_types.ServerMessageMetadata(request_context=request)
181
  )
182
- await write_stream_global.send(session_message)
183
  return Response(status_code=202)
184
- # Handle tools/list request manually
185
- if message.get("method") == "tools/list" and write_stream_global:
186
  logger.debug("Handling tools/list request manually")
187
  response = {
188
  "jsonrpc": "2.0",
@@ -196,7 +211,7 @@ async def handle_post_message(request: Request):
196
  message=mcp_types.JSONRPCResponse(**response),
197
  metadata=mcp_types.ServerMessageMetadata(request_context=request)
198
  )
199
- await write_stream_global.send(session_message)
200
  return Response(status_code=202)
201
  await transport.handle_post_message(request.scope, request.receive, request._send)
202
  logger.debug("POST message handled successfully")
 
12
  import anyio
13
  import asyncio
14
  import logging
15
+ from typing import Dict
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.DEBUG)
 
111
  mcp_server.tool_handlers = tool_handlers # Set as attribute
112
  mcp_server.tools = tools # Set tools as attribute for Deep Agent to discover
113
 
114
+ # Initialize SseServerTransport
115
+ transport = SseServerTransport("/airtable/mcp")
116
+
117
+ # Store write streams for each session ID
118
+ write_streams: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {}
119
 
120
  # SSE endpoint for GET requests
121
  @app.get("/airtable/mcp")
122
  async def handle_sse(request: Request):
 
123
  logger.debug("Handling SSE connection request")
124
  async def sse_writer():
125
  logger.debug("Starting SSE writer")
126
  async with sse_stream_writer, write_stream_reader:
127
+ # Extract session_id from endpoint_data
128
+ session_id = None
129
  endpoint_data = "/airtable/mcp?session_id={session_id}"
130
  await sse_stream_writer.send(
131
  {"event": "endpoint", "data": endpoint_data}
132
  )
133
  logger.debug(f"Sent endpoint event: {endpoint_data}")
134
+ # Extract session_id from the sent message (simplified)
135
+ # In a real implementation, we'd parse the session_id from transport
136
+ # For now, we'll set it after the first message
137
  async for session_message in write_stream_reader:
138
  logger.debug(f"Sending message via SSE: {session_message}")
139
  await sse_stream_writer.send(
 
146
  )
147
 
148
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
149
+ try:
150
+ async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
151
+ read_stream, write_stream = streams
152
+ write_stream_reader = write_stream # Since streams are MemoryObject streams
153
+ # Extract session_id (this is a simplification; ideally, we'd get it from transport)
154
+ session_id = request.query_params.get("session_id", "unknown")
155
+ write_streams[session_id] = write_stream # Store write_stream for this session
156
+ logger.debug(f"Stored write_stream for session_id: {session_id}")
157
+ logger.debug("Running MCP server with streams")
158
+ await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
159
+ except Exception as e:
160
+ logger.error(f"Error in handle_sse: {str(e)}")
161
+ raise
162
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
163
 
164
  # Message handling endpoint for POST requests
 
169
  logger.debug(f"Received POST message body: {body}")
170
  try:
171
  message = json.loads(body.decode())
172
+ session_id = request.query_params.get("session_id")
173
+ write_stream = write_streams.get(session_id) if session_id else None
174
+ if message.get("method") == "initialize" and write_stream:
175
  logger.debug("Handling initialize request manually")
176
  response = {
177
  "jsonrpc": "2.0",
 
195
  message=mcp_types.JSONRPCResponse(**response),
196
  metadata=mcp_types.ServerMessageMetadata(request_context=request)
197
  )
198
+ await write_stream.send(session_message)
199
  return Response(status_code=202)
200
+ if message.get("method") == "tools/list" and write_stream:
 
201
  logger.debug("Handling tools/list request manually")
202
  response = {
203
  "jsonrpc": "2.0",
 
211
  message=mcp_types.JSONRPCResponse(**response),
212
  metadata=mcp_types.ServerMessageMetadata(request_context=request)
213
  )
214
+ await write_stream.send(session_message)
215
  return Response(status_code=202)
216
  await transport.handle_post_message(request.scope, request.receive, request._send)
217
  logger.debug("POST message handled successfully")