Spaces:
Runtime error
Runtime error
Added image_posterior (to create segmentation)
Browse files- image_posterior.py +73 -0
image_posterior.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Create a gif sampling from the posterior from an image.
|
2 |
+
|
3 |
+
The file includes routines to create gifs of posterior samples for image
|
4 |
+
explanations. To create the gif, we sample a number of draws from the posterior,
|
5 |
+
plot the explanation and the image, and repeat this to stitch together a gif.
|
6 |
+
|
7 |
+
The interpretation is that regions of the image that more frequency show up as
|
8 |
+
green are more likely to positively impact the prediction. Similarly, regions that
|
9 |
+
more frequently show up as red are more likey to negatively impact the prediction.
|
10 |
+
"""
|
11 |
+
import os
|
12 |
+
from os.path import exists, dirname
|
13 |
+
import sys
|
14 |
+
|
15 |
+
import imageio
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
import numpy as np
|
18 |
+
from skimage.segmentation import mark_boundaries
|
19 |
+
import tempfile
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
import lime.lime_tabular as baseline_lime_tabular
|
23 |
+
import shap
|
24 |
+
|
25 |
+
# Make sure we can get bayes explanations
|
26 |
+
parent_dir = dirname(os.path.abspath(os.getcwd()))
|
27 |
+
sys.path.append(parent_dir)
|
28 |
+
|
29 |
+
from bayes.explanations import BayesLocalExplanations, explain_many
|
30 |
+
from bayes.data_routines import get_dataset_by_name
|
31 |
+
from bayes.models import *
|
32 |
+
|
33 |
+
def fill_segmentation(values, segmentation, image, n_max=5):
|
34 |
+
max_segs = np.argsort(abs(values))[-n_max:]
|
35 |
+
out = np.zeros((224, 224))
|
36 |
+
c_image = np.zeros(image.shape)
|
37 |
+
for i in range(len(values)):
|
38 |
+
if i in max_segs:
|
39 |
+
out[segmentation == i] = 1 if values[i] > 0 else -1
|
40 |
+
c = 1 if values[i] > 0 else 0
|
41 |
+
c_image[segmentation == i, c] = np.max(image)
|
42 |
+
return c_image.astype(int), out.astype(int)
|
43 |
+
|
44 |
+
def create_gif(explanation_blr, segments, image, n_images=20, n_max=5):
|
45 |
+
"""Create the gif corresponding to the image explanation.
|
46 |
+
|
47 |
+
Arguments:
|
48 |
+
explanation_coefficients: The explanation blr object.
|
49 |
+
segments: The image segmentation.
|
50 |
+
image: The image for which to compute the explantion.
|
51 |
+
save_loc: The location to save the gif.
|
52 |
+
n_images: Number of images to create the gif with.
|
53 |
+
n_max: The number of superpixels to draw on the image.
|
54 |
+
"""
|
55 |
+
draws = explanation_blr.draw_posterior_samples(n_images)
|
56 |
+
# Setup temporary directory to store paths in
|
57 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
58 |
+
paths = []
|
59 |
+
for i, d in tqdm(enumerate(draws)):
|
60 |
+
c_image, filled_segs = fill_segmentation(d, segments, image, n_max=n_max)
|
61 |
+
plt.cla()
|
62 |
+
plt.axis('off')
|
63 |
+
plt.imshow(mark_boundaries(image, filled_segs))
|
64 |
+
plt.imshow(c_image, alpha=0.3)
|
65 |
+
paths.append(os.path.join(tmpdirname, f"{i}.png"))
|
66 |
+
plt.savefig(paths[-1])
|
67 |
+
|
68 |
+
# Save to gif
|
69 |
+
# https://stackoverflow.com/questions/61716066/creating-an-animation-out-of-matplotlib-pngs
|
70 |
+
print(f"Saving gif to {save_loc}")
|
71 |
+
ims = [imageio.imread(f) for f in paths]
|
72 |
+
return imageio.mimwrite(imageio.RETURN_BYTES, ims)
|
73 |
+
|