File size: 3,656 Bytes
ddb9253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import annotations

import sys
from collections.abc import Callable, Iterable, Mapping
from contextlib import AbstractContextManager
from types import TracebackType
from typing import TYPE_CHECKING, Any

if sys.version_info < (3, 11):
    from ._exceptions import BaseExceptionGroup

if TYPE_CHECKING:
    _Handler = Callable[[BaseException], Any]


class _Catcher:
    def __init__(self, handler_map: Mapping[tuple[type[BaseException], ...], _Handler]):
        self._handler_map = handler_map

    def __enter__(self) -> None:
        pass

    def __exit__(
        self,
        etype: type[BaseException] | None,
        exc: BaseException | None,
        tb: TracebackType | None,
    ) -> bool:
        if exc is not None:
            unhandled = self.handle_exception(exc)
            if unhandled is exc:
                return False
            elif unhandled is None:
                return True
            else:
                raise unhandled from None

        return False

    def handle_exception(self, exc: BaseException) -> BaseException | None:
        excgroup: BaseExceptionGroup | None
        if isinstance(exc, BaseExceptionGroup):
            excgroup = exc
        else:
            excgroup = BaseExceptionGroup("", [exc])

        new_exceptions: list[BaseException] = []
        for exc_types, handler in self._handler_map.items():
            matched, excgroup = excgroup.split(exc_types)
            if matched:
                try:
                    handler(matched)
                except BaseException as new_exc:
                    new_exceptions.append(new_exc)

            if not excgroup:
                break

        if new_exceptions:
            if len(new_exceptions) == 1:
                return new_exceptions[0]

            if excgroup:
                new_exceptions.append(excgroup)

            return BaseExceptionGroup("", new_exceptions)
        elif (
            excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc
        ):
            return exc
        else:
            return excgroup


def catch(
    __handlers: Mapping[type[BaseException] | Iterable[type[BaseException]], _Handler]
) -> AbstractContextManager[None]:
    if not isinstance(__handlers, Mapping):
        raise TypeError("the argument must be a mapping")

    handler_map: dict[
        tuple[type[BaseException], ...], Callable[[BaseExceptionGroup]]
    ] = {}
    for type_or_iterable, handler in __handlers.items():
        iterable: tuple[type[BaseException]]
        if isinstance(type_or_iterable, type) and issubclass(
            type_or_iterable, BaseException
        ):
            iterable = (type_or_iterable,)
        elif isinstance(type_or_iterable, Iterable):
            iterable = tuple(type_or_iterable)
        else:
            raise TypeError(
                "each key must be either an exception classes or an iterable thereof"
            )

        if not callable(handler):
            raise TypeError("handlers must be callable")

        for exc_type in iterable:
            if not isinstance(exc_type, type) or not issubclass(
                exc_type, BaseException
            ):
                raise TypeError(
                    "each key must be either an exception classes or an iterable "
                    "thereof"
                )

            if issubclass(exc_type, BaseExceptionGroup):
                raise TypeError(
                    "catching ExceptionGroup with catch() is not allowed. "
                    "Use except instead."
                )

        handler_map[iterable] = handler

    return _Catcher(handler_map)