strexp / captum /_utils /typing.py
markytools's picture
added strexp
d61b9c7
#!/usr/bin/env python3
from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union
from torch import Tensor
from torch.nn import Module
if TYPE_CHECKING:
import sys
if sys.version_info >= (3, 8):
from typing import Literal # noqa: F401
else:
from typing_extensions import Literal # noqa: F401
else:
Literal = {True: bool, False: bool, (True, False): bool}
TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
TensorLikeList1D = List[float]
TensorLikeList2D = List[TensorLikeList1D]
TensorLikeList3D = List[TensorLikeList2D]
TensorLikeList4D = List[TensorLikeList3D]
TensorLikeList5D = List[TensorLikeList4D]
TensorLikeList = Union[
TensorLikeList1D,
TensorLikeList2D,
TensorLikeList3D,
TensorLikeList4D,
TensorLikeList5D,
]