Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import logging | |
import os | |
import socket | |
import threading | |
from time import sleep | |
from typing import Optional | |
from captum.log import log_usage | |
from flask import Flask, jsonify, render_template, request | |
from flask_compress import Compress | |
from torch import Tensor | |
app = Flask( | |
__name__, static_folder="frontend/build/static", template_folder="frontend/build" | |
) | |
visualizer = None | |
port = None | |
Compress(app) | |
def namedtuple_to_dict(obj): | |
if isinstance(obj, Tensor): | |
return obj.item() | |
if hasattr(obj, "_asdict"): # detect namedtuple | |
return dict(zip(obj._fields, (namedtuple_to_dict(item) for item in obj))) | |
elif isinstance(obj, str): # iterables - strings | |
return obj | |
elif hasattr(obj, "keys"): # iterables - mapping | |
return dict( | |
zip(obj.keys(), (namedtuple_to_dict(item) for item in obj.values())) | |
) | |
elif hasattr(obj, "__iter__"): # iterables - sequence | |
return type(obj)((namedtuple_to_dict(item) for item in obj)) | |
else: # non-iterable cannot contain namedtuples | |
return obj | |
def attribute(): | |
# force=True needed for Colab notebooks, which doesn't use the correct | |
# Content-Type header when forwarding requests through the Colab proxy | |
r = request.get_json(force=True) | |
return jsonify( | |
namedtuple_to_dict( | |
visualizer._calculate_attribution_from_cache( | |
r["inputIndex"], r["modelIndex"], r["labelIndex"] | |
) | |
) | |
) | |
def fetch(): | |
# force=True needed, see comment for "/attribute" route above | |
visualizer._update_config(request.get_json(force=True)) | |
visualizer_output = visualizer.visualize() | |
clean_output = namedtuple_to_dict(visualizer_output) | |
return jsonify(clean_output) | |
def init(): | |
return jsonify(visualizer.get_insights_config()) | |
def index(id=0): | |
return render_template("index.html") | |
def get_free_tcp_port(): | |
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
tcp.bind(("", 0)) | |
addr, port = tcp.getsockname() | |
tcp.close() | |
return port | |
def run_app(debug: bool = True, bind_all: bool = False): | |
if bind_all: | |
app.run(port=port, use_reloader=False, debug=debug, host="0.0.0.0") | |
else: | |
app.run(port=port, use_reloader=False, debug=debug) | |
def start_server( | |
_viz, | |
blocking: bool = False, | |
debug: bool = False, | |
_port: Optional[int] = None, | |
bind_all: bool = False, | |
): | |
global visualizer | |
visualizer = _viz | |
global port | |
if port is None: | |
os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message | |
if not debug: | |
log = logging.getLogger("werkzeug") | |
log.disabled = True | |
app.logger.disabled = True | |
port = _port or get_free_tcp_port() | |
# Start in a new thread to not block notebook execution | |
t = threading.Thread( | |
target=run_app, kwargs={"debug": debug, "bind_all": bind_all} | |
) | |
t.start() | |
sleep(0.01) # add a short delay to allow server to start up | |
if blocking: | |
t.join() | |
print(f"\nFetch data and view Captum Insights at http://localhost:{port}/\n") | |
return port | |