|
from __future__ import annotations |
|
|
|
import binascii |
|
import json |
|
import warnings |
|
from typing import TYPE_CHECKING, Any |
|
|
|
from .algorithms import ( |
|
Algorithm, |
|
get_default_algorithms, |
|
has_crypto, |
|
requires_cryptography, |
|
) |
|
from .api_jwk import PyJWK |
|
from .exceptions import ( |
|
DecodeError, |
|
InvalidAlgorithmError, |
|
InvalidSignatureError, |
|
InvalidTokenError, |
|
) |
|
from .utils import base64url_decode, base64url_encode |
|
from .warnings import RemovedInPyjwt3Warning |
|
|
|
if TYPE_CHECKING: |
|
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys |
|
|
|
|
|
class PyJWS: |
|
header_typ = "JWT" |
|
|
|
def __init__( |
|
self, |
|
algorithms: list[str] | None = None, |
|
options: dict[str, Any] | None = None, |
|
) -> None: |
|
self._algorithms = get_default_algorithms() |
|
self._valid_algs = ( |
|
set(algorithms) if algorithms is not None else set(self._algorithms) |
|
) |
|
|
|
|
|
for key in list(self._algorithms.keys()): |
|
if key not in self._valid_algs: |
|
del self._algorithms[key] |
|
|
|
if options is None: |
|
options = {} |
|
self.options = {**self._get_default_options(), **options} |
|
|
|
@staticmethod |
|
def _get_default_options() -> dict[str, bool]: |
|
return {"verify_signature": True} |
|
|
|
def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: |
|
""" |
|
Registers a new Algorithm for use when creating and verifying tokens. |
|
""" |
|
if alg_id in self._algorithms: |
|
raise ValueError("Algorithm already has a handler.") |
|
|
|
if not isinstance(alg_obj, Algorithm): |
|
raise TypeError("Object is not of type `Algorithm`") |
|
|
|
self._algorithms[alg_id] = alg_obj |
|
self._valid_algs.add(alg_id) |
|
|
|
def unregister_algorithm(self, alg_id: str) -> None: |
|
""" |
|
Unregisters an Algorithm for use when creating and verifying tokens |
|
Throws KeyError if algorithm is not registered. |
|
""" |
|
if alg_id not in self._algorithms: |
|
raise KeyError( |
|
"The specified algorithm could not be removed" |
|
" because it is not registered." |
|
) |
|
|
|
del self._algorithms[alg_id] |
|
self._valid_algs.remove(alg_id) |
|
|
|
def get_algorithms(self) -> list[str]: |
|
""" |
|
Returns a list of supported values for the 'alg' parameter. |
|
""" |
|
return list(self._valid_algs) |
|
|
|
def get_algorithm_by_name(self, alg_name: str) -> Algorithm: |
|
""" |
|
For a given string name, return the matching Algorithm object. |
|
|
|
Example usage: |
|
|
|
>>> jws_obj.get_algorithm_by_name("RS256") |
|
""" |
|
try: |
|
return self._algorithms[alg_name] |
|
except KeyError as e: |
|
if not has_crypto and alg_name in requires_cryptography: |
|
raise NotImplementedError( |
|
f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" |
|
) from e |
|
raise NotImplementedError("Algorithm not supported") from e |
|
|
|
def encode( |
|
self, |
|
payload: bytes, |
|
key: AllowedPrivateKeys | str | bytes, |
|
algorithm: str | None = "HS256", |
|
headers: dict[str, Any] | None = None, |
|
json_encoder: type[json.JSONEncoder] | None = None, |
|
is_payload_detached: bool = False, |
|
sort_headers: bool = True, |
|
) -> str: |
|
segments = [] |
|
|
|
|
|
algorithm_: str = algorithm if algorithm is not None else "none" |
|
|
|
|
|
if headers: |
|
headers_alg = headers.get("alg") |
|
if headers_alg: |
|
algorithm_ = headers["alg"] |
|
|
|
headers_b64 = headers.get("b64") |
|
if headers_b64 is False: |
|
is_payload_detached = True |
|
|
|
|
|
header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} |
|
|
|
if headers: |
|
self._validate_headers(headers) |
|
header.update(headers) |
|
|
|
if not header["typ"]: |
|
del header["typ"] |
|
|
|
if is_payload_detached: |
|
header["b64"] = False |
|
elif "b64" in header: |
|
|
|
del header["b64"] |
|
|
|
json_header = json.dumps( |
|
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers |
|
).encode() |
|
|
|
segments.append(base64url_encode(json_header)) |
|
|
|
if is_payload_detached: |
|
msg_payload = payload |
|
else: |
|
msg_payload = base64url_encode(payload) |
|
segments.append(msg_payload) |
|
|
|
|
|
signing_input = b".".join(segments) |
|
|
|
alg_obj = self.get_algorithm_by_name(algorithm_) |
|
key = alg_obj.prepare_key(key) |
|
signature = alg_obj.sign(signing_input, key) |
|
|
|
segments.append(base64url_encode(signature)) |
|
|
|
|
|
if is_payload_detached: |
|
segments[1] = b"" |
|
encoded_string = b".".join(segments) |
|
|
|
return encoded_string.decode("utf-8") |
|
|
|
def decode_complete( |
|
self, |
|
jwt: str | bytes, |
|
key: AllowedPublicKeys | PyJWK | str | bytes = "", |
|
algorithms: list[str] | None = None, |
|
options: dict[str, Any] | None = None, |
|
detached_payload: bytes | None = None, |
|
**kwargs, |
|
) -> dict[str, Any]: |
|
if kwargs: |
|
warnings.warn( |
|
"passing additional kwargs to decode_complete() is deprecated " |
|
"and will be removed in pyjwt version 3. " |
|
f"Unsupported kwargs: {tuple(kwargs.keys())}", |
|
RemovedInPyjwt3Warning, |
|
) |
|
if options is None: |
|
options = {} |
|
merged_options = {**self.options, **options} |
|
verify_signature = merged_options["verify_signature"] |
|
|
|
if verify_signature and not algorithms and not isinstance(key, PyJWK): |
|
raise DecodeError( |
|
'It is required that you pass in a value for the "algorithms" argument when calling decode().' |
|
) |
|
|
|
payload, signing_input, header, signature = self._load(jwt) |
|
|
|
if header.get("b64", True) is False: |
|
if detached_payload is None: |
|
raise DecodeError( |
|
'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' |
|
) |
|
payload = detached_payload |
|
signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) |
|
|
|
if verify_signature: |
|
self._verify_signature(signing_input, header, signature, key, algorithms) |
|
|
|
return { |
|
"payload": payload, |
|
"header": header, |
|
"signature": signature, |
|
} |
|
|
|
def decode( |
|
self, |
|
jwt: str | bytes, |
|
key: AllowedPublicKeys | PyJWK | str | bytes = "", |
|
algorithms: list[str] | None = None, |
|
options: dict[str, Any] | None = None, |
|
detached_payload: bytes | None = None, |
|
**kwargs, |
|
) -> Any: |
|
if kwargs: |
|
warnings.warn( |
|
"passing additional kwargs to decode() is deprecated " |
|
"and will be removed in pyjwt version 3. " |
|
f"Unsupported kwargs: {tuple(kwargs.keys())}", |
|
RemovedInPyjwt3Warning, |
|
) |
|
decoded = self.decode_complete( |
|
jwt, key, algorithms, options, detached_payload=detached_payload |
|
) |
|
return decoded["payload"] |
|
|
|
def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: |
|
"""Returns back the JWT header parameters as a dict() |
|
|
|
Note: The signature is not verified so the header parameters |
|
should not be fully trusted until signature verification is complete |
|
""" |
|
headers = self._load(jwt)[2] |
|
self._validate_headers(headers) |
|
|
|
return headers |
|
|
|
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: |
|
if isinstance(jwt, str): |
|
jwt = jwt.encode("utf-8") |
|
|
|
if not isinstance(jwt, bytes): |
|
raise DecodeError(f"Invalid token type. Token must be a {bytes}") |
|
|
|
try: |
|
signing_input, crypto_segment = jwt.rsplit(b".", 1) |
|
header_segment, payload_segment = signing_input.split(b".", 1) |
|
except ValueError as err: |
|
raise DecodeError("Not enough segments") from err |
|
|
|
try: |
|
header_data = base64url_decode(header_segment) |
|
except (TypeError, binascii.Error) as err: |
|
raise DecodeError("Invalid header padding") from err |
|
|
|
try: |
|
header = json.loads(header_data) |
|
except ValueError as e: |
|
raise DecodeError(f"Invalid header string: {e}") from e |
|
|
|
if not isinstance(header, dict): |
|
raise DecodeError("Invalid header string: must be a json object") |
|
|
|
try: |
|
payload = base64url_decode(payload_segment) |
|
except (TypeError, binascii.Error) as err: |
|
raise DecodeError("Invalid payload padding") from err |
|
|
|
try: |
|
signature = base64url_decode(crypto_segment) |
|
except (TypeError, binascii.Error) as err: |
|
raise DecodeError("Invalid crypto padding") from err |
|
|
|
return (payload, signing_input, header, signature) |
|
|
|
def _verify_signature( |
|
self, |
|
signing_input: bytes, |
|
header: dict[str, Any], |
|
signature: bytes, |
|
key: AllowedPublicKeys | PyJWK | str | bytes = "", |
|
algorithms: list[str] | None = None, |
|
) -> None: |
|
if algorithms is None and isinstance(key, PyJWK): |
|
algorithms = [key.algorithm_name] |
|
try: |
|
alg = header["alg"] |
|
except KeyError: |
|
raise InvalidAlgorithmError("Algorithm not specified") |
|
|
|
if not alg or (algorithms is not None and alg not in algorithms): |
|
raise InvalidAlgorithmError("The specified alg value is not allowed") |
|
|
|
if isinstance(key, PyJWK): |
|
alg_obj = key.Algorithm |
|
prepared_key = key.key |
|
else: |
|
try: |
|
alg_obj = self.get_algorithm_by_name(alg) |
|
except NotImplementedError as e: |
|
raise InvalidAlgorithmError("Algorithm not supported") from e |
|
prepared_key = alg_obj.prepare_key(key) |
|
|
|
if not alg_obj.verify(signing_input, prepared_key, signature): |
|
raise InvalidSignatureError("Signature verification failed") |
|
|
|
def _validate_headers(self, headers: dict[str, Any]) -> None: |
|
if "kid" in headers: |
|
self._validate_kid(headers["kid"]) |
|
|
|
def _validate_kid(self, kid: Any) -> None: |
|
if not isinstance(kid, str): |
|
raise InvalidTokenError("Key ID header parameter must be a string") |
|
|
|
|
|
_jws_global_obj = PyJWS() |
|
encode = _jws_global_obj.encode |
|
decode_complete = _jws_global_obj.decode_complete |
|
decode = _jws_global_obj.decode |
|
register_algorithm = _jws_global_obj.register_algorithm |
|
unregister_algorithm = _jws_global_obj.unregister_algorithm |
|
get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name |
|
get_unverified_header = _jws_global_obj.get_unverified_header |
|
|