Spaces:
Runtime error
Runtime error
from cgitb import enable | |
from pyexpat import model | |
from statistics import mode | |
import numpy as np | |
import gradio as gr | |
"""An example of generating a gif explanation for an image of my dog.""" | |
import argparse | |
import os | |
from os.path import exists, dirname | |
import sys | |
import flask | |
parent_dir = dirname(os.path.abspath(os.getcwd())) | |
sys.path.append(parent_dir) | |
from bayes.explanations import BayesLocalExplanations, explain_many | |
from bayes.data_routines import get_dataset_by_name | |
from bayes.models import * | |
from image_posterior import create_gif | |
BLENHEIM_SPANIEL_CLASS = 156 | |
def get_image_data(image_name): | |
"""Gets the image data and model.""" | |
image, model_and_data = [None, None] | |
if (image_name == "imagenet_diego"): | |
image = get_dataset_by_name("imagenet_diego", get_label=False) | |
model_and_data = process_imagenet_get_model(image) | |
elif (image_name == "imagenet_french_bulldog"): | |
image = get_dataset_by_name("imagenet_french_bulldog", get_label=False) | |
model_and_data = process_imagenet_get_model(image) | |
return image, model_and_data | |
def segmentation_generation(image_name, c_width, n_top, n_gif_imgs): | |
print("Inputs Received:", image_name, c_width, n_top, n_gif_imgs) | |
# cred_width = c_width | |
# n_top_segs = n_top | |
# n_gif_images = n_gif_imgs | |
image, model_and_data = get_image_data(image_name) | |
print("model_and_data", model_and_data) | |
# Unpack datax | |
xtest = model_and_data["xtest"] | |
ytest = model_and_data["ytest"] | |
segs = model_and_data["xtest_segs"] | |
get_model = model_and_data["model"] | |
label = model_and_data["label"] | |
if (image_name == 'imagenet_diego'): | |
label = 156 | |
elif (image_name == 'imagenet_french_bulldog'): | |
label = 245 | |
# Unpack instance and segments | |
instance = xtest[0] | |
segments = segs[0] | |
# Get wrapped model | |
cur_model = get_model(instance, segments) | |
# Get background data | |
xtrain = get_xtrain(segments) | |
prediction = np.argmax(cur_model(xtrain[:1]), axis=1) | |
assert prediction == label, f"Prediction is {prediction} not {label}" | |
# Compute explanation | |
exp_init = BayesLocalExplanations(training_data=xtrain, | |
data="image", | |
kernel="lime", | |
categorical_features=np.arange(xtrain.shape[1]), | |
verbose=True) | |
rout = exp_init.explain(classifier_f=cur_model, | |
data=np.ones_like(xtrain[0]), | |
label=label, | |
cred_width=c_width, | |
focus_sample=False, | |
l2=False) | |
# Create the gif of the explanation | |
return create_gif(rout['blr'], image_name, segments, instance, n_gif_imgs, n_top) | |
def image_mod(image): | |
return image.rotate(45) | |
if __name__ == "__main__": | |
# gradio's image inputs look like this: <PIL.Image.Image image mode=RGB size=305x266 at 0x7F3D01C91FA0> | |
# need to learn how to handle image inputs, or deal with file inputs or just file path strings | |
inp = gr.inputs.Textbox(lines=1, placeholder="Select an example from below", default="", label="Input Image Path", optional=False) | |
out = gr.outputs.HTML(label="Output GIF") | |
iface = gr.Interface( | |
segmentation_generation, | |
[ | |
inp, | |
gr.inputs.Slider(minimum=0.01, maximum=0.8, step=0.01, default=0.1, label="cred_width", optional=False), | |
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=5, label="n_top_segs", optional=False), | |
gr.inputs.Slider(minimum=10, maximum=50, step=1, default=20, label="n_gif_images", optional=False), | |
], | |
outputs=out, | |
examples=[["imagenet_diego", 0.01, 7, 50], | |
["imagenet_french_bulldog", 0.05, 5, 50]] | |
) | |
iface.launch(enable_queue=True) |