""" Functions for explaining classifiers that use Image data. """ import copy from functools import partial import numpy as np import sklearn import sklearn.preprocessing from sklearn.utils import check_random_state from skimage.color import gray2rgb from tqdm.auto import tqdm from . import lime_base from .wrappers.scikit_image import SegmentationAlgorithm class ImageExplanation(object): def __init__(self, image, segments): """Init function. Args: image: 3d numpy array segments: 2d numpy array, with the output from skimage.segmentation """ self.image = image self.segments = segments self.intercept = {} self.local_exp = {} self.local_pred = None def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, num_features=5, min_weight=0.): """Init function. Args: label: label to explain positive_only: if True, only take superpixels that positively contribute to the prediction of the label. negative_only: if True, only take superpixels that negatively contribute to the prediction of the label. If false, and so is positive_only, then both negativey and positively contributions will be taken. Both can't be True at the same time hide_rest: if True, make the non-explanation part of the return image gray num_features: number of superpixels to include in explanation min_weight: minimum weight of the superpixels to include in explanation Returns: (image, mask), where image is a 3d numpy array and mask is a 2d numpy array that can be used with skimage.segmentation.mark_boundaries """ if label not in self.local_exp: raise KeyError('Label not in explanation') if positive_only & negative_only: raise ValueError("Positive_only and negative_only cannot be true at the same time.") segments = self.segments image = self.image exp = self.local_exp[label] mask = np.zeros(segments.shape, segments.dtype) if hide_rest: temp = np.zeros(self.image.shape) else: temp = self.image.copy() if positive_only: fs = [x[0] for x in exp if x[1] > 0 and x[1] > min_weight][:num_features] if negative_only: fs = [x[0] for x in exp if x[1] < 0 and abs(x[1]) > min_weight][:num_features] if positive_only or negative_only: for f in fs: temp[segments == f] = image[segments == f].copy() mask[segments == f] = 1 return temp, mask else: for f, w in exp[:num_features]: if np.abs(w) < min_weight: continue c = 0 if w < 0 else 1 mask[segments == f] = -1 if w < 0 else 1 temp[segments == f] = image[segments == f].copy() temp[segments == f, c] = np.max(image) return temp, mask class LimeImageExplainer(object): """Explains predictions on Image (i.e. matrix) data. For numerical features, perturb them by sampling from a Normal(0,1) and doing the inverse operation of mean-centering and scaling, according to the means and stds in the training data. For categorical features, perturb by sampling according to the training distribution, and making a binary feature that is 1 when the value is the same as the instance being explained.""" def __init__(self, kernel_width=.25, kernel=None, verbose=False, feature_selection='auto', random_state=None): """Init function. Args: kernel_width: kernel width for the exponential kernel. If None, defaults to sqrt(number of columns) * 0.75. kernel: similarity kernel that takes euclidean distances and kernel width as input and outputs weights in (0,1). If None, defaults to an exponential kernel. verbose: if true, print local prediction values from linear model feature_selection: feature selection method. can be 'forward_selection', 'lasso_path', 'none' or 'auto'. See function 'explain_instance_with_data' in lime_base.py for details on what each of the options does. random_state: an integer or numpy.RandomState that will be used to generate random numbers. If None, the random state will be initialized using the internal numpy seed. """ kernel_width = float(kernel_width) if kernel is None: def kernel(d, kernel_width): return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) kernel_fn = partial(kernel, kernel_width=kernel_width) self.random_state = check_random_state(random_state) self.feature_selection = feature_selection self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state) ### Custom function to acquire segmentation only, same as in the explain_instance() function def acquireSegmOnly(self, img): random_seed = self.random_state.randint(0, high=1000) segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, max_dist=200, ratio=0.2, random_seed=random_seed) segments = segmentation_fn(img) return segments def explain_instance(self, image, inputImg, classifier_fn, labels=(1,), hide_color=None, top_labels=5, num_features=100000, num_samples=1000, batch_size=10, segmentation_fn=None, distance_metric='cosine', model_regressor=None, random_seed=None, squaredSegm=None, loadedSegmData=None): """Generates explanations for a prediction. First, we generate neighborhood data by randomly perturbing features from the instance (see __data_inverse). We then learn locally weighted linear models on this neighborhood data to explain each of the classes in an interpretable way (see lime_base.py). Args: image: 3 dimension RGB image. If this is only two dimensional, we will assume it's a grayscale image and call gray2rgb. classifier_fn: classifier prediction probability function, which takes a numpy array and outputs prediction probabilities. For ScikitClassifiers , this is classifier.predict_proba. labels: iterable with labels to be explained. hide_color: TODO top_labels: if not None, ignore labels and produce explanations for the K labels with highest prediction probabilities, where K is this parameter. num_features: maximum number of features present in explanation num_samples: size of the neighborhood to learn the linear model batch_size: TODO distance_metric: the distance metric to use for weights. model_regressor: sklearn regressor to use in explanation. Defaults to Ridge regression in LimeBase. Must have model_regressor.coef_ and 'sample_weight' as a parameter to model_regressor.fit() segmentation_fn: SegmentationAlgorithm, wrapped skimage segmentation function random_seed: integer used as random seed for the segmentation algorithm. If None, a random integer, between 0 and 1000, will be generated using the internal random number generator. squaredSegm: integer or None (default): Returns: An ImageExplanation object (see lime_image.py) with the corresponding explanations. """ if len(image.shape) == 2: image = gray2rgb(image) if random_seed is None: random_seed = self.random_state.randint(0, high=1000) if squaredSegm == 4: segments = np.zeros((image.shape[0], image.shape[1]), dtype=np.int64) imgW = image.shape[1] halfW1 = 1*imgW//4 halfW2 = 2*imgW//4 halfW3 = 3*imgW//4 segments[:,0:halfW1] = 0 segments[:,halfW1:halfW2] = 1 segments[:,halfW2:halfW3] = 2 segments[:,halfW3:imgW] = 3 elif squaredSegm == -2: ### Use to load custom resized segm data segments = loadedSegmData else: if segmentation_fn is None: segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, max_dist=200, ratio=0.2, random_seed=random_seed) try: segments = segmentation_fn(image) except ValueError as e: raise e fudged_image = image.copy() if hide_color is None: for x in np.unique(segments): fudged_image[segments == x] = ( np.mean(image[segments == x][:, 0]), np.mean(image[segments == x][:, 1]), np.mean(image[segments == x][:, 2])) else: fudged_image[:] = hide_color top = labels data, labels = self.data_labels(image, inputImg, fudged_image, segments, classifier_fn, num_samples, batch_size=batch_size) distances = sklearn.metrics.pairwise_distances( data, data[0].reshape(1, -1), metric=distance_metric ).ravel() ret_exp = ImageExplanation(image, segments) if top_labels: top = np.argsort(labels[0])[-top_labels:] ret_exp.top_labels = list(top) ret_exp.top_labels.reverse() for label in top: (ret_exp.intercept[label], ret_exp.local_exp[label], ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data( data, labels, distances, label, num_features, model_regressor=model_regressor, feature_selection=self.feature_selection) return ret_exp def data_labels(self, image, inputImg, fudged_image, segments, classifier_fn, num_samples, batch_size=10): """Generates images and predictions in the neighborhood of this image. Args: image: 3d numpy array, the image fudged_image: 3d numpy array, image to replace original image when superpixel is turned off segments: segmentation of the image classifier_fn: function that takes a list of images and returns a matrix of prediction probabilities num_samples: size of the neighborhood to learn the linear model batch_size: classifier_fn will be called on batches of this size. Returns: A tuple (data, labels), where: data: dense num_samples * num_superpixels labels: prediction probabilities matrix """ n_features = np.unique(segments).shape[0] data = self.random_state.randint(0, 2, num_samples * n_features)\ .reshape((num_samples, n_features)) labels = [] data[0, :] = 1 imgs = [] # print("data new shape: ", data.shape) # assert(False) # for row in tqdm(data): for row in data: temp = copy.deepcopy(image) zeros = np.where(row == 0)[0] mask = np.zeros(segments.shape).astype(bool) for z in zeros: mask[segments == z] = True temp[mask] = fudged_image[mask] imgs.append(temp) if len(imgs) == batch_size: preds = classifier_fn(inputImg) preds = preds.cpu().detach().numpy() labels.extend(preds) imgs = [] if len(imgs) > 0: preds = classifier_fn(inputImg) preds = preds.cpu().detach().numpy() labels.extend(preds) return data, np.array(labels)