File size: 6,153 Bytes
df8330d
fea1872
df8330d
 
 
 
 
 
 
 
 
 
fea1872
df8330d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fea1872
df8330d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fea1872
 
 
 
 
 
 
 
 
df8330d
 
 
 
 
 
fea1872
 
 
 
 
 
 
 
df8330d
 
fea1872
 
df8330d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from flask import Flask, request, jsonify
import os
import uuid
import time
import docker
import requests
import atexit
import socket 
import argparse
import logging
from pydantic import BaseModel, Field, ValidationError

current_dir = os.path.dirname(os.path.abspath(__file__))

app = Flask(__name__)
app.logger.setLevel(logging.INFO)


# CLI function to parse arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Jupyter server.")
    parser.add_argument('--n_instances', type=int, help="Number of Jupyter instances.")
    parser.add_argument('--n_cpus', type=int, default=2, help="Number of CPUs per Jupyter instance.")
    parser.add_argument('--mem', type=str, default="2g", help="Amount of memory per Jupyter instance.")
    parser.add_argument('--execution_timeout', type=int, default=10, help="Timeout period for a code execution.")
    parser.add_argument('--port', type=int, default=5001, help="Port of main server")
    return parser.parse_args()


def get_unused_port(start=50000, end=65535, exclusion=[]):
    for port in range(start, end + 1):
        if port in exclusion:
            continue
        try:
            sock = socket.socket()
            sock.bind(("", port))
            sock.listen(1)
            sock.close()
            return port
        except OSError:
            continue
    raise IOError("No free ports available in range {}-{}".format(start, end))


def create_kernel_containers(n_instances, n_cpus=2, mem="2g", execution_timeout=10):

    docker_client = docker.from_env()
    app.logger.info("Buidling docker image...")
    image, logs = docker_client.images.build(path=current_dir, tag='jupyter-kernel:latest')
    app.logger.info("Building docker image complete.")

    containers = []
    port_exclusion = []
    for i in range(n_instances):
        
        free_port =  get_unused_port(exclusion=port_exclusion)
        port_exclusion.append(free_port) # it takes a while to startup so we don't use the same port twice
        app.logger.info(f"Starting container {i} on port {free_port}...")
        container = docker_client.containers.run(
            "jupyter-kernel:latest",
            detach=True,
            mem_limit=mem,
            cpuset_cpus=f"{i*n_cpus}-{(i+1)*n_cpus-1}",  # Limit to CPU cores 0 and 1
            remove=True,
            ports={'5000/tcp': free_port},
            environment={"EXECUTION_TIMEOUT": execution_timeout},
        )

        containers.append({"container": container, "port": free_port})

    start_time = time.time()
    
    containers_ready = []

    while len(containers_ready) < n_instances:
        app.logger.info("Pinging Jupyter containers to check readiness.")
        if time.time() - start_time > 60:
            raise TimeoutError("Container took too long to startup.")
        for i in range(n_instances):
            if i in containers_ready:
                continue
            url = f"http://localhost:{containers[i]['port']}/health"
            try:
                # TODO: dedicated health endpoint
                response = requests.get(url)
                if response.status_code == 200:
                    containers_ready.append(i)
            except Exception as e:
                # Catch any other errors that might occur
                pass
        time.sleep(0.5)
    app.logger.info("Containers ready!")
    return containers

def shutdown_cleanup():
    app.logger.info("Shutting down. Stopping and removing all containers...")
    for instance in app.containers:
        try:
            instance['container'].stop()
            instance['container'].remove()
        except Exception as e:
            app.logger.info(f"Error stopping/removing container: {str(e)}")
    app.logger.info("All containers stopped and removed.")


class ServerRequest(BaseModel):
    code: str = Field(..., example="print('Hello World!')")
    instance_id: int = Field(0, example=0)
    restart: bool = Field(False, example=False)


@app.route('/execute', methods=['POST'])
def execute_code():
    try:
        input = ServerRequest(**request.json)
    except ValidationError as e:
        return jsonify(e.errors()), 400


    port = app.containers[input.instance_id]["port"]

    app.logger.info(f"Received request for instance {input.instance_id} (port={port}).")

    try:
        if input.restart:
            response = requests.post(f'http://localhost:{port}/restart', json={})
            if response.status_code==200:
                app.logger.info(f"Kernel for instance {input.instance_id} restarted.")
            else:
                app.logger.info(f"Error when restarting kernel of instance {input.instance_id}: {response.json()}.")

        response = requests.post(f'http://localhost:{port}/execute', json={'code': input.code})
        result = response.json()
        return result
    
    except Exception as e:
        app.logger.info(f"Error in execute_code: {str(e)}")
        return jsonify({
            'result': 'error',
            'output': str(e)
        }), 500


def init_app(app, args=None):
    if args is None:
        # When run through Gunicorn, use environment variables
        args = argparse.Namespace(
            n_instances=int(os.getenv('N_INSTANCES', 1)),
            n_cpus=int(os.getenv('N_CPUS', 1)),
            mem=os.getenv('MEM', '1g'),
            execution_timeout=int(os.getenv('EXECUTION_TIMEOUT', 60))
        )

    app.containers = create_kernel_containers(
        args.n_instances,
        n_cpus=args.n_cpus, 
        mem=args.mem, 
        execution_timeout=args.execution_timeout
    )
    return app, args

atexit.register(shutdown_cleanup)

if __name__ == '__main__':
    args = parse_args()
    app, args = init_app(app, args=args)
    # don't use debug=True --> it will run main twice and thus start double the containers
    app.run(debug=False, host='0.0.0.0', port=args.port) 
else:
    app, args = init_app(app)


# TODO:
# how to mount data at runtime into the container? idea: mount a (read only) 
# folder into the container at startup and copy the data in there. before starting 
# the kernel we could cp the necessary data into the pwd.