#!/usr/bin/env python3 import asyncio import logging import os import re from uuid import uuid4 import tornado import tornado.websocket from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from tornado.escape import json_decode, json_encode, url_escape from tornado.httpclient import AsyncHTTPClient, HTTPRequest from tornado.ioloop import PeriodicCallback from tornado.websocket import websocket_connect logging.basicConfig(level=logging.INFO) def strip_ansi(o: str) -> str: """Removes ANSI escape sequences from `o`, as defined by ECMA-048 in http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf # https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py >>> strip_ansi("\\033[33mLorem ipsum\\033[0m") 'Lorem ipsum' >>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.") 'Lorem Ipsum sit\\namet.' >>> strip_ansi("") '' >>> strip_ansi("\\x1b[0m") '' >>> strip_ansi("Lorem") 'Lorem' >>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m') 'Lorem ipsum' >>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m') 'Lorem dolor sit ipsum' """ # pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/') pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m') stripped = pattern.sub('', o) return stripped class JupyterKernel: def __init__(self, url_suffix: str, convid: str, lang: str = 'python') -> None: self.base_url = f'http://{url_suffix}' self.base_ws_url = f'ws://{url_suffix}' self.lang = lang self.kernel_id: str | None = None self.ws: tornado.websocket.WebSocketClientConnection | None = None self.convid = convid logging.info( f'Jupyter kernel created for conversation {convid} at {url_suffix}' ) self.heartbeat_interval = 10000 # 10 seconds self.heartbeat_callback: PeriodicCallback | None = None self.initialized = False async def initialize(self) -> None: await self.execute(r'%colors nocolor') # pre-defined tools self.tools_to_run: list[str] = [ # TODO: You can add code for your pre-defined tools here ] for tool in self.tools_to_run: res = await self.execute(tool) logging.info(f'Tool [{tool}] initialized:\n{res}') self.initialized = True async def _send_heartbeat(self) -> None: if not self.ws: return try: self.ws.ping() # logging.info('Heartbeat sent...') except tornado.iostream.StreamClosedError: # logging.info('Heartbeat failed, reconnecting...') try: await self._connect() except ConnectionRefusedError: logging.info( 'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?' ) async def _connect(self) -> None: if self.ws: self.ws.close() self.ws = None client = AsyncHTTPClient() if not self.kernel_id: n_tries = 5 while n_tries > 0: try: response = await client.fetch( '{}/api/kernels'.format(self.base_url), method='POST', body=json_encode({'name': self.lang}), ) kernel = json_decode(response.body) self.kernel_id = kernel['id'] break except Exception: # kernels are not ready yet n_tries -= 1 await asyncio.sleep(1) if n_tries == 0: raise ConnectionRefusedError('Failed to connect to kernel') ws_req = HTTPRequest( url='{}/api/kernels/{}/channels'.format( self.base_ws_url, url_escape(self.kernel_id) ) ) self.ws = await websocket_connect(ws_req) logging.info('Connected to kernel websocket') # Setup heartbeat if self.heartbeat_callback: self.heartbeat_callback.stop() self.heartbeat_callback = PeriodicCallback( self._send_heartbeat, self.heartbeat_interval ) self.heartbeat_callback.start() @retry( retry=retry_if_exception_type(ConnectionRefusedError), stop=stop_after_attempt(3), wait=wait_fixed(2), ) # type: ignore async def execute( self, code: str, timeout: int = 120 ) -> dict[str, list[str] | str]: if not self.ws or self.ws.stream.closed(): await self._connect() msg_id = uuid4().hex assert self.ws is not None res = await self.ws.write_message( json_encode( { 'header': { 'username': '', 'version': '5.0', 'session': '', 'msg_id': msg_id, 'msg_type': 'execute_request', }, 'parent_header': {}, 'channel': 'shell', 'content': { 'code': code, 'silent': False, 'store_history': False, 'user_expressions': {}, 'allow_stdin': False, }, 'metadata': {}, 'buffers': {}, } ) ) logging.info(f'Executed code in jupyter kernel:\n{res}') outputs: list[dict] = [] async def wait_for_messages() -> bool: execution_done = False while not execution_done: assert self.ws is not None msg = await self.ws.read_message() if msg is None: continue msg_dict = json_decode(msg) msg_type = msg_dict['msg_type'] parent_msg_id = msg_dict['parent_header'].get('msg_id', None) if parent_msg_id != msg_id: continue if os.environ.get('DEBUG'): logging.info( f'MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg_dict["content"]}' ) if msg_type == 'error': traceback = '\n'.join(msg_dict['content']['traceback']) outputs.append({'type': 'text', 'content': traceback}) execution_done = True elif msg_type == 'stream': outputs.append( {'type': 'text', 'content': msg_dict['content']['text']} ) elif msg_type in ['execute_result', 'display_data']: outputs.append( { 'type': 'text', 'content': msg_dict['content']['data']['text/plain'], } ) if 'image/png' in msg_dict['content']['data']: # Store image data in structured format image_url = f'data:image/png;base64,{msg_dict["content"]["data"]["image/png"]}' outputs.append({'type': 'image', 'content': image_url}) elif msg_type == 'execute_reply': execution_done = True return execution_done async def interrupt_kernel() -> None: client = AsyncHTTPClient() if self.kernel_id is None: return interrupt_response = await client.fetch( f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt', method='POST', body=json_encode({'kernel_id': self.kernel_id}), ) logging.info(f'Kernel interrupted: {interrupt_response}') try: execution_done = await asyncio.wait_for(wait_for_messages(), timeout) except asyncio.TimeoutError: await interrupt_kernel() return {'text': f'[Execution timed out ({timeout} seconds).]', 'images': []} # Process structured outputs text_outputs = [] image_outputs = [] for output in outputs: if output['type'] == 'text': text_outputs.append(output['content']) elif output['type'] == 'image': image_outputs.append(output['content']) if not text_outputs and execution_done: text_content = '[Code executed successfully with no output]' else: text_content = ''.join(text_outputs) # Remove ANSI from text content text_content = strip_ansi(text_content) # Return a dictionary with text content and image URLs return {'text': text_content, 'images': image_outputs} async def shutdown_async(self) -> None: if self.kernel_id: client = AsyncHTTPClient() await client.fetch( '{}/api/kernels/{}'.format(self.base_url, self.kernel_id), method='DELETE', ) self.kernel_id = None if self.ws: self.ws.close() self.ws = None class ExecuteHandler(tornado.web.RequestHandler): def initialize(self, jupyter_kernel: JupyterKernel) -> None: self.jupyter_kernel = jupyter_kernel async def post(self) -> None: data = json_decode(self.request.body) code = data.get('code') if not code: self.set_status(400) self.write('Missing code') return output = await self.jupyter_kernel.execute(code) # Set content type to JSON and return the structured output self.set_header('Content-Type', 'application/json') self.write(json_encode(output)) def make_app() -> tornado.web.Application: jupyter_kernel = JupyterKernel( f'localhost:{os.environ.get("JUPYTER_GATEWAY_PORT", "8888")}', os.environ.get('JUPYTER_GATEWAY_KERNEL_ID', 'default'), ) asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize()) return tornado.web.Application( [ (r'/execute', ExecuteHandler, {'jupyter_kernel': jupyter_kernel}), ] ) if __name__ == '__main__': app = make_app() app.listen(os.environ.get('JUPYTER_EXEC_SERVER_PORT')) tornado.ioloop.IOLoop.current().start()