File size: 4,103 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""ZAP Authenticator in a Python Thread.

.. versionadded:: 14.1
"""

# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.

import asyncio
from threading import Event, Thread
from typing import Any, List, Optional

import zmq
import zmq.asyncio

from .base import Authenticator


class AuthenticationThread(Thread):
    """A Thread for running a zmq Authenticator

    This is run in the background by ThreadAuthenticator
    """

    pipe: zmq.Socket
    loop: asyncio.AbstractEventLoop
    authenticator: Authenticator
    poller: Optional[zmq.asyncio.Poller] = None

    def __init__(
        self,
        authenticator: Authenticator,
        pipe: zmq.Socket,
    ) -> None:
        super().__init__(daemon=True)
        self.authenticator = authenticator
        self.log = authenticator.log
        self.pipe = pipe

        self.started = Event()

    def run(self) -> None:
        """Start the Authentication Agent thread task"""

        loop = asyncio.new_event_loop()
        try:
            loop.run_until_complete(self._run())
        finally:
            if self.pipe:
                self.pipe.close()
                self.pipe = None  # type: ignore

            loop.close()

    async def _run(self):
        self.poller = zmq.asyncio.Poller()
        self.poller.register(self.pipe, zmq.POLLIN)
        self.poller.register(self.authenticator.zap_socket, zmq.POLLIN)
        self.started.set()

        while True:
            events = dict(await self.poller.poll())
            if self.pipe in events:
                msg = self.pipe.recv_multipart()
                if self._handle_pipe_message(msg):
                    return
            if self.authenticator.zap_socket in events:
                msg = self.authenticator.zap_socket.recv_multipart()
                await self.authenticator.handle_zap_message(msg)

    def _handle_pipe_message(self, msg: List[bytes]) -> bool:
        command = msg[0]
        self.log.debug("auth received API command %r", command)

        if command == b'TERMINATE':
            return True

        else:
            self.log.error("Invalid auth command from API: %r", command)
            self.pipe.send(b'ERROR')

        return False


class ThreadAuthenticator(Authenticator):
    """Run ZAP authentication in a background thread"""

    pipe: "zmq.Socket"
    pipe_endpoint: str = ''
    thread: AuthenticationThread

    def __init__(
        self,
        context: Optional["zmq.Context"] = None,
        encoding: str = 'utf-8',
        log: Any = None,
    ):
        super().__init__(context=context, encoding=encoding, log=log)
        self.pipe = None  # type: ignore
        self.pipe_endpoint = f"inproc://{id(self)}.inproc"
        self.thread = None  # type: ignore

    def start(self) -> None:
        """Start the authentication thread"""
        # start the Authenticator
        super().start()

        # create a socket pair to communicate with auth thread.
        self.pipe = self.context.socket(zmq.PAIR, socket_class=zmq.Socket)
        self.pipe.linger = 1
        self.pipe.bind(self.pipe_endpoint)
        thread_pipe = self.context.socket(zmq.PAIR, socket_class=zmq.Socket)
        thread_pipe.linger = 1
        thread_pipe.connect(self.pipe_endpoint)
        self.thread = AuthenticationThread(authenticator=self, pipe=thread_pipe)
        self.thread.start()
        if not self.thread.started.wait(timeout=10):
            raise RuntimeError("Authenticator thread failed to start")

    def stop(self) -> None:
        """Stop the authentication thread"""
        if self.pipe:
            self.pipe.send(b'TERMINATE')
            if self.is_alive():
                self.thread.join()
            self.thread = None  # type: ignore
            self.pipe.close()
            self.pipe = None  # type: ignore
        super().stop()

    def is_alive(self) -> bool:
        """Is the ZAP thread currently running?"""
        return bool(self.thread and self.thread.is_alive())

    def __del__(self) -> None:
        self.stop()


__all__ = ['ThreadAuthenticator']