File size: 6,153 Bytes
df8330d fea1872 df8330d fea1872 df8330d fea1872 df8330d fea1872 df8330d fea1872 df8330d fea1872 df8330d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from flask import Flask, request, jsonify
import os
import uuid
import time
import docker
import requests
import atexit
import socket
import argparse
import logging
from pydantic import BaseModel, Field, ValidationError
current_dir = os.path.dirname(os.path.abspath(__file__))
app = Flask(__name__)
app.logger.setLevel(logging.INFO)
# CLI function to parse arguments
def parse_args():
parser = argparse.ArgumentParser(description="Jupyter server.")
parser.add_argument('--n_instances', type=int, help="Number of Jupyter instances.")
parser.add_argument('--n_cpus', type=int, default=2, help="Number of CPUs per Jupyter instance.")
parser.add_argument('--mem', type=str, default="2g", help="Amount of memory per Jupyter instance.")
parser.add_argument('--execution_timeout', type=int, default=10, help="Timeout period for a code execution.")
parser.add_argument('--port', type=int, default=5001, help="Port of main server")
return parser.parse_args()
def get_unused_port(start=50000, end=65535, exclusion=[]):
for port in range(start, end + 1):
if port in exclusion:
continue
try:
sock = socket.socket()
sock.bind(("", port))
sock.listen(1)
sock.close()
return port
except OSError:
continue
raise IOError("No free ports available in range {}-{}".format(start, end))
def create_kernel_containers(n_instances, n_cpus=2, mem="2g", execution_timeout=10):
docker_client = docker.from_env()
app.logger.info("Buidling docker image...")
image, logs = docker_client.images.build(path=current_dir, tag='jupyter-kernel:latest')
app.logger.info("Building docker image complete.")
containers = []
port_exclusion = []
for i in range(n_instances):
free_port = get_unused_port(exclusion=port_exclusion)
port_exclusion.append(free_port) # it takes a while to startup so we don't use the same port twice
app.logger.info(f"Starting container {i} on port {free_port}...")
container = docker_client.containers.run(
"jupyter-kernel:latest",
detach=True,
mem_limit=mem,
cpuset_cpus=f"{i*n_cpus}-{(i+1)*n_cpus-1}", # Limit to CPU cores 0 and 1
remove=True,
ports={'5000/tcp': free_port},
environment={"EXECUTION_TIMEOUT": execution_timeout},
)
containers.append({"container": container, "port": free_port})
start_time = time.time()
containers_ready = []
while len(containers_ready) < n_instances:
app.logger.info("Pinging Jupyter containers to check readiness.")
if time.time() - start_time > 60:
raise TimeoutError("Container took too long to startup.")
for i in range(n_instances):
if i in containers_ready:
continue
url = f"http://localhost:{containers[i]['port']}/health"
try:
# TODO: dedicated health endpoint
response = requests.get(url)
if response.status_code == 200:
containers_ready.append(i)
except Exception as e:
# Catch any other errors that might occur
pass
time.sleep(0.5)
app.logger.info("Containers ready!")
return containers
def shutdown_cleanup():
app.logger.info("Shutting down. Stopping and removing all containers...")
for instance in app.containers:
try:
instance['container'].stop()
instance['container'].remove()
except Exception as e:
app.logger.info(f"Error stopping/removing container: {str(e)}")
app.logger.info("All containers stopped and removed.")
class ServerRequest(BaseModel):
code: str = Field(..., example="print('Hello World!')")
instance_id: int = Field(0, example=0)
restart: bool = Field(False, example=False)
@app.route('/execute', methods=['POST'])
def execute_code():
try:
input = ServerRequest(**request.json)
except ValidationError as e:
return jsonify(e.errors()), 400
port = app.containers[input.instance_id]["port"]
app.logger.info(f"Received request for instance {input.instance_id} (port={port}).")
try:
if input.restart:
response = requests.post(f'http://localhost:{port}/restart', json={})
if response.status_code==200:
app.logger.info(f"Kernel for instance {input.instance_id} restarted.")
else:
app.logger.info(f"Error when restarting kernel of instance {input.instance_id}: {response.json()}.")
response = requests.post(f'http://localhost:{port}/execute', json={'code': input.code})
result = response.json()
return result
except Exception as e:
app.logger.info(f"Error in execute_code: {str(e)}")
return jsonify({
'result': 'error',
'output': str(e)
}), 500
def init_app(app, args=None):
if args is None:
# When run through Gunicorn, use environment variables
args = argparse.Namespace(
n_instances=int(os.getenv('N_INSTANCES', 1)),
n_cpus=int(os.getenv('N_CPUS', 1)),
mem=os.getenv('MEM', '1g'),
execution_timeout=int(os.getenv('EXECUTION_TIMEOUT', 60))
)
app.containers = create_kernel_containers(
args.n_instances,
n_cpus=args.n_cpus,
mem=args.mem,
execution_timeout=args.execution_timeout
)
return app, args
atexit.register(shutdown_cleanup)
if __name__ == '__main__':
args = parse_args()
app, args = init_app(app, args=args)
# don't use debug=True --> it will run main twice and thus start double the containers
app.run(debug=False, host='0.0.0.0', port=args.port)
else:
app, args = init_app(app)
# TODO:
# how to mount data at runtime into the container? idea: mount a (read only)
# folder into the container at startup and copy the data in there. before starting
# the kernel we could cp the necessary data into the pwd.
|