File size: 2,057 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3
import ipywidgets as widgets
from captum.insights import AttributionVisualizer
from captum.insights.attr_vis.server import namedtuple_to_dict
from traitlets import Dict, Instance, List, observe, Unicode


@widgets.register
class CaptumInsights(widgets.DOMWidget):
    """A widget for interacting with Captum Insights."""

    _view_name = Unicode("CaptumInsightsView").tag(sync=True)
    _model_name = Unicode("CaptumInsightsModel").tag(sync=True)
    _view_module = Unicode("jupyter-captum-insights").tag(sync=True)
    _model_module = Unicode("jupyter-captum-insights").tag(sync=True)
    _view_module_version = Unicode("^0.1.0").tag(sync=True)
    _model_module_version = Unicode("^0.1.0").tag(sync=True)

    visualizer = Instance(klass=AttributionVisualizer)

    insights_config = Dict().tag(sync=True)
    label_details = Dict().tag(sync=True)
    attribution = Dict().tag(sync=True)
    config = Dict().tag(sync=True)
    output = List().tag(sync=True)

    def __init__(self, **kwargs) -> None:
        super(CaptumInsights, self).__init__(**kwargs)
        self.insights_config = self.visualizer.get_insights_config()
        self.out = widgets.Output()
        with self.out:
            print("Captum Insights widget created.")

    @observe("config")
    def _fetch_data(self, change):
        if not self.config:
            return
        with self.out:
            self.visualizer._update_config(self.config)
            self.output = namedtuple_to_dict(self.visualizer.visualize())
            self.config = dict()

    @observe("label_details")
    def _fetch_attribution(self, change):
        if not self.label_details:
            return
        with self.out:
            self.attribution = namedtuple_to_dict(
                self.visualizer._calculate_attribution_from_cache(
                    self.label_details["inputIndex"],
                    self.label_details["modelIndex"],
                    self.label_details["labelIndex"],
                )
            )
            self.label_details = dict()