markytools's picture
added strexp
d61b9c7
#!/usr/bin/env python3
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from captum.attr import (
Deconvolution,
DeepLift,
FeatureAblation,
GuidedBackprop,
InputXGradient,
IntegratedGradients,
Occlusion,
Saliency,
)
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
class NumberConfig(NamedTuple):
value: int = 1
limit: Tuple[Optional[int], Optional[int]] = (None, None)
type: str = "number"
class StrEnumConfig(NamedTuple):
value: str
limit: List[str]
type: str = "enum"
class StrConfig(NamedTuple):
value: str
type: str = "string"
Config = Union[NumberConfig, StrEnumConfig, StrConfig]
SUPPORTED_ATTRIBUTION_METHODS = [
Deconvolution,
DeepLift,
GuidedBackprop,
InputXGradient,
IntegratedGradients,
Saliency,
FeatureAblation,
Occlusion,
]
class ConfigParameters(NamedTuple):
params: Dict[str, Config]
help_info: Optional[str] = None # TODO fill out help for each method
post_process: Optional[Dict[str, Callable[[Any], Any]]] = None
ATTRIBUTION_NAMES_TO_METHODS = {
# mypy bug - treating it as a type instead of a class
cls.get_name(): cls # type: ignore
for cls in SUPPORTED_ATTRIBUTION_METHODS
}
def _str_to_tuple(s):
if isinstance(s, tuple):
return s
return tuple([int(i) for i in s.split()])
ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = {
IntegratedGradients.get_name(): ConfigParameters(
params={
"n_steps": NumberConfig(value=25, limit=(2, None)),
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
},
post_process={"n_steps": int},
),
FeatureAblation.get_name(): ConfigParameters(
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))},
),
Occlusion.get_name(): ConfigParameters(
params={
"sliding_window_shapes": StrConfig(value=""),
"strides": StrConfig(value=""),
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
},
post_process={
"sliding_window_shapes": _str_to_tuple,
"strides": _str_to_tuple,
"perturbations_per_eval": int,
},
),
}