markytools's picture
added strexp
d61b9c7
raw
history blame
3.33 kB
#!/usr/bin/env python3
import logging
import os
import socket
import threading
from time import sleep
from typing import Optional
from captum.log import log_usage
from flask import Flask, jsonify, render_template, request
from flask_compress import Compress
from torch import Tensor
app = Flask(
__name__, static_folder="frontend/build/static", template_folder="frontend/build"
)
visualizer = None
port = None
Compress(app)
def namedtuple_to_dict(obj):
if isinstance(obj, Tensor):
return obj.item()
if hasattr(obj, "_asdict"): # detect namedtuple
return dict(zip(obj._fields, (namedtuple_to_dict(item) for item in obj)))
elif isinstance(obj, str): # iterables - strings
return obj
elif hasattr(obj, "keys"): # iterables - mapping
return dict(
zip(obj.keys(), (namedtuple_to_dict(item) for item in obj.values()))
)
elif hasattr(obj, "__iter__"): # iterables - sequence
return type(obj)((namedtuple_to_dict(item) for item in obj))
else: # non-iterable cannot contain namedtuples
return obj
@app.route("/attribute", methods=["POST"])
def attribute():
# force=True needed for Colab notebooks, which doesn't use the correct
# Content-Type header when forwarding requests through the Colab proxy
r = request.get_json(force=True)
return jsonify(
namedtuple_to_dict(
visualizer._calculate_attribution_from_cache(
r["inputIndex"], r["modelIndex"], r["labelIndex"]
)
)
)
@app.route("/fetch", methods=["POST"])
def fetch():
# force=True needed, see comment for "/attribute" route above
visualizer._update_config(request.get_json(force=True))
visualizer_output = visualizer.visualize()
clean_output = namedtuple_to_dict(visualizer_output)
return jsonify(clean_output)
@app.route("/init")
def init():
return jsonify(visualizer.get_insights_config())
@app.route("/")
def index(id=0):
return render_template("index.html")
def get_free_tcp_port():
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp.bind(("", 0))
addr, port = tcp.getsockname()
tcp.close()
return port
def run_app(debug: bool = True, bind_all: bool = False):
if bind_all:
app.run(port=port, use_reloader=False, debug=debug, host="0.0.0.0")
else:
app.run(port=port, use_reloader=False, debug=debug)
@log_usage()
def start_server(
_viz,
blocking: bool = False,
debug: bool = False,
_port: Optional[int] = None,
bind_all: bool = False,
):
global visualizer
visualizer = _viz
global port
if port is None:
os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message
if not debug:
log = logging.getLogger("werkzeug")
log.disabled = True
app.logger.disabled = True
port = _port or get_free_tcp_port()
# Start in a new thread to not block notebook execution
t = threading.Thread(
target=run_app, kwargs={"debug": debug, "bind_all": bind_all}
)
t.start()
sleep(0.01) # add a short delay to allow server to start up
if blocking:
t.join()
print(f"\nFetch data and view Captum Insights at http://localhost:{port}/\n")
return port