File size: 3,331 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
import logging
import os
import socket
import threading
from time import sleep
from typing import Optional

from captum.log import log_usage
from flask import Flask, jsonify, render_template, request
from flask_compress import Compress
from torch import Tensor

app = Flask(
    __name__, static_folder="frontend/build/static", template_folder="frontend/build"
)
visualizer = None
port = None
Compress(app)


def namedtuple_to_dict(obj):
    if isinstance(obj, Tensor):
        return obj.item()
    if hasattr(obj, "_asdict"):  # detect namedtuple
        return dict(zip(obj._fields, (namedtuple_to_dict(item) for item in obj)))
    elif isinstance(obj, str):  # iterables - strings
        return obj
    elif hasattr(obj, "keys"):  # iterables - mapping
        return dict(
            zip(obj.keys(), (namedtuple_to_dict(item) for item in obj.values()))
        )
    elif hasattr(obj, "__iter__"):  # iterables - sequence
        return type(obj)((namedtuple_to_dict(item) for item in obj))
    else:  # non-iterable cannot contain namedtuples
        return obj


@app.route("/attribute", methods=["POST"])
def attribute():
    # force=True needed for Colab notebooks, which doesn't use the correct
    # Content-Type header when forwarding requests through the Colab proxy
    r = request.get_json(force=True)
    return jsonify(
        namedtuple_to_dict(
            visualizer._calculate_attribution_from_cache(
                r["inputIndex"], r["modelIndex"], r["labelIndex"]
            )
        )
    )


@app.route("/fetch", methods=["POST"])
def fetch():
    # force=True needed, see comment for "/attribute" route above
    visualizer._update_config(request.get_json(force=True))
    visualizer_output = visualizer.visualize()
    clean_output = namedtuple_to_dict(visualizer_output)
    return jsonify(clean_output)


@app.route("/init")
def init():
    return jsonify(visualizer.get_insights_config())


@app.route("/")
def index(id=0):
    return render_template("index.html")


def get_free_tcp_port():
    tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    tcp.bind(("", 0))
    addr, port = tcp.getsockname()
    tcp.close()
    return port


def run_app(debug: bool = True, bind_all: bool = False):
    if bind_all:
        app.run(port=port, use_reloader=False, debug=debug, host="0.0.0.0")
    else:
        app.run(port=port, use_reloader=False, debug=debug)


@log_usage()
def start_server(
    _viz,
    blocking: bool = False,
    debug: bool = False,
    _port: Optional[int] = None,
    bind_all: bool = False,
):
    global visualizer
    visualizer = _viz

    global port
    if port is None:
        os.environ["WERKZEUG_RUN_MAIN"] = "true"  # hides starting message
        if not debug:
            log = logging.getLogger("werkzeug")
            log.disabled = True
            app.logger.disabled = True

        port = _port or get_free_tcp_port()
        # Start in a new thread to not block notebook execution
        t = threading.Thread(
            target=run_app, kwargs={"debug": debug, "bind_all": bind_all}
        )
        t.start()
        sleep(0.01)  # add a short delay to allow server to start up
        if blocking:
            t.join()

    print(f"\nFetch data and view Captum Insights at http://localhost:{port}/\n")
    return port