kritsg's picture
cleaning up hardcoded aspects of code
c38a7bb
raw
history blame
3.54 kB
from cgitb import enable
import numpy as np
import gradio as gr
"""An example of generating a gif explanation for an image of my dog."""
import argparse
import os
from os.path import exists, dirname
import sys
import flask
parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)
from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *
from image_posterior import create_gif
BLENHEIM_SPANIEL_CLASS = 156
def get_image_data(image_name):
"""Gets the image data and model."""
if (image_name == "imagenet_diego.png"):
image = get_dataset_by_name("imagenet_diego", get_label=False)
model_and_data = process_imagenet_get_model(image)
return image, model_and_data
def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
print("GRADIO INPUTS:", image_name, c_width, n_top, n_gif_imgs)
cred_width = c_width
n_top_segs = n_top
n_gif_images = n_gif_imgs
image, model_and_data = get_image_data(image_name)
# Unpack datax
xtest = model_and_data["xtest"]
ytest = model_and_data["ytest"]
segs = model_and_data["xtest_segs"]
get_model = model_and_data["model"]
label = model_and_data["label"]
print("LABEL:", label)
# Unpack instance and segments
instance = xtest[0]
segments = segs[0]
# Get wrapped model
cur_model = get_model(instance, segments)
# Get background data
xtrain = get_xtrain(segments)
prediction = np.argmax(cur_model(xtrain[:1]), axis=1)
assert prediction == BLENHEIM_SPANIEL_CLASS, f"Prediction is {prediction} not {BLENHEIM_SPANIEL_CLASS}"
# Compute explanation
exp_init = BayesLocalExplanations(training_data=xtrain,
data="image",
kernel="lime",
categorical_features=np.arange(xtrain.shape[1]),
verbose=True)
rout = exp_init.explain(classifier_f=cur_model,
data=np.ones_like(xtrain[0]),
label=BLENHEIM_SPANIEL_CLASS,
cred_width=cred_width,
focus_sample=False,
l2=False)
# Create the gif of the explanation
return create_gif(rout['blr'], image_name, segments, instance, n_gif_images, n_top_segs)
def image_mod(image):
return image.rotate(45)
if __name__ == "__main__":
# gradio's image inputs look like this: <PIL.Image.Image image mode=RGB size=305x266 at 0x7F3D01C91FA0>
# need to learn how to handle image inputs, or deal with file inputs or just file path strings
inp = gr.inputs.Textbox(lines=1, placeholder="Select an example from below", default="", label="Input Image Path", optional=False)
out = gr.outputs.HTML(label="Output GIF")
iface = gr.Interface(
segmentation_generation,
[
inp,
gr.inputs.Slider(minimum=0.01, maximum=0.8, step=0.001, default=0.1, label="cred_width", optional=False),
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=5, label="n_top_segs", optional=False),
gr.inputs.Slider(minimum=10, maximum=50, step=1, default=20, label="n_gif_images", optional=False),
],
outputs=out,
examples=[["imagenet_diego.png", 0.01, 7, 50]]
)
iface.launch(show_error=True, enable_queue=True)