File size: 16,337 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
"""Base implementation of 0MQ authentication."""

# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.

import logging
import os
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union

import zmq
from zmq.error import _check_version
from zmq.utils import z85

from .certs import load_certificates

CURVE_ALLOW_ANY = '*'
VERSION = b'1.0'


class Authenticator:
    """Implementation of ZAP authentication for zmq connections.

    This authenticator class does not register with an event loop. As a result,
    you will need to manually call `handle_zap_message`::

        auth = zmq.Authenticator()
        auth.allow("127.0.0.1")
        auth.start()
        while True:
            await auth.handle_zap_msg(auth.zap_socket.recv_multipart())

    Alternatively, you can register `auth.zap_socket` with a poller.

    Since many users will want to run ZAP in a way that does not block the
    main thread, other authentication classes (such as :mod:`zmq.auth.thread`)
    are provided.

    Note:

    - libzmq provides four levels of security: default NULL (which the Authenticator does
      not see), and authenticated NULL, PLAIN, CURVE, and GSSAPI, which the Authenticator can see.
    - until you add policies, all incoming NULL connections are allowed.
      (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
    - GSSAPI requires no configuration.
    """

    context: "zmq.Context"
    encoding: str
    allow_any: bool
    credentials_providers: Dict[str, Any]
    zap_socket: "zmq.Socket"
    _allowed: Set[str]
    _denied: Set[str]
    passwords: Dict[str, Dict[str, str]]
    certs: Dict[str, Dict[bytes, Any]]
    log: Any

    def __init__(
        self,
        context: Optional["zmq.Context"] = None,
        encoding: str = 'utf-8',
        log: Any = None,
    ):
        _check_version((4, 0), "security")
        self.context = context or zmq.Context.instance()
        self.encoding = encoding
        self.allow_any = False
        self.credentials_providers = {}
        self.zap_socket = None  # type: ignore
        self._allowed = set()
        self._denied = set()
        # passwords is a dict keyed by domain and contains values
        # of dicts with username:password pairs.
        self.passwords = {}
        # certs is dict keyed by domain and contains values
        # of dicts keyed by the public keys from the specified location.
        self.certs = {}
        self.log = log or logging.getLogger('zmq.auth')

    def start(self) -> None:
        """Create and bind the ZAP socket"""
        self.zap_socket = self.context.socket(zmq.REP, socket_class=zmq.Socket)
        self.zap_socket.linger = 1
        self.zap_socket.bind("inproc://zeromq.zap.01")
        self.log.debug("Starting")

    def stop(self) -> None:
        """Close the ZAP socket"""
        if self.zap_socket:
            self.zap_socket.close()
        self.zap_socket = None  # type: ignore

    def allow(self, *addresses: str) -> None:
        """Allow IP address(es).

        Connections from addresses not explicitly allowed will be rejected.

        - For NULL, all clients from this address will be accepted.
        - For real auth setups, they will be allowed to continue with authentication.

        allow is mutually exclusive with deny.
        """
        if self._denied:
            raise ValueError("Only use allow or deny, not both")
        self.log.debug("Allowing %s", ','.join(addresses))
        self._allowed.update(addresses)

    def deny(self, *addresses: str) -> None:
        """Deny IP address(es).

        Addresses not explicitly denied will be allowed to continue with authentication.

        deny is mutually exclusive with allow.
        """
        if self._allowed:
            raise ValueError("Only use a allow or deny, not both")
        self.log.debug("Denying %s", ','.join(addresses))
        self._denied.update(addresses)

    def configure_plain(
        self, domain: str = '*', passwords: Optional[Dict[str, str]] = None
    ) -> None:
        """Configure PLAIN authentication for a given domain.

        PLAIN authentication uses a plain-text password file.
        To cover all domains, use "*".
        You can modify the password file at any time; it is reloaded automatically.
        """
        if passwords:
            self.passwords[domain] = passwords
        self.log.debug("Configure plain: %s", domain)

    def configure_curve(
        self, domain: str = '*', location: Union[str, os.PathLike] = "."
    ) -> None:
        """Configure CURVE authentication for a given domain.

        CURVE authentication uses a directory that holds all public client certificates,
        i.e. their public keys.

        To cover all domains, use "*".

        You can add and remove certificates in that directory at any time. configure_curve must be called
        every time certificates are added or removed, in order to update the Authenticator's state

        To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
        """
        # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
        # treat location as a directory that holds the certificates.
        self.log.debug("Configure curve: %s[%s]", domain, location)
        if location == CURVE_ALLOW_ANY:
            self.allow_any = True
        else:
            self.allow_any = False
            try:
                self.certs[domain] = load_certificates(location)
            except Exception as e:
                self.log.error("Failed to load CURVE certs from %s: %s", location, e)

    def configure_curve_callback(
        self, domain: str = '*', credentials_provider: Any = None
    ) -> None:
        """Configure CURVE authentication for a given domain.

        CURVE authentication using a callback function validating
        the client public key according to a custom mechanism, e.g. checking the
        key against records in a db. credentials_provider is an object of a class which
        implements a callback method accepting two parameters (domain and key), e.g.::

            class CredentialsProvider(object):

                def __init__(self):
                    ...e.g. db connection

                def callback(self, domain, key):
                    valid = ...lookup key and/or domain in db
                    if valid:
                        logging.info('Authorizing: {0}, {1}'.format(domain, key))
                        return True
                    else:
                        logging.warning('NOT Authorizing: {0}, {1}'.format(domain, key))
                        return False

        To cover all domains, use "*".
        """

        self.allow_any = False

        if credentials_provider is not None:
            self.credentials_providers[domain] = credentials_provider
        else:
            self.log.error("None credentials_provider provided for domain:%s", domain)

    def curve_user_id(self, client_public_key: bytes) -> str:
        """Return the User-Id corresponding to a CURVE client's public key

        Default implementation uses the z85-encoding of the public key.

        Override to define a custom mapping of public key : user-id

        This is only called on successful authentication.

        Parameters
        ----------
        client_public_key: bytes
            The client public key used for the given message

        Returns
        -------
        user_id: unicode
            The user ID as text
        """
        return z85.encode(client_public_key).decode('ascii')

    def configure_gssapi(
        self, domain: str = '*', location: Optional[str] = None
    ) -> None:
        """Configure GSSAPI authentication

        Currently this is a no-op because there is nothing to configure with GSSAPI.
        """

    async def handle_zap_message(self, msg: List[bytes]):
        """Perform ZAP authentication"""
        if len(msg) < 6:
            self.log.error("Invalid ZAP message, not enough frames: %r", msg)
            if len(msg) < 2:
                self.log.error("Not enough information to reply")
            else:
                self._send_zap_reply(msg[1], b"400", b"Not enough frames")
            return

        version, request_id, domain, address, identity, mechanism = msg[:6]
        credentials = msg[6:]

        domain = domain.decode(self.encoding, 'replace')
        address = address.decode(self.encoding, 'replace')

        if version != VERSION:
            self.log.error("Invalid ZAP version: %r", msg)
            self._send_zap_reply(request_id, b"400", b"Invalid version")
            return

        self.log.debug(
            "version: %r, request_id: %r, domain: %r,"
            " address: %r, identity: %r, mechanism: %r",
            version,
            request_id,
            domain,
            address,
            identity,
            mechanism,
        )

        # Is address is explicitly allowed or _denied?
        allowed = False
        denied = False
        reason = b"NO ACCESS"

        if self._allowed:
            if address in self._allowed:
                allowed = True
                self.log.debug("PASSED (allowed) address=%s", address)
            else:
                denied = True
                reason = b"Address not allowed"
                self.log.debug("DENIED (not allowed) address=%s", address)

        elif self._denied:
            if address in self._denied:
                denied = True
                reason = b"Address denied"
                self.log.debug("DENIED (denied) address=%s", address)
            else:
                allowed = True
                self.log.debug("PASSED (not denied) address=%s", address)

        # Perform authentication mechanism-specific checks if necessary
        username = "anonymous"
        if not denied:
            if mechanism == b'NULL' and not allowed:
                # For NULL, we allow if the address wasn't denied
                self.log.debug("ALLOWED (NULL)")
                allowed = True

            elif mechanism == b'PLAIN':
                # For PLAIN, even a _alloweded address must authenticate
                if len(credentials) != 2:
                    self.log.error("Invalid PLAIN credentials: %r", credentials)
                    self._send_zap_reply(request_id, b"400", b"Invalid credentials")
                    return
                username, password = (
                    c.decode(self.encoding, 'replace') for c in credentials
                )
                allowed, reason = self._authenticate_plain(domain, username, password)

            elif mechanism == b'CURVE':
                # For CURVE, even a _alloweded address must authenticate
                if len(credentials) != 1:
                    self.log.error("Invalid CURVE credentials: %r", credentials)
                    self._send_zap_reply(request_id, b"400", b"Invalid credentials")
                    return
                key = credentials[0]
                allowed, reason = await self._authenticate_curve(domain, key)
                if allowed:
                    username = self.curve_user_id(key)

            elif mechanism == b'GSSAPI':
                if len(credentials) != 1:
                    self.log.error("Invalid GSSAPI credentials: %r", credentials)
                    self._send_zap_reply(request_id, b"400", b"Invalid credentials")
                    return
                # use principal as user-id for now
                principal = credentials[0]
                username = principal.decode("utf8")
                allowed, reason = self._authenticate_gssapi(domain, principal)

        if allowed:
            self._send_zap_reply(request_id, b"200", b"OK", username)
        else:
            self._send_zap_reply(request_id, b"400", reason)

    def _authenticate_plain(
        self, domain: str, username: str, password: str
    ) -> Tuple[bool, bytes]:
        """PLAIN ZAP authentication"""
        allowed = False
        reason = b""
        if self.passwords:
            # If no domain is not specified then use the default domain
            if not domain:
                domain = '*'

            if domain in self.passwords:
                if username in self.passwords[domain]:
                    if password == self.passwords[domain][username]:
                        allowed = True
                    else:
                        reason = b"Invalid password"
                else:
                    reason = b"Invalid username"
            else:
                reason = b"Invalid domain"

            if allowed:
                self.log.debug(
                    "ALLOWED (PLAIN) domain=%s username=%s password=%s",
                    domain,
                    username,
                    password,
                )
            else:
                self.log.debug("DENIED %s", reason)

        else:
            reason = b"No passwords defined"
            self.log.debug("DENIED (PLAIN) %s", reason)

        return allowed, reason

    async def _authenticate_curve(
        self, domain: str, client_key: bytes
    ) -> Tuple[bool, bytes]:
        """CURVE ZAP authentication"""
        allowed = False
        reason = b""
        if self.allow_any:
            allowed = True
            reason = b"OK"
            self.log.debug("ALLOWED (CURVE allow any client)")
        elif self.credentials_providers != {}:
            # If no explicit domain is specified then use the default domain
            if not domain:
                domain = '*'

            if domain in self.credentials_providers:
                z85_client_key = z85.encode(client_key)
                # Callback to check if key is Allowed
                r = self.credentials_providers[domain].callback(domain, z85_client_key)
                if isinstance(r, Awaitable):
                    r = await r
                if r:
                    allowed = True
                    reason = b"OK"
                else:
                    reason = b"Unknown key"

                status = "ALLOWED" if allowed else "DENIED"
                self.log.debug(
                    "%s (CURVE auth_callback) domain=%s client_key=%s",
                    status,
                    domain,
                    z85_client_key,
                )
            else:
                reason = b"Unknown domain"
        else:
            # If no explicit domain is specified then use the default domain
            if not domain:
                domain = '*'

            if domain in self.certs:
                # The certs dict stores keys in z85 format, convert binary key to z85 bytes
                z85_client_key = z85.encode(client_key)
                if self.certs[domain].get(z85_client_key):
                    allowed = True
                    reason = b"OK"
                else:
                    reason = b"Unknown key"

                status = "ALLOWED" if allowed else "DENIED"
                self.log.debug(
                    "%s (CURVE) domain=%s client_key=%s",
                    status,
                    domain,
                    z85_client_key,
                )
            else:
                reason = b"Unknown domain"

        return allowed, reason

    def _authenticate_gssapi(self, domain: str, principal: bytes) -> Tuple[bool, bytes]:
        """Nothing to do for GSSAPI, which has already been handled by an external service."""
        self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal)
        return True, b'OK'

    def _send_zap_reply(
        self,
        request_id: bytes,
        status_code: bytes,
        status_text: bytes,
        user_id: str = 'anonymous',
    ) -> None:
        """Send a ZAP reply to finish the authentication."""
        user_id = user_id if status_code == b'200' else b''
        if isinstance(user_id, str):
            user_id = user_id.encode(self.encoding, 'replace')
        metadata = b''  # not currently used
        self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
        reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
        self.zap_socket.send_multipart(reply)


__all__ = ['Authenticator', 'CURVE_ALLOW_ANY']