File size: 4,248 Bytes
b1b3f23
fa636b5
 
e3914b4
 
 
 
 
8e67e6e
e3914b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971bf27
 
e3914b4
 
 
 
 
 
 
ec53bcf
e3914b4
ead7d7b
e3914b4
 
 
429b311
93f12e7
e572b7a
 
 
 
 
 
 
 
 
429b311
 
 
e3914b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa636b5
 
 
971bf27
e3914b4
0bf30a9
 
ec53bcf
971bf27
e3914b4
429b311
 
 
 
971bf27
429b311
 
971bf27
 
 
 
b1b3f23
971bf27
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from cgitb import enable
import numpy as np
import gradio as gr
"""An example of generating a gif explanation for an image of my dog."""
import argparse
import os
from os.path import exists, dirname
import sys
import flask

parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)

from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *
from image_posterior import create_gif	

parser = argparse.ArgumentParser()
parser.add_argument("--cred_width", type=float, default=0.1)
parser.add_argument("--save_loc", type=str, required=True)
parser.add_argument("--n_top_segs", type=int, default=5)
parser.add_argument("--n_gif_images", type=int, default=20)

# app = flask.Flask(__name__, template_folder="./")

IMAGE_NAME = "imagenet_diego"
BLENHEIM_SPANIEL_CLASS = 156


def get_image_data():
    """Gets the image data and model."""
    puppy_image = get_dataset_by_name(IMAGE_NAME, get_label=False)
    print("IMAGE RETURNED FROM GETTING DATASET:\n", puppy_image)
    model_and_data = process_imagenet_get_model(puppy_image)
    print("MODEL RETURNED FROM PROCESSING IMAGE:\n", model_and_data)
    return puppy_image, model_and_data


def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
    print("GRADIO INPUTS:", image_name, c_width, n_top, n_gif_imgs)

    html =  '''   
        <div style='max-width:100%; max-height:360px; overflow:auto'>
            <video width="320" height="240" autoplay>
                <source src="./test.mp4" type=video/mp4>
            </video>
        </div>''' 
    return html

    cred_width = c_width
    n_top_segs = n_top
    n_gif_images = n_gif_imgs
    puppy_image, model_and_data = get_image_data()

    # Unpack datax
    xtest = model_and_data["xtest"]
    ytest = model_and_data["ytest"]
    segs = model_and_data["xtest_segs"]
    get_model = model_and_data["model"]
    label = model_and_data["label"]

    # Unpack instance and segments
    instance = xtest[0]
    segments = segs[0]

    # Get wrapped model
    cur_model = get_model(instance, segments)

    # Get background data
    xtrain = get_xtrain(segments)

    prediction = np.argmax(cur_model(xtrain[:1]), axis=1)
    assert prediction == BLENHEIM_SPANIEL_CLASS, f"Prediction is {prediction} not {BLENHEIM_SPANIEL_CLASS}"

    # Compute explanation
    exp_init = BayesLocalExplanations(training_data=xtrain,
                                              data="image",
                                              kernel="lime",
                                              categorical_features=np.arange(xtrain.shape[1]),
                                              verbose=True)
    rout = exp_init.explain(classifier_f=cur_model,
                            data=np.ones_like(xtrain[0]),
                            label=BLENHEIM_SPANIEL_CLASS,
                            cred_width=cred_width,
                            focus_sample=False,
                            l2=False)

    # Create the gif of the explanation
    return create_gif(rout['blr'], segments, instance, n_gif_images, n_top_segs)

def image_mod(image):
    return image.rotate(45)
    
if __name__ == "__main__":
    # gradio's image inputs look like this: <PIL.Image.Image image mode=RGB size=305x266 at 0x7F3D01C91FA0>
    # need to learn how to handle image inputs, or deal with file inputs or just file path strings
    inp = gr.inputs.Textbox(lines=1, placeholder="Insert file path here", default="", label="Input Image Path", optional=False)
    out = gr.outputs.HTML(label="Output Video")

    iface = gr.Interface(
        segmentation_generation, 
        [
            inp,
            gr.inputs.Slider(minimum=0.01, maximum=0.8, step=0.001, default=0.1, label="cred_width", optional=False),
            gr.inputs.Slider(minimum=1, maximum=10, step=1, default=5, label="n_top_segs", optional=False),
            gr.inputs.Slider(minimum=10, maximum=50, step=1, default=20, label="n_gif_images", optional=False), 
        ], 
        outputs=out, 
        examples=[["./imagenet_diego.png", 0.05, 7, 50]]
    )
    iface.launch(enable_queue = True)

    # app.run(host='0.0.0.0',  port=int(os.environ.get('PORT', 7860)))