File size: 3,838 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""A Websocket Handler for emitting Jupyter server events.

.. versionadded:: 2.0
"""

from __future__ import annotations

import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional, cast

from jupyter_core.utils import ensure_async
from tornado import web, websocket

from jupyter_server.auth.decorator import authorized, ws_authenticated
from jupyter_server.base.handlers import JupyterHandler

from ...base.handlers import APIHandler

AUTH_RESOURCE = "events"


if TYPE_CHECKING:
    import jupyter_events.logger


class SubscribeWebsocket(
    JupyterHandler,
    websocket.WebSocketHandler,
):
    """Websocket handler for subscribing to events"""

    auth_resource = AUTH_RESOURCE

    async def pre_get(self):
        """Handles authorization when
        attempting to subscribe to events emitted by
        Jupyter Server's eventbus.
        """
        user = self.current_user
        # authorize the user.
        authorized = await ensure_async(
            self.authorizer.is_authorized(self, user, "execute", "events")
        )
        if not authorized:
            raise web.HTTPError(403)

    @ws_authenticated
    async def get(self, *args, **kwargs):
        """Get an event socket."""
        await ensure_async(self.pre_get())
        res = super().get(*args, **kwargs)
        if res is not None:
            await res

    async def event_listener(
        self, logger: jupyter_events.logger.EventLogger, schema_id: str, data: dict[str, Any]
    ) -> None:
        """Write an event message."""
        capsule = dict(schema_id=schema_id, **data)
        self.write_message(json.dumps(capsule))

    def open(self):
        """Routes events that are emitted by Jupyter Server's
        EventBus to a WebSocket client in the browser.
        """
        self.event_logger.add_listener(listener=self.event_listener)

    def on_close(self):
        """Handle a socket close."""
        self.event_logger.remove_listener(listener=self.event_listener)


def validate_model(data: dict[str, Any]) -> None:
    """Validates for required fields in the JSON request body"""
    required_keys = {"schema_id", "version", "data"}
    for key in required_keys:
        if key not in data:
            raise web.HTTPError(400, f"Missing `{key}` in the JSON request body.")


def get_timestamp(data: dict[str, Any]) -> Optional[datetime]:
    """Parses timestamp from the JSON request body"""
    try:
        if "timestamp" in data:
            timestamp = datetime.strptime(data["timestamp"], "%Y-%m-%dT%H:%M:%S%zZ")
        else:
            timestamp = None
    except Exception as e:
        raise web.HTTPError(
            400,
            """Failed to parse timestamp from JSON request body,
            an ISO format datetime string with UTC offset is expected,
            for example, 2022-05-26T13:50:00+05:00Z""",
        ) from e

    return timestamp


class EventHandler(APIHandler):
    """REST api handler for events"""

    auth_resource = AUTH_RESOURCE

    @web.authenticated
    @authorized
    async def post(self):
        """Emit an event."""
        payload = self.get_json_body()
        if payload is None:
            raise web.HTTPError(400, "No JSON data provided")

        try:
            validate_model(payload)
            self.event_logger.emit(
                schema_id=cast(str, payload.get("schema_id")),
                data=cast("Dict[str, Any]", payload.get("data")),
                timestamp_override=get_timestamp(payload),
            )
            self.set_status(204)
            self.finish()
        except web.HTTPError:
            raise
        except Exception as e:
            raise web.HTTPError(500, str(e)) from e


default_handlers = [
    (r"/api/events", EventHandler),
    (r"/api/events/subscribe", SubscribeWebsocket),
]