Spaces:
Build error
Build error
import asyncio | |
import json | |
import logging | |
import uuid | |
from typing import Optional | |
import aiohttp | |
import websockets | |
from pydantic import BaseModel | |
from open_webui.env import SRC_LOG_LEVELS | |
logger = logging.getLogger(__name__) | |
logger.setLevel(SRC_LOG_LEVELS["MAIN"]) | |
class ResultModel(BaseModel): | |
""" | |
Execute Code Result Model | |
""" | |
stdout: Optional[str] = "" | |
stderr: Optional[str] = "" | |
result: Optional[str] = "" | |
class JupyterCodeExecuter: | |
""" | |
Execute code in jupyter notebook | |
""" | |
def __init__( | |
self, | |
base_url: str, | |
code: str, | |
token: str = "", | |
password: str = "", | |
timeout: int = 60, | |
): | |
""" | |
:param base_url: Jupyter server URL (e.g., "http://localhost:8888") | |
:param code: Code to execute | |
:param token: Jupyter authentication token (optional) | |
:param password: Jupyter password (optional) | |
:param timeout: WebSocket timeout in seconds (default: 60s) | |
""" | |
self.base_url = base_url.rstrip("/") | |
self.code = code | |
self.token = token | |
self.password = password | |
self.timeout = timeout | |
self.kernel_id = "" | |
self.session = aiohttp.ClientSession(base_url=self.base_url) | |
self.params = {} | |
self.result = ResultModel() | |
async def __aenter__(self): | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
if self.kernel_id: | |
try: | |
async with self.session.delete( | |
f"/api/kernels/{self.kernel_id}", params=self.params | |
) as response: | |
response.raise_for_status() | |
except Exception as err: | |
logger.exception("close kernel failed, %s", err) | |
await self.session.close() | |
async def run(self) -> ResultModel: | |
try: | |
await self.sign_in() | |
await self.init_kernel() | |
await self.execute_code() | |
except Exception as err: | |
logger.exception("execute code failed, %s", err) | |
self.result.stderr = f"Error: {err}" | |
return self.result | |
async def sign_in(self) -> None: | |
# password authentication | |
if self.password and not self.token: | |
async with self.session.get("/login") as response: | |
response.raise_for_status() | |
xsrf_token = response.cookies["_xsrf"].value | |
if not xsrf_token: | |
raise ValueError("_xsrf token not found") | |
self.session.cookie_jar.update_cookies(response.cookies) | |
self.session.headers.update({"X-XSRFToken": xsrf_token}) | |
async with self.session.post( | |
"/login", | |
data={"_xsrf": xsrf_token, "password": self.password}, | |
allow_redirects=False, | |
) as response: | |
response.raise_for_status() | |
self.session.cookie_jar.update_cookies(response.cookies) | |
# token authentication | |
if self.token: | |
self.params.update({"token": self.token}) | |
async def init_kernel(self) -> None: | |
async with self.session.post( | |
url="/api/kernels", params=self.params | |
) as response: | |
response.raise_for_status() | |
kernel_data = await response.json() | |
self.kernel_id = kernel_data["id"] | |
def init_ws(self) -> (str, dict): | |
ws_base = self.base_url.replace("http", "ws") | |
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()]) | |
websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}" | |
ws_headers = {} | |
if self.password and not self.token: | |
ws_headers = { | |
"Cookie": "; ".join( | |
[ | |
f"{cookie.key}={cookie.value}" | |
for cookie in self.session.cookie_jar | |
] | |
), | |
**self.session.headers, | |
} | |
return websocket_url, ws_headers | |
async def execute_code(self) -> None: | |
# initialize ws | |
websocket_url, ws_headers = self.init_ws() | |
# execute | |
async with websockets.connect( | |
websocket_url, additional_headers=ws_headers | |
) as ws: | |
await self.execute_in_jupyter(ws) | |
async def execute_in_jupyter(self, ws) -> None: | |
# send message | |
msg_id = uuid.uuid4().hex | |
await ws.send( | |
json.dumps( | |
{ | |
"header": { | |
"msg_id": msg_id, | |
"msg_type": "execute_request", | |
"username": "user", | |
"session": uuid.uuid4().hex, | |
"date": "", | |
"version": "5.3", | |
}, | |
"parent_header": {}, | |
"metadata": {}, | |
"content": { | |
"code": self.code, | |
"silent": False, | |
"store_history": True, | |
"user_expressions": {}, | |
"allow_stdin": False, | |
"stop_on_error": True, | |
}, | |
"channel": "shell", | |
} | |
) | |
) | |
# parse message | |
stdout, stderr, result = "", "", [] | |
while True: | |
try: | |
# wait for message | |
message = await asyncio.wait_for(ws.recv(), self.timeout) | |
message_data = json.loads(message) | |
# msg id not match, skip | |
if message_data.get("parent_header", {}).get("msg_id") != msg_id: | |
continue | |
# check message type | |
msg_type = message_data.get("msg_type") | |
match msg_type: | |
case "stream": | |
if message_data["content"]["name"] == "stdout": | |
stdout += message_data["content"]["text"] | |
elif message_data["content"]["name"] == "stderr": | |
stderr += message_data["content"]["text"] | |
case "execute_result" | "display_data": | |
data = message_data["content"]["data"] | |
if "image/png" in data: | |
result.append(f"data:image/png;base64,{data['image/png']}") | |
elif "text/plain" in data: | |
result.append(data["text/plain"]) | |
case "error": | |
stderr += "\n".join(message_data["content"]["traceback"]) | |
case "status": | |
if message_data["content"]["execution_state"] == "idle": | |
break | |
except asyncio.TimeoutError: | |
stderr += "\nExecution timed out." | |
break | |
self.result.stdout = stdout.strip() | |
self.result.stderr = stderr.strip() | |
self.result.result = "\n".join(result).strip() if result else "" | |
async def execute_code_jupyter( | |
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60 | |
) -> dict: | |
async with JupyterCodeExecuter( | |
base_url, code, token, password, timeout | |
) as executor: | |
result = await executor.run() | |
return result.model_dump() | |