""" Explanation class, with visualization functions. """ from io import open import os import os.path import json import string import numpy as np from .exceptions import LimeError from sklearn.utils import check_random_state def id_generator(size=15, random_state=None): """Helper function to generate random div ids. This is useful for embedding HTML into ipython notebooks.""" chars = list(string.ascii_uppercase + string.digits) return ''.join(random_state.choice(chars, size, replace=True)) class DomainMapper(object): """Class for mapping features to the specific domain. The idea is that there would be a subclass for each domain (text, tables, images, etc), so that we can have a general Explanation class, and separate out the specifics of visualizing features in here. """ def __init__(self): pass def map_exp_ids(self, exp, **kwargs): """Maps the feature ids to concrete names. Default behaviour is the identity function. Subclasses can implement this as they see fit. Args: exp: list of tuples [(id, weight), (id,weight)] kwargs: optional keyword arguments Returns: exp: list of tuples [(name, weight), (name, weight)...] """ return exp def visualize_instance_html(self, exp, label, div_name, exp_object_name, **kwargs): """Produces html for visualizing the instance. Default behaviour does nothing. Subclasses can implement this as they see fit. Args: exp: list of tuples [(id, weight), (id,weight)] label: label id (integer) div_name: name of div object to be used for rendering(in js) exp_object_name: name of js explanation object kwargs: optional keyword arguments Returns: js code for visualizing the instance """ return '' class Explanation(object): """Object returned by explainers.""" def __init__(self, domain_mapper, mode='classification', class_names=None, random_state=None): """ Initializer. Args: domain_mapper: must inherit from DomainMapper class type: "classification" or "regression" class_names: list of class names (only used for classification) random_state: an integer or numpy.RandomState that will be used to generate random numbers. If None, the random state will be initialized using the internal numpy seed. """ self.random_state = random_state self.mode = mode self.domain_mapper = domain_mapper self.local_exp = {} self.intercept = {} self.score = None self.local_pred = None if mode == 'classification': self.class_names = class_names self.top_labels = None self.predict_proba = None elif mode == 'regression': self.class_names = ['negative', 'positive'] self.predicted_value = None self.min_value = 0.0 self.max_value = 1.0 self.dummy_label = 1 else: raise LimeError('Invalid explanation mode "{}". ' 'Should be either "classification" ' 'or "regression".'.format(mode)) def available_labels(self): """ Returns the list of classification labels for which we have any explanations. """ try: assert self.mode == "classification" except AssertionError: raise NotImplementedError('Not supported for regression explanations.') else: ans = self.top_labels if self.top_labels else self.local_exp.keys() return list(ans) def as_list(self, label=1, **kwargs): """Returns the explanation as a list. Args: label: desired label. If you ask for a label for which an explanation wasn't computed, will throw an exception. Will be ignored for regression explanations. kwargs: keyword arguments, passed to domain_mapper Returns: list of tuples (representation, weight), where representation is given by domain_mapper. Weight is a float. """ label_to_use = label if self.mode == "classification" else self.dummy_label ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs) ans = [(x[0], float(x[1])) for x in ans] return ans def as_map(self): """Returns the map of explanations. Returns: Map from label to list of tuples (feature_id, weight). """ return self.local_exp def as_pyplot_figure(self, label=1, **kwargs): """Returns the explanation as a pyplot figure. Will throw an error if you don't have matplotlib installed Args: label: desired label. If you ask for a label for which an explanation wasn't computed, will throw an exception. Will be ignored for regression explanations. kwargs: keyword arguments, passed to domain_mapper Returns: pyplot figure (barchart). """ import matplotlib.pyplot as plt exp = self.as_list(label=label, **kwargs) fig = plt.figure() vals = [x[1] for x in exp] names = [x[0] for x in exp] vals.reverse() names.reverse() colors = ['green' if x > 0 else 'red' for x in vals] pos = np.arange(len(exp)) + .5 plt.barh(pos, vals, align='center', color=colors) plt.yticks(pos, names) if self.mode == "classification": title = 'Local explanation for class %s' % self.class_names[label] else: title = 'Local explanation' plt.title(title) return fig def show_in_notebook(self, labels=None, predict_proba=True, show_predicted_value=True, **kwargs): """Shows html explanation in ipython notebook. See as_html() for parameters. This will throw an error if you don't have IPython installed""" from IPython.core.display import display, HTML display(HTML(self.as_html(labels=labels, predict_proba=predict_proba, show_predicted_value=show_predicted_value, **kwargs))) def save_to_file(self, file_path, labels=None, predict_proba=True, show_predicted_value=True, **kwargs): """Saves html explanation to file. . Params: file_path: file to save explanations to See as_html() for additional parameters. """ file_ = open(file_path, 'w', encoding='utf8') file_.write(self.as_html(labels=labels, predict_proba=predict_proba, show_predicted_value=show_predicted_value, **kwargs)) file_.close() def as_html(self, labels=None, predict_proba=True, show_predicted_value=True, **kwargs): """Returns the explanation as an html page. Args: labels: desired labels to show explanations for (as barcharts). If you ask for a label for which an explanation wasn't computed, will throw an exception. If None, will show explanations for all available labels. (only used for classification) predict_proba: if true, add barchart with prediction probabilities for the top classes. (only used for classification) show_predicted_value: if true, add barchart with expected value (only used for regression) kwargs: keyword arguments, passed to domain_mapper Returns: code for an html page, including javascript includes. """ def jsonize(x): return json.dumps(x, ensure_ascii=False) if labels is None and self.mode == "classification": labels = self.available_labels() this_dir, _ = os.path.split(__file__) bundle = open(os.path.join(this_dir, 'bundle.js'), encoding="utf8").read() out = u''' ''' % bundle random_id = id_generator(size=15, random_state=check_random_state(self.random_state)) out += u'''
''' % random_id predict_proba_js = '' if self.mode == "classification" and predict_proba: predict_proba_js = u''' var pp_div = top_div.append('div') .classed('lime predict_proba', true); var pp_svg = pp_div.append('svg').style('width', '100%%'); var pp = new lime.PredictProba(pp_svg, %s, %s); ''' % (jsonize([str(x) for x in self.class_names]), jsonize(list(self.predict_proba.astype(float)))) predict_value_js = '' if self.mode == "regression" and show_predicted_value: # reference self.predicted_value # (svg, predicted_value, min_value, max_value) predict_value_js = u''' var pp_div = top_div.append('div') .classed('lime predicted_value', true); var pp_svg = pp_div.append('svg').style('width', '100%%'); var pp = new lime.PredictedValue(pp_svg, %s, %s, %s); ''' % (jsonize(float(self.predicted_value)), jsonize(float(self.min_value)), jsonize(float(self.max_value))) exp_js = '''var exp_div; var exp = new lime.Explanation(%s); ''' % (jsonize([str(x) for x in self.class_names])) if self.mode == "classification": for label in labels: exp = jsonize(self.as_list(label)) exp_js += u''' exp_div = top_div.append('div').classed('lime explanation', true); exp.show(%s, %d, exp_div); ''' % (exp, label) else: exp = jsonize(self.as_list()) exp_js += u''' exp_div = top_div.append('div').classed('lime explanation', true); exp.show(%s, %s, exp_div); ''' % (exp, self.dummy_label) raw_js = '''var raw_div = top_div.append('div');''' if self.mode == "classification": html_data = self.local_exp[labels[0]] else: html_data = self.local_exp[self.dummy_label] raw_js += self.domain_mapper.visualize_instance_html( html_data, labels[0] if self.mode == "classification" else self.dummy_label, 'raw_div', 'exp', **kwargs) out += u''' ''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js) out += u'' return out