File size: 3,690 Bytes
0ff7e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e67e6e
0ff7e03
 
40d47b6
8d20412
0ff7e03
 
 
 
 
 
 
 
 
8d20412
 
 
 
 
0ff7e03
 
 
 
 
 
 
 
 
 
 
8d20412
0ff7e03
 
 
 
 
 
 
 
 
 
 
8d20412
5d3e27c
40d47b6
 
 
 
 
fe14988
8d20412
0ff7e03
 
 
 
 
 
 
 
 
 
7dd304e
0ff7e03
 
73c8e91
971bf27
73c8e91
 
decc6b2
0ff7e03
73c8e91
c38a7bb
 
0e27b08
73c8e91
c38a7bb
 
214a191
4790b80
c38a7bb
971bf27
0ff7e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Create a gif sampling from the posterior from an image.

The file includes routines to create gifs of posterior samples for image
explanations. To create the gif, we sample a number of draws from the posterior,
plot the explanation and the image, and repeat this to stitch together a gif.

The interpretation is that regions of the image that more frequency show up as
green are more likely to positively impact the prediction. Similarly, regions that 
more frequently show up as red are more likey to negatively impact the prediction.
"""
import os
from os.path import exists, dirname
import sys

import imageio
import matplotlib.pyplot as plt
import numpy as np
from skimage.segmentation import mark_boundaries
import tempfile
from tqdm import tqdm
import ffmpeg
import lime.lime_tabular as baseline_lime_tabular
import shap
import shutil
import json

# Make sure we can get bayes explanations
parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)

from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *

labels_dict = {}
with open("labels.json") as file:
    labels_dict = json.load(file)


def fill_segmentation(values, segmentation, image, n_max=5):
    max_segs = np.argsort(abs(values))[-n_max:]
    out = np.zeros((224, 224))
    c_image = np.zeros(image.shape)
    for i in range(len(values)):
        if i in max_segs:
            out[segmentation == i] = 1 if values[i] > 0 else -1
            c = 1 if values[i] > 0 else 0
            c_image[segmentation == i, c] = np.max(image)
    return c_image.astype(int), out.astype(int)

def create_gif(explanation_blr, img_name, segments, image, prediction, n_images=20, n_max=5):
    """Create the gif corresponding to the image explanation.

    Arguments:
        explanation_coefficients: The explanation blr object.
        segments: The image segmentation.
        image: The image for which to compute the explantion.
        save_loc: The location to save the gif.
        n_images: Number of images to create the gif with.
        n_max: The number of superpixels to draw on the image.
    """
    draws = explanation_blr.draw_posterior_samples(n_images)
    # remove any existing files
    temp_path = tempfile.TemporaryDirectory().name
    for root, dirs, files in os.walk(temp_path):
        for f in files:
            os.unlink(os.path.join(root, f))
        for d in dirs:
            shutil.rmtree(os.path.join(root, d))

    # Setup temporary directory to store paths in 
    with tempfile.TemporaryDirectory() as tmpdirname:
        paths = []
        for i, d in tqdm(enumerate(draws)):
            c_image, filled_segs = fill_segmentation(d, segments, image, n_max=n_max)
            plt.cla()
            plt.axis('off')
            plt.imshow(mark_boundaries(image, filled_segs))
            plt.imshow(c_image, alpha=0.3)
            paths.append(os.path.join(tmpdirname, f"{i}.png"))
            plt.savefig(paths[-1])
    
        # Save to gif
        # https://stackoverflow.com/questions/61716066/creating-an-animation-out-of-matplotlib-pngs
        print(f"Saving gif to {str(prediction)}_explanation.gif")

        if(os.path.exists(f'{str(prediction)}_explanation.gif')):
            os.remove(f'{str(prediction)}_explanation.gif')

        ims = [imageio.imread(f) for f in paths]
        imageio.mimwrite(f'{str(prediction)}_explanation.gif', ims)

    html = (
        "<div>"
        + f"<img  src='file/{str(prediction)}_explanation.gif' alt='explanation gif'/>"
        + "</div>"
    )

    return html, f"Predction was {prediction}: {labels_dict[str(prediction)]}"