kritsg's picture
using html to embed video into demo
971bf27
raw
history blame
3.47 kB
import numpy as np
import gradio as gr
"""An example of generating a gif explanation for an image of my dog."""
import argparse
import os
from os.path import exists, dirname
import sys
import flask
parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)
from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *
from image_posterior import create_gif
parser = argparse.ArgumentParser()
parser.add_argument("--cred_width", type=float, default=0.1)
parser.add_argument("--save_loc", type=str, required=True)
parser.add_argument("--n_top_segs", type=int, default=5)
parser.add_argument("--n_gif_images", type=int, default=20)
# app = flask.Flask(__name__, template_folder="./")
IMAGE_NAME = "imagenet_diego"
BLENHEIM_SPANIEL_CLASS = 156
def get_image_data():
"""Gets the image data and model."""
puppy_image = get_dataset_by_name(IMAGE_NAME, get_label=False)
model_and_data = process_imagenet_get_model(puppy_image)
return puppy_image, model_and_data
def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
cred_width = c_width
n_top_segs = n_top
n_gif_images = n_gif_imgs
puppy_image, model_and_data = get_image_data()
# Unpack datax
xtest = model_and_data["xtest"]
ytest = model_and_data["ytest"]
segs = model_and_data["xtest_segs"]
get_model = model_and_data["model"]
label = model_and_data["label"]
# Unpack instance and segments
instance = xtest[0]
segments = segs[0]
# Get wrapped model
cur_model = get_model(instance, segments)
# Get background data
xtrain = get_xtrain(segments)
prediction = np.argmax(cur_model(xtrain[:1]), axis=1)
assert prediction == BLENHEIM_SPANIEL_CLASS, f"Prediction is {prediction} not {BLENHEIM_SPANIEL_CLASS}"
# Compute explanation
exp_init = BayesLocalExplanations(training_data=xtrain,
data="image",
kernel="lime",
categorical_features=np.arange(xtrain.shape[1]),
verbose=True)
rout = exp_init.explain(classifier_f=cur_model,
data=np.ones_like(xtrain[0]),
label=BLENHEIM_SPANIEL_CLASS,
cred_width=cred_width,
focus_sample=False,
l2=False)
# Create the gif of the explanation
return create_gif(rout['blr'], segments, instance, n_gif_images, n_top_segs)
def image_mod(image):
return image.rotate(45)
if __name__ == "__main__":
inp = gr.inputs.Image(label="Input Image", type="pil")
out = gr.outputs.HTML(label="Output Video")
iface = gr.Interface(
segmentation_generation,
[
inp,
gr.inputs.Slider(minimum=0.01, maximum=0.8, step=0.001, default=0.1, label="cred_width", optional=False),
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=5, label="n_top_segs", optional=False),
gr.inputs.Slider(minimum=10, maximum=50, step=1, default=20, label="n_gif_images", optional=False),
],
outputs=out,
examples=[["./imagenet_diego.png", 0.05, 7, 50]]
)
iface.launch()
# app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))