kritsg commited on
Commit
0ff7e03
·
1 Parent(s): e65a2db

Added image_posterior (to create segmentation)

Browse files
Files changed (1) hide show
  1. image_posterior.py +73 -0
image_posterior.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Create a gif sampling from the posterior from an image.
2
+
3
+ The file includes routines to create gifs of posterior samples for image
4
+ explanations. To create the gif, we sample a number of draws from the posterior,
5
+ plot the explanation and the image, and repeat this to stitch together a gif.
6
+
7
+ The interpretation is that regions of the image that more frequency show up as
8
+ green are more likely to positively impact the prediction. Similarly, regions that
9
+ more frequently show up as red are more likey to negatively impact the prediction.
10
+ """
11
+ import os
12
+ from os.path import exists, dirname
13
+ import sys
14
+
15
+ import imageio
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ from skimage.segmentation import mark_boundaries
19
+ import tempfile
20
+ from tqdm import tqdm
21
+
22
+ import lime.lime_tabular as baseline_lime_tabular
23
+ import shap
24
+
25
+ # Make sure we can get bayes explanations
26
+ parent_dir = dirname(os.path.abspath(os.getcwd()))
27
+ sys.path.append(parent_dir)
28
+
29
+ from bayes.explanations import BayesLocalExplanations, explain_many
30
+ from bayes.data_routines import get_dataset_by_name
31
+ from bayes.models import *
32
+
33
+ def fill_segmentation(values, segmentation, image, n_max=5):
34
+ max_segs = np.argsort(abs(values))[-n_max:]
35
+ out = np.zeros((224, 224))
36
+ c_image = np.zeros(image.shape)
37
+ for i in range(len(values)):
38
+ if i in max_segs:
39
+ out[segmentation == i] = 1 if values[i] > 0 else -1
40
+ c = 1 if values[i] > 0 else 0
41
+ c_image[segmentation == i, c] = np.max(image)
42
+ return c_image.astype(int), out.astype(int)
43
+
44
+ def create_gif(explanation_blr, segments, image, n_images=20, n_max=5):
45
+ """Create the gif corresponding to the image explanation.
46
+
47
+ Arguments:
48
+ explanation_coefficients: The explanation blr object.
49
+ segments: The image segmentation.
50
+ image: The image for which to compute the explantion.
51
+ save_loc: The location to save the gif.
52
+ n_images: Number of images to create the gif with.
53
+ n_max: The number of superpixels to draw on the image.
54
+ """
55
+ draws = explanation_blr.draw_posterior_samples(n_images)
56
+ # Setup temporary directory to store paths in
57
+ with tempfile.TemporaryDirectory() as tmpdirname:
58
+ paths = []
59
+ for i, d in tqdm(enumerate(draws)):
60
+ c_image, filled_segs = fill_segmentation(d, segments, image, n_max=n_max)
61
+ plt.cla()
62
+ plt.axis('off')
63
+ plt.imshow(mark_boundaries(image, filled_segs))
64
+ plt.imshow(c_image, alpha=0.3)
65
+ paths.append(os.path.join(tmpdirname, f"{i}.png"))
66
+ plt.savefig(paths[-1])
67
+
68
+ # Save to gif
69
+ # https://stackoverflow.com/questions/61716066/creating-an-animation-out-of-matplotlib-pngs
70
+ print(f"Saving gif to {save_loc}")
71
+ ims = [imageio.imread(f) for f in paths]
72
+ return imageio.mimwrite(imageio.RETURN_BYTES, ims)
73
+