Spaces:
Build error
Build error
import types | |
from lime.utils.generic_utils import has_arg | |
from skimage.segmentation import felzenszwalb, slic, quickshift | |
class BaseWrapper(object): | |
"""Base class for LIME Scikit-Image wrapper | |
Args: | |
target_fn: callable function or class instance | |
target_params: dict, parameters to pass to the target_fn | |
'target_params' takes parameters required to instanciate the | |
desired Scikit-Image class/model | |
""" | |
def __init__(self, target_fn=None, **target_params): | |
self.target_fn = target_fn | |
self.target_params = target_params | |
self.target_fn = target_fn | |
self.target_params = target_params | |
def _check_params(self, parameters): | |
"""Checks for mistakes in 'parameters' | |
Args : | |
parameters: dict, parameters to be checked | |
Raises : | |
ValueError: if any parameter is not a valid argument for the target function | |
or the target function is not defined | |
TypeError: if argument parameters is not iterable | |
""" | |
a_valid_fn = [] | |
if self.target_fn is None: | |
if callable(self): | |
a_valid_fn.append(self.__call__) | |
else: | |
raise TypeError('invalid argument: tested object is not callable,\ | |
please provide a valid target_fn') | |
elif isinstance(self.target_fn, types.FunctionType) \ | |
or isinstance(self.target_fn, types.MethodType): | |
a_valid_fn.append(self.target_fn) | |
else: | |
a_valid_fn.append(self.target_fn.__call__) | |
if not isinstance(parameters, str): | |
for p in parameters: | |
for fn in a_valid_fn: | |
if has_arg(fn, p): | |
pass | |
else: | |
raise ValueError('{} is not a valid parameter'.format(p)) | |
else: | |
raise TypeError('invalid argument: list or dictionnary expected') | |
def set_params(self, **params): | |
"""Sets the parameters of this estimator. | |
Args: | |
**params: Dictionary of parameter names mapped to their values. | |
Raises : | |
ValueError: if any parameter is not a valid argument | |
for the target function | |
""" | |
self._check_params(params) | |
self.target_params = params | |
def filter_params(self, fn, override=None): | |
"""Filters `target_params` and return those in `fn`'s arguments. | |
Args: | |
fn : arbitrary function | |
override: dict, values to override target_params | |
Returns: | |
result : dict, dictionary containing variables | |
in both target_params and fn's arguments. | |
""" | |
override = override or {} | |
result = {} | |
for name, value in self.target_params.items(): | |
if has_arg(fn, name): | |
result.update({name: value}) | |
result.update(override) | |
return result | |
class SegmentationAlgorithm(BaseWrapper): | |
""" Define the image segmentation function based on Scikit-Image | |
implementation and a set of provided parameters | |
Args: | |
algo_type: string, segmentation algorithm among the following: | |
'quickshift', 'slic', 'felzenszwalb' | |
target_params: dict, algorithm parameters (valid model paramters | |
as define in Scikit-Image documentation) | |
""" | |
def __init__(self, algo_type, **target_params): | |
self.algo_type = algo_type | |
if (self.algo_type == 'quickshift'): | |
BaseWrapper.__init__(self, quickshift, **target_params) | |
kwargs = self.filter_params(quickshift) | |
self.set_params(**kwargs) | |
elif (self.algo_type == 'felzenszwalb'): | |
BaseWrapper.__init__(self, felzenszwalb, **target_params) | |
kwargs = self.filter_params(felzenszwalb) | |
self.set_params(**kwargs) | |
elif (self.algo_type == 'slic'): | |
BaseWrapper.__init__(self, slic, **target_params) | |
kwargs = self.filter_params(slic) | |
self.set_params(**kwargs) | |
def __call__(self, *args): | |
return self.target_fn(args[0], **self.target_params) | |