import types from lime.utils.generic_utils import has_arg from skimage.segmentation import felzenszwalb, slic, quickshift class BaseWrapper(object): """Base class for LIME Scikit-Image wrapper Args: target_fn: callable function or class instance target_params: dict, parameters to pass to the target_fn 'target_params' takes parameters required to instanciate the desired Scikit-Image class/model """ def __init__(self, target_fn=None, **target_params): self.target_fn = target_fn self.target_params = target_params self.target_fn = target_fn self.target_params = target_params def _check_params(self, parameters): """Checks for mistakes in 'parameters' Args : parameters: dict, parameters to be checked Raises : ValueError: if any parameter is not a valid argument for the target function or the target function is not defined TypeError: if argument parameters is not iterable """ a_valid_fn = [] if self.target_fn is None: if callable(self): a_valid_fn.append(self.__call__) else: raise TypeError('invalid argument: tested object is not callable,\ please provide a valid target_fn') elif isinstance(self.target_fn, types.FunctionType) \ or isinstance(self.target_fn, types.MethodType): a_valid_fn.append(self.target_fn) else: a_valid_fn.append(self.target_fn.__call__) if not isinstance(parameters, str): for p in parameters: for fn in a_valid_fn: if has_arg(fn, p): pass else: raise ValueError('{} is not a valid parameter'.format(p)) else: raise TypeError('invalid argument: list or dictionnary expected') def set_params(self, **params): """Sets the parameters of this estimator. Args: **params: Dictionary of parameter names mapped to their values. Raises : ValueError: if any parameter is not a valid argument for the target function """ self._check_params(params) self.target_params = params def filter_params(self, fn, override=None): """Filters `target_params` and return those in `fn`'s arguments. Args: fn : arbitrary function override: dict, values to override target_params Returns: result : dict, dictionary containing variables in both target_params and fn's arguments. """ override = override or {} result = {} for name, value in self.target_params.items(): if has_arg(fn, name): result.update({name: value}) result.update(override) return result class SegmentationAlgorithm(BaseWrapper): """ Define the image segmentation function based on Scikit-Image implementation and a set of provided parameters Args: algo_type: string, segmentation algorithm among the following: 'quickshift', 'slic', 'felzenszwalb' target_params: dict, algorithm parameters (valid model paramters as define in Scikit-Image documentation) """ def __init__(self, algo_type, **target_params): self.algo_type = algo_type if (self.algo_type == 'quickshift'): BaseWrapper.__init__(self, quickshift, **target_params) kwargs = self.filter_params(quickshift) self.set_params(**kwargs) elif (self.algo_type == 'felzenszwalb'): BaseWrapper.__init__(self, felzenszwalb, **target_params) kwargs = self.filter_params(felzenszwalb) self.set_params(**kwargs) elif (self.algo_type == 'slic'): BaseWrapper.__init__(self, slic, **target_params) kwargs = self.filter_params(slic) self.set_params(**kwargs) def __call__(self, *args): return self.target_fn(args[0], **self.target_params)