Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union | |
from captum.attr import ( | |
Deconvolution, | |
DeepLift, | |
FeatureAblation, | |
GuidedBackprop, | |
InputXGradient, | |
IntegratedGradients, | |
Occlusion, | |
Saliency, | |
) | |
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS | |
class NumberConfig(NamedTuple): | |
value: int = 1 | |
limit: Tuple[Optional[int], Optional[int]] = (None, None) | |
type: str = "number" | |
class StrEnumConfig(NamedTuple): | |
value: str | |
limit: List[str] | |
type: str = "enum" | |
class StrConfig(NamedTuple): | |
value: str | |
type: str = "string" | |
Config = Union[NumberConfig, StrEnumConfig, StrConfig] | |
SUPPORTED_ATTRIBUTION_METHODS = [ | |
Deconvolution, | |
DeepLift, | |
GuidedBackprop, | |
InputXGradient, | |
IntegratedGradients, | |
Saliency, | |
FeatureAblation, | |
Occlusion, | |
] | |
class ConfigParameters(NamedTuple): | |
params: Dict[str, Config] | |
help_info: Optional[str] = None # TODO fill out help for each method | |
post_process: Optional[Dict[str, Callable[[Any], Any]]] = None | |
ATTRIBUTION_NAMES_TO_METHODS = { | |
# mypy bug - treating it as a type instead of a class | |
cls.get_name(): cls # type: ignore | |
for cls in SUPPORTED_ATTRIBUTION_METHODS | |
} | |
def _str_to_tuple(s): | |
if isinstance(s, tuple): | |
return s | |
return tuple([int(i) for i in s.split()]) | |
ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = { | |
IntegratedGradients.get_name(): ConfigParameters( | |
params={ | |
"n_steps": NumberConfig(value=25, limit=(2, None)), | |
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"), | |
}, | |
post_process={"n_steps": int}, | |
), | |
FeatureAblation.get_name(): ConfigParameters( | |
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))}, | |
), | |
Occlusion.get_name(): ConfigParameters( | |
params={ | |
"sliding_window_shapes": StrConfig(value=""), | |
"strides": StrConfig(value=""), | |
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)), | |
}, | |
post_process={ | |
"sliding_window_shapes": _str_to_tuple, | |
"strides": _str_to_tuple, | |
"perturbations_per_eval": int, | |
}, | |
), | |
} | |