strexp / lime /tests /test_scikit_image.py
markytools's picture
added strexp
d61b9c7
import unittest
from lime.wrappers.scikit_image import BaseWrapper
from lime.wrappers.scikit_image import SegmentationAlgorithm
from skimage.segmentation import quickshift
from skimage.data import chelsea
from skimage.util import img_as_float
import numpy as np
class TestBaseWrapper(unittest.TestCase):
def test_base_wrapper(self):
obj_with_params = BaseWrapper(a=10, b='message')
obj_without_params = BaseWrapper()
def foo_fn():
return 'bar'
obj_with_fn = BaseWrapper(foo_fn)
self.assertEqual(obj_with_params.target_params, {'a': 10, 'b': 'message'})
self.assertEqual(obj_without_params.target_params, {})
self.assertEqual(obj_with_fn.target_fn(), 'bar')
def test__check_params(self):
def bar_fn(a):
return str(a)
class Pipo():
def __init__(self):
self.name = 'pipo'
def __call__(self, message):
return message
pipo = Pipo()
obj_with_valid_fn = BaseWrapper(bar_fn, a=10, b='message')
obj_with_valid_callable_fn = BaseWrapper(pipo, c=10, d='message')
obj_with_invalid_fn = BaseWrapper([1, 2, 3], fn_name='invalid')
# target_fn is not a callable or function/method
with self.assertRaises(AttributeError):
obj_with_invalid_fn._check_params('fn_name')
# parameters is not in target_fn args
with self.assertRaises(ValueError):
obj_with_valid_fn._check_params(['c'])
obj_with_valid_callable_fn._check_params(['e'])
# params is in target_fn args
try:
obj_with_valid_fn._check_params(['a'])
obj_with_valid_callable_fn._check_params(['message'])
except Exception:
self.fail("_check_params() raised an unexpected exception")
# params is not a dict or list
with self.assertRaises(TypeError):
obj_with_valid_fn._check_params(None)
with self.assertRaises(TypeError):
obj_with_valid_fn._check_params('param_name')
def test_set_params(self):
class Pipo():
def __init__(self):
self.name = 'pipo'
def __call__(self, message):
return message
pipo = Pipo()
obj = BaseWrapper(pipo)
# argument is set accordingly
obj.set_params(message='OK')
self.assertEqual(obj.target_params, {'message': 'OK'})
self.assertEqual(obj.target_fn(**obj.target_params), 'OK')
# invalid argument is passed
try:
obj = BaseWrapper(Pipo())
obj.set_params(invalid='KO')
except Exception:
self.assertEqual(obj.target_params, {})
def test_filter_params(self):
# right arguments are kept and wrong dismmissed
def baz_fn(a, b, c=True):
if c:
return a + b
else:
return a
obj_ = BaseWrapper(baz_fn, a=10, b=100, d=1000)
self.assertEqual(obj_.filter_params(baz_fn), {'a': 10, 'b': 100})
# target_params is overriden using 'override' argument
self.assertEqual(obj_.filter_params(baz_fn, override={'c': False}),
{'a': 10, 'b': 100, 'c': False})
class TestSegmentationAlgorithm(unittest.TestCase):
def test_instanciate_segmentation_algorithm(self):
img = img_as_float(chelsea()[::2, ::2])
# wrapped functions provide the same result
fn = SegmentationAlgorithm('quickshift', kernel_size=3, max_dist=6,
ratio=0.5, random_seed=133)
fn_result = fn(img)
original_result = quickshift(img, kernel_size=3, max_dist=6, ratio=0.5,
random_seed=133)
# same segments
self.assertTrue(np.array_equal(fn_result, original_result))
def test_instanciate_slic(self):
pass
def test_instanciate_felzenszwalb(self):
pass
if __name__ == '__main__':
unittest.main()