File size: 2,722 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import React from "react";
import ReactDOM from "react-dom";
import AppBase from "../../src/App";
import * as widgets from "@jupyter-widgets/base";
import * as _ from "lodash";

class Widget extends React.Component {
  constructor(props) {
    super(props);
    this.state = {
      data: [],
      config: {
        classes: [],
        methods: [],
        method_arguments: {},
      },
      loading: false,
      callback: null,
    };
    this.backbone = this.props.backbone;
  }

  componentDidMount() {
    this.backbone.model.on("change:output", this._outputChanged, this);
    this.backbone.model.on(
      "change:attribution",
      this._attributionChanged,
      this
    );
  }

  _outputChanged(model, output, options) {
    if (_.isEmpty(output)) return;
    this.setState({ data: output, loading: false });
  }

  _attributionChanged(model, attribution, options) {
    if (_.isEmpty(attribution)) return;
    const data = Object.assign([], this.state.data);
    const callback = this.state.callback;
    const labelDetails = model.attributes.label_details;
    data[labelDetails.inputIndex][labelDetails.modelIndex] = attribution;
    this.setState({ data: data, callback: null }, () => {
      callback();
    });
  }

  _fetchInit = () => {
    this.setState({
      config: this.backbone.model.get("insights_config"),
    });
  };

  fetchData = (filterConfig) => {
    this.setState({ loading: true }, () => {
      this.backbone.model.save({ config: filterConfig, output: [] });
    });
  };

  onTargetClick = (labelIndex, inputIndex, modelIndex, callback) => {
    this.setState({ callback: callback }, () => {
      this.backbone.model.save({
        label_details: { labelIndex, inputIndex, modelIndex },
        attribution: {},
      });
    });
  };

  render() {
    return (
      <AppBase
        fetchData={this.fetchData}
        fetchInit={this._fetchInit}
        onTargetClick={this.onTargetClick}
        data={this.state.data}
        config={this.state.config}
        loading={this.state.loading}
      />
    );
  }
}

var CaptumInsightsModel = widgets.DOMWidgetModel.extend({
  defaults: _.extend(widgets.DOMWidgetModel.prototype.defaults(), {
    _model_name: "CaptumInsightsModel",
    _view_name: "CaptumInsightsView",
    _model_module: "jupyter-captum-insights",
    _view_module: "jupyter-captum-insights",
    _model_module_version: "0.1.0",
    _view_module_version: "0.1.0",
  }),
});

var CaptumInsightsView = widgets.DOMWidgetView.extend({
  initialize() {
    const $app = document.createElement("div");
    ReactDOM.render(<Widget backbone={this} />, $app);
    this.el.append($app);
  },
});

export { Widget as default, CaptumInsightsModel, CaptumInsightsView };