Spaces:
Build error
Build error
File size: 4,058 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import unittest
from lime.wrappers.scikit_image import BaseWrapper
from lime.wrappers.scikit_image import SegmentationAlgorithm
from skimage.segmentation import quickshift
from skimage.data import chelsea
from skimage.util import img_as_float
import numpy as np
class TestBaseWrapper(unittest.TestCase):
def test_base_wrapper(self):
obj_with_params = BaseWrapper(a=10, b='message')
obj_without_params = BaseWrapper()
def foo_fn():
return 'bar'
obj_with_fn = BaseWrapper(foo_fn)
self.assertEqual(obj_with_params.target_params, {'a': 10, 'b': 'message'})
self.assertEqual(obj_without_params.target_params, {})
self.assertEqual(obj_with_fn.target_fn(), 'bar')
def test__check_params(self):
def bar_fn(a):
return str(a)
class Pipo():
def __init__(self):
self.name = 'pipo'
def __call__(self, message):
return message
pipo = Pipo()
obj_with_valid_fn = BaseWrapper(bar_fn, a=10, b='message')
obj_with_valid_callable_fn = BaseWrapper(pipo, c=10, d='message')
obj_with_invalid_fn = BaseWrapper([1, 2, 3], fn_name='invalid')
# target_fn is not a callable or function/method
with self.assertRaises(AttributeError):
obj_with_invalid_fn._check_params('fn_name')
# parameters is not in target_fn args
with self.assertRaises(ValueError):
obj_with_valid_fn._check_params(['c'])
obj_with_valid_callable_fn._check_params(['e'])
# params is in target_fn args
try:
obj_with_valid_fn._check_params(['a'])
obj_with_valid_callable_fn._check_params(['message'])
except Exception:
self.fail("_check_params() raised an unexpected exception")
# params is not a dict or list
with self.assertRaises(TypeError):
obj_with_valid_fn._check_params(None)
with self.assertRaises(TypeError):
obj_with_valid_fn._check_params('param_name')
def test_set_params(self):
class Pipo():
def __init__(self):
self.name = 'pipo'
def __call__(self, message):
return message
pipo = Pipo()
obj = BaseWrapper(pipo)
# argument is set accordingly
obj.set_params(message='OK')
self.assertEqual(obj.target_params, {'message': 'OK'})
self.assertEqual(obj.target_fn(**obj.target_params), 'OK')
# invalid argument is passed
try:
obj = BaseWrapper(Pipo())
obj.set_params(invalid='KO')
except Exception:
self.assertEqual(obj.target_params, {})
def test_filter_params(self):
# right arguments are kept and wrong dismmissed
def baz_fn(a, b, c=True):
if c:
return a + b
else:
return a
obj_ = BaseWrapper(baz_fn, a=10, b=100, d=1000)
self.assertEqual(obj_.filter_params(baz_fn), {'a': 10, 'b': 100})
# target_params is overriden using 'override' argument
self.assertEqual(obj_.filter_params(baz_fn, override={'c': False}),
{'a': 10, 'b': 100, 'c': False})
class TestSegmentationAlgorithm(unittest.TestCase):
def test_instanciate_segmentation_algorithm(self):
img = img_as_float(chelsea()[::2, ::2])
# wrapped functions provide the same result
fn = SegmentationAlgorithm('quickshift', kernel_size=3, max_dist=6,
ratio=0.5, random_seed=133)
fn_result = fn(img)
original_result = quickshift(img, kernel_size=3, max_dist=6, ratio=0.5,
random_seed=133)
# same segments
self.assertTrue(np.array_equal(fn_result, original_result))
def test_instanciate_slic(self):
pass
def test_instanciate_felzenszwalb(self):
pass
if __name__ == '__main__':
unittest.main()
|