aeonshift commited on
Commit
fef2e57
·
verified ·
1 Parent(s): 6d033fe

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +13 -26
server.py CHANGED
@@ -7,7 +7,6 @@ from fastapi import FastAPI, Request
7
  from fastapi.responses import Response
8
  from sse_starlette import EventSourceResponse
9
  from mcp.server.lowlevel import Server
10
- from mcp.server.sse import SseServerTransport
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
@@ -17,7 +16,6 @@ app = FastAPI()
17
 
18
  # Define the MCP server
19
  server = Server(name="airtable-mcp")
20
- transport = SseServerTransport("/airtable/mcp")
21
 
22
  # Store write streams for each session ID
23
  write_streams = {}
@@ -34,39 +32,33 @@ async def handle_sse(request: Request):
34
  session_id = None
35
  async def sse_writer():
36
  nonlocal session_id
37
- async with sse_stream_writer, write_stream_reader:
38
  endpoint_data = f"/airtable/mcp?session_id={{session_id}}"
39
  await sse_stream_writer.send({"event": "endpoint", "data": endpoint_data})
40
- async for session_message in write_stream_reader:
41
- if isinstance(session_message, dict):
42
- event = session_message.get("event", "message")
43
- data = session_message.get("data", json.dumps(session_message))
44
- else:
45
- event = "message"
46
- data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
47
- message = json.loads(data) if isinstance(data, str) else data
48
- if not session_id and event == "endpoint":
49
- endpoint_url = data
50
  if "session_id=" in endpoint_url:
51
  session_id = endpoint_url.split("session_id=")[1]
52
  placeholder_id = f"placeholder_{id(write_stream)}"
53
  if placeholder_id in write_streams:
54
  write_streams[session_id] = write_streams.pop(placeholder_id)
55
  logger.info(f"Updated placeholder {placeholder_id} to session_id {session_id}")
56
- await sse_stream_writer.send({"event": event, "data": data})
57
  # Keep-alive loop to maintain the SSE connection
58
  while True:
59
  await sse_stream_writer.send({"event": "ping", "data": "keep-alive"})
60
  await asyncio.sleep(15) # Send keep-alive every 15 seconds
61
 
62
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
63
- async with transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
64
- write_stream_reader = write_stream
65
- placeholder_id = f"placeholder_{id(write_stream)}"
66
- write_streams[placeholder_id] = write_stream
67
- logger.info("Starting MCP server with streams")
68
- await server.run(read_stream, write_stream, server.create_initialization_options())
69
- logger.info("MCP server running")
70
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
71
  except Exception as e:
72
  logger.error(f"Error in handle_sse: {str(e)}")
@@ -141,11 +133,6 @@ async def handle_post(request: Request):
141
  logger.error(f"No write_stream found for session_id: {session_id}")
142
  return Response(status_code=202)
143
 
144
- try:
145
- await transport.handle_post_message(request.scope, request.receive, request._send)
146
- logger.info(f"Handled post message for session {session_id} via transport")
147
- except Exception as e:
148
- logger.error(f"Failed to handle post message via transport for session {session_id}: {str(e)}")
149
  return Response(status_code=202)
150
  except Exception as e:
151
  logger.error(f"Error handling POST message: {str(e)}")
 
7
  from fastapi.responses import Response
8
  from sse_starlette import EventSourceResponse
9
  from mcp.server.lowlevel import Server
 
10
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO)
 
16
 
17
  # Define the MCP server
18
  server = Server(name="airtable-mcp")
 
19
 
20
  # Store write streams for each session ID
21
  write_streams = {}
 
32
  session_id = None
33
  async def sse_writer():
34
  nonlocal session_id
35
+ async with sse_stream_writer, read_stream_reader:
36
  endpoint_data = f"/airtable/mcp?session_id={{session_id}}"
37
  await sse_stream_writer.send({"event": "endpoint", "data": endpoint_data})
38
+ async for message in read_stream_reader:
39
+ message_data = json.loads(message) if isinstance(message, str) else message
40
+ if message_data.get("event") == "endpoint":
41
+ endpoint_url = message_data.get("data", "")
 
 
 
 
 
 
42
  if "session_id=" in endpoint_url:
43
  session_id = endpoint_url.split("session_id=")[1]
44
  placeholder_id = f"placeholder_{id(write_stream)}"
45
  if placeholder_id in write_streams:
46
  write_streams[session_id] = write_streams.pop(placeholder_id)
47
  logger.info(f"Updated placeholder {placeholder_id} to session_id {session_id}")
48
+ await sse_stream_writer.send(message_data)
49
  # Keep-alive loop to maintain the SSE connection
50
  while True:
51
  await sse_stream_writer.send({"event": "ping", "data": "keep-alive"})
52
  await asyncio.sleep(15) # Send keep-alive every 15 seconds
53
 
54
  sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0)
55
+ read_stream, write_stream = anyio.create_memory_object_stream(0)
56
+ placeholder_id = f"placeholder_{id(write_stream)}"
57
+ write_streams[placeholder_id] = write_stream
58
+ read_stream_reader = read_stream
59
+ logger.info("Starting MCP server with streams")
60
+ await server.run(read_stream, write_stream, server.create_initialization_options())
61
+ logger.info("MCP server running")
62
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
63
  except Exception as e:
64
  logger.error(f"Error in handle_sse: {str(e)}")
 
133
  logger.error(f"No write_stream found for session_id: {session_id}")
134
  return Response(status_code=202)
135
 
 
 
 
 
 
136
  return Response(status_code=202)
137
  except Exception as e:
138
  logger.error(f"Error handling POST message: {str(e)}")