strexp / lime /wrappers /scikit_image.py
markytools's picture
added strexp
d61b9c7
import types
from lime.utils.generic_utils import has_arg
from skimage.segmentation import felzenszwalb, slic, quickshift
class BaseWrapper(object):
"""Base class for LIME Scikit-Image wrapper
Args:
target_fn: callable function or class instance
target_params: dict, parameters to pass to the target_fn
'target_params' takes parameters required to instanciate the
desired Scikit-Image class/model
"""
def __init__(self, target_fn=None, **target_params):
self.target_fn = target_fn
self.target_params = target_params
self.target_fn = target_fn
self.target_params = target_params
def _check_params(self, parameters):
"""Checks for mistakes in 'parameters'
Args :
parameters: dict, parameters to be checked
Raises :
ValueError: if any parameter is not a valid argument for the target function
or the target function is not defined
TypeError: if argument parameters is not iterable
"""
a_valid_fn = []
if self.target_fn is None:
if callable(self):
a_valid_fn.append(self.__call__)
else:
raise TypeError('invalid argument: tested object is not callable,\
please provide a valid target_fn')
elif isinstance(self.target_fn, types.FunctionType) \
or isinstance(self.target_fn, types.MethodType):
a_valid_fn.append(self.target_fn)
else:
a_valid_fn.append(self.target_fn.__call__)
if not isinstance(parameters, str):
for p in parameters:
for fn in a_valid_fn:
if has_arg(fn, p):
pass
else:
raise ValueError('{} is not a valid parameter'.format(p))
else:
raise TypeError('invalid argument: list or dictionnary expected')
def set_params(self, **params):
"""Sets the parameters of this estimator.
Args:
**params: Dictionary of parameter names mapped to their values.
Raises :
ValueError: if any parameter is not a valid argument
for the target function
"""
self._check_params(params)
self.target_params = params
def filter_params(self, fn, override=None):
"""Filters `target_params` and return those in `fn`'s arguments.
Args:
fn : arbitrary function
override: dict, values to override target_params
Returns:
result : dict, dictionary containing variables
in both target_params and fn's arguments.
"""
override = override or {}
result = {}
for name, value in self.target_params.items():
if has_arg(fn, name):
result.update({name: value})
result.update(override)
return result
class SegmentationAlgorithm(BaseWrapper):
""" Define the image segmentation function based on Scikit-Image
implementation and a set of provided parameters
Args:
algo_type: string, segmentation algorithm among the following:
'quickshift', 'slic', 'felzenszwalb'
target_params: dict, algorithm parameters (valid model paramters
as define in Scikit-Image documentation)
"""
def __init__(self, algo_type, **target_params):
self.algo_type = algo_type
if (self.algo_type == 'quickshift'):
BaseWrapper.__init__(self, quickshift, **target_params)
kwargs = self.filter_params(quickshift)
self.set_params(**kwargs)
elif (self.algo_type == 'felzenszwalb'):
BaseWrapper.__init__(self, felzenszwalb, **target_params)
kwargs = self.filter_params(felzenszwalb)
self.set_params(**kwargs)
elif (self.algo_type == 'slic'):
BaseWrapper.__init__(self, slic, **target_params)
kwargs = self.filter_params(slic)
self.set_params(**kwargs)
def __call__(self, *args):
return self.target_fn(args[0], **self.target_params)