File size: 4,058 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import unittest
from lime.wrappers.scikit_image import BaseWrapper
from lime.wrappers.scikit_image import SegmentationAlgorithm
from skimage.segmentation import quickshift
from skimage.data import chelsea
from skimage.util import img_as_float
import numpy as np


class TestBaseWrapper(unittest.TestCase):

    def test_base_wrapper(self):

        obj_with_params = BaseWrapper(a=10, b='message')
        obj_without_params = BaseWrapper()

        def foo_fn():
            return 'bar'

        obj_with_fn = BaseWrapper(foo_fn)
        self.assertEqual(obj_with_params.target_params, {'a': 10, 'b': 'message'})
        self.assertEqual(obj_without_params.target_params, {})
        self.assertEqual(obj_with_fn.target_fn(), 'bar')

    def test__check_params(self):

        def bar_fn(a):
            return str(a)

        class Pipo():

            def __init__(self):
                self.name = 'pipo'

            def __call__(self, message):
                return message

        pipo = Pipo()
        obj_with_valid_fn = BaseWrapper(bar_fn, a=10, b='message')
        obj_with_valid_callable_fn = BaseWrapper(pipo, c=10, d='message')
        obj_with_invalid_fn = BaseWrapper([1, 2, 3], fn_name='invalid')

        # target_fn is not a callable or function/method
        with self.assertRaises(AttributeError):
            obj_with_invalid_fn._check_params('fn_name')

        # parameters is not in target_fn args
        with self.assertRaises(ValueError):
            obj_with_valid_fn._check_params(['c'])
            obj_with_valid_callable_fn._check_params(['e'])

        # params is in target_fn args
        try:
            obj_with_valid_fn._check_params(['a'])
            obj_with_valid_callable_fn._check_params(['message'])
        except Exception:
            self.fail("_check_params() raised an unexpected exception")

        # params is not a dict or list
        with self.assertRaises(TypeError):
            obj_with_valid_fn._check_params(None)
        with self.assertRaises(TypeError):
            obj_with_valid_fn._check_params('param_name')

    def test_set_params(self):

        class Pipo():

            def __init__(self):
                self.name = 'pipo'

            def __call__(self, message):
                return message
        pipo = Pipo()
        obj = BaseWrapper(pipo)

        # argument is set accordingly
        obj.set_params(message='OK')
        self.assertEqual(obj.target_params, {'message': 'OK'})
        self.assertEqual(obj.target_fn(**obj.target_params), 'OK')

        # invalid argument is passed
        try:
            obj = BaseWrapper(Pipo())
            obj.set_params(invalid='KO')
        except Exception:
            self.assertEqual(obj.target_params, {})

    def test_filter_params(self):

        # right arguments are kept and wrong dismmissed
        def baz_fn(a, b, c=True):
            if c:
                return a + b
            else:
                return a
        obj_ = BaseWrapper(baz_fn, a=10, b=100, d=1000)
        self.assertEqual(obj_.filter_params(baz_fn), {'a': 10, 'b': 100})

        # target_params is overriden using 'override' argument
        self.assertEqual(obj_.filter_params(baz_fn, override={'c': False}),
                         {'a': 10, 'b': 100, 'c': False})


class TestSegmentationAlgorithm(unittest.TestCase):

    def test_instanciate_segmentation_algorithm(self):
        img = img_as_float(chelsea()[::2, ::2])

        # wrapped functions provide the same result
        fn = SegmentationAlgorithm('quickshift', kernel_size=3, max_dist=6,
                                   ratio=0.5, random_seed=133)
        fn_result = fn(img)
        original_result = quickshift(img, kernel_size=3, max_dist=6, ratio=0.5,
                                     random_seed=133)

        # same segments
        self.assertTrue(np.array_equal(fn_result, original_result))

    def test_instanciate_slic(self):
        pass

    def test_instanciate_felzenszwalb(self):
        pass


if __name__ == '__main__':
    unittest.main()