File size: 3,990 Bytes
b1b3f23
fa636b5
 
e3914b4
 
 
 
 
8e67e6e
e3914b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971bf27
 
e3914b4
 
 
 
 
 
 
ec53bcf
e3914b4
ead7d7b
e3914b4
 
 
429b311
93f12e7
429b311
 
 
e3914b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa636b5
 
 
971bf27
e3914b4
0bf30a9
 
ec53bcf
971bf27
e3914b4
429b311
 
 
 
971bf27
429b311
 
971bf27
 
 
 
b1b3f23
971bf27
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from cgitb import enable
import numpy as np
import gradio as gr
"""An example of generating a gif explanation for an image of my dog."""
import argparse
import os
from os.path import exists, dirname
import sys
import flask

parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)

from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *
from image_posterior import create_gif	

parser = argparse.ArgumentParser()
parser.add_argument("--cred_width", type=float, default=0.1)
parser.add_argument("--save_loc", type=str, required=True)
parser.add_argument("--n_top_segs", type=int, default=5)
parser.add_argument("--n_gif_images", type=int, default=20)

# app = flask.Flask(__name__, template_folder="./")

IMAGE_NAME = "imagenet_diego"
BLENHEIM_SPANIEL_CLASS = 156


def get_image_data():
    """Gets the image data and model."""
    puppy_image = get_dataset_by_name(IMAGE_NAME, get_label=False)
    print("IMAGE RETURNED FROM GETTING DATASET:\n", puppy_image)
    model_and_data = process_imagenet_get_model(puppy_image)
    print("MODEL RETURNED FROM PROCESSING IMAGE:\n", model_and_data)
    return puppy_image, model_and_data


def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
    print("GRADIO INPUTS:", image_name, c_width, n_top, n_gif_imgs)
    cred_width = c_width
    n_top_segs = n_top
    n_gif_images = n_gif_imgs
    puppy_image, model_and_data = get_image_data()

    # Unpack datax
    xtest = model_and_data["xtest"]
    ytest = model_and_data["ytest"]
    segs = model_and_data["xtest_segs"]
    get_model = model_and_data["model"]
    label = model_and_data["label"]

    # Unpack instance and segments
    instance = xtest[0]
    segments = segs[0]

    # Get wrapped model
    cur_model = get_model(instance, segments)

    # Get background data
    xtrain = get_xtrain(segments)

    prediction = np.argmax(cur_model(xtrain[:1]), axis=1)
    assert prediction == BLENHEIM_SPANIEL_CLASS, f"Prediction is {prediction} not {BLENHEIM_SPANIEL_CLASS}"

    # Compute explanation
    exp_init = BayesLocalExplanations(training_data=xtrain,
                                              data="image",
                                              kernel="lime",
                                              categorical_features=np.arange(xtrain.shape[1]),
                                              verbose=True)
    rout = exp_init.explain(classifier_f=cur_model,
                            data=np.ones_like(xtrain[0]),
                            label=BLENHEIM_SPANIEL_CLASS,
                            cred_width=cred_width,
                            focus_sample=False,
                            l2=False)

    # Create the gif of the explanation
    return create_gif(rout['blr'], segments, instance, n_gif_images, n_top_segs)

def image_mod(image):
    return image.rotate(45)
    
if __name__ == "__main__":
    # gradio's image inputs look like this: <PIL.Image.Image image mode=RGB size=305x266 at 0x7F3D01C91FA0>
    # need to learn how to handle image inputs, or deal with file inputs or just file path strings
    inp = gr.inputs.Textbox(lines=1, placeholder="Insert file path here", default="", label="Input Image Path", optional=False)
    out = gr.outputs.HTML(label="Output Video")

    iface = gr.Interface(
        segmentation_generation, 
        [
            inp,
            gr.inputs.Slider(minimum=0.01, maximum=0.8, step=0.001, default=0.1, label="cred_width", optional=False),
            gr.inputs.Slider(minimum=1, maximum=10, step=1, default=5, label="n_top_segs", optional=False),
            gr.inputs.Slider(minimum=10, maximum=50, step=1, default=20, label="n_gif_images", optional=False), 
        ], 
        outputs=out, 
        examples=[["./imagenet_diego.png", 0.05, 7, 50]]
    )
    iface.launch(enable_queue = True)

    # app.run(host='0.0.0.0',  port=int(os.environ.get('PORT', 7860)))