Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import ipywidgets as widgets | |
from captum.insights import AttributionVisualizer | |
from captum.insights.attr_vis.server import namedtuple_to_dict | |
from traitlets import Dict, Instance, List, observe, Unicode | |
class CaptumInsights(widgets.DOMWidget): | |
"""A widget for interacting with Captum Insights.""" | |
_view_name = Unicode("CaptumInsightsView").tag(sync=True) | |
_model_name = Unicode("CaptumInsightsModel").tag(sync=True) | |
_view_module = Unicode("jupyter-captum-insights").tag(sync=True) | |
_model_module = Unicode("jupyter-captum-insights").tag(sync=True) | |
_view_module_version = Unicode("^0.1.0").tag(sync=True) | |
_model_module_version = Unicode("^0.1.0").tag(sync=True) | |
visualizer = Instance(klass=AttributionVisualizer) | |
insights_config = Dict().tag(sync=True) | |
label_details = Dict().tag(sync=True) | |
attribution = Dict().tag(sync=True) | |
config = Dict().tag(sync=True) | |
output = List().tag(sync=True) | |
def __init__(self, **kwargs) -> None: | |
super(CaptumInsights, self).__init__(**kwargs) | |
self.insights_config = self.visualizer.get_insights_config() | |
self.out = widgets.Output() | |
with self.out: | |
print("Captum Insights widget created.") | |
def _fetch_data(self, change): | |
if not self.config: | |
return | |
with self.out: | |
self.visualizer._update_config(self.config) | |
self.output = namedtuple_to_dict(self.visualizer.visualize()) | |
self.config = dict() | |
def _fetch_attribution(self, change): | |
if not self.label_details: | |
return | |
with self.out: | |
self.attribution = namedtuple_to_dict( | |
self.visualizer._calculate_attribution_from_cache( | |
self.label_details["inputIndex"], | |
self.label_details["modelIndex"], | |
self.label_details["labelIndex"], | |
) | |
) | |
self.label_details = dict() | |