Spaces:
Build error
Build error
""" | |
Explanation class, with visualization functions. | |
""" | |
from io import open | |
import os | |
import os.path | |
import json | |
import string | |
import numpy as np | |
from .exceptions import LimeError | |
from sklearn.utils import check_random_state | |
def id_generator(size=15, random_state=None): | |
"""Helper function to generate random div ids. This is useful for embedding | |
HTML into ipython notebooks.""" | |
chars = list(string.ascii_uppercase + string.digits) | |
return ''.join(random_state.choice(chars, size, replace=True)) | |
class DomainMapper(object): | |
"""Class for mapping features to the specific domain. | |
The idea is that there would be a subclass for each domain (text, tables, | |
images, etc), so that we can have a general Explanation class, and separate | |
out the specifics of visualizing features in here. | |
""" | |
def __init__(self): | |
pass | |
def map_exp_ids(self, exp, **kwargs): | |
"""Maps the feature ids to concrete names. | |
Default behaviour is the identity function. Subclasses can implement | |
this as they see fit. | |
Args: | |
exp: list of tuples [(id, weight), (id,weight)] | |
kwargs: optional keyword arguments | |
Returns: | |
exp: list of tuples [(name, weight), (name, weight)...] | |
""" | |
return exp | |
def visualize_instance_html(self, | |
exp, | |
label, | |
div_name, | |
exp_object_name, | |
**kwargs): | |
"""Produces html for visualizing the instance. | |
Default behaviour does nothing. Subclasses can implement this as they | |
see fit. | |
Args: | |
exp: list of tuples [(id, weight), (id,weight)] | |
label: label id (integer) | |
div_name: name of div object to be used for rendering(in js) | |
exp_object_name: name of js explanation object | |
kwargs: optional keyword arguments | |
Returns: | |
js code for visualizing the instance | |
""" | |
return '' | |
class Explanation(object): | |
"""Object returned by explainers.""" | |
def __init__(self, | |
domain_mapper, | |
mode='classification', | |
class_names=None, | |
random_state=None): | |
""" | |
Initializer. | |
Args: | |
domain_mapper: must inherit from DomainMapper class | |
type: "classification" or "regression" | |
class_names: list of class names (only used for classification) | |
random_state: an integer or numpy.RandomState that will be used to | |
generate random numbers. If None, the random state will be | |
initialized using the internal numpy seed. | |
""" | |
self.random_state = random_state | |
self.mode = mode | |
self.domain_mapper = domain_mapper | |
self.local_exp = {} | |
self.intercept = {} | |
self.score = None | |
self.local_pred = None | |
if mode == 'classification': | |
self.class_names = class_names | |
self.top_labels = None | |
self.predict_proba = None | |
elif mode == 'regression': | |
self.class_names = ['negative', 'positive'] | |
self.predicted_value = None | |
self.min_value = 0.0 | |
self.max_value = 1.0 | |
self.dummy_label = 1 | |
else: | |
raise LimeError('Invalid explanation mode "{}". ' | |
'Should be either "classification" ' | |
'or "regression".'.format(mode)) | |
def available_labels(self): | |
""" | |
Returns the list of classification labels for which we have any explanations. | |
""" | |
try: | |
assert self.mode == "classification" | |
except AssertionError: | |
raise NotImplementedError('Not supported for regression explanations.') | |
else: | |
ans = self.top_labels if self.top_labels else self.local_exp.keys() | |
return list(ans) | |
def as_list(self, label=1, **kwargs): | |
"""Returns the explanation as a list. | |
Args: | |
label: desired label. If you ask for a label for which an | |
explanation wasn't computed, will throw an exception. | |
Will be ignored for regression explanations. | |
kwargs: keyword arguments, passed to domain_mapper | |
Returns: | |
list of tuples (representation, weight), where representation is | |
given by domain_mapper. Weight is a float. | |
""" | |
label_to_use = label if self.mode == "classification" else self.dummy_label | |
ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs) | |
ans = [(x[0], float(x[1])) for x in ans] | |
return ans | |
def as_map(self): | |
"""Returns the map of explanations. | |
Returns: | |
Map from label to list of tuples (feature_id, weight). | |
""" | |
return self.local_exp | |
def as_pyplot_figure(self, label=1, **kwargs): | |
"""Returns the explanation as a pyplot figure. | |
Will throw an error if you don't have matplotlib installed | |
Args: | |
label: desired label. If you ask for a label for which an | |
explanation wasn't computed, will throw an exception. | |
Will be ignored for regression explanations. | |
kwargs: keyword arguments, passed to domain_mapper | |
Returns: | |
pyplot figure (barchart). | |
""" | |
import matplotlib.pyplot as plt | |
exp = self.as_list(label=label, **kwargs) | |
fig = plt.figure() | |
vals = [x[1] for x in exp] | |
names = [x[0] for x in exp] | |
vals.reverse() | |
names.reverse() | |
colors = ['green' if x > 0 else 'red' for x in vals] | |
pos = np.arange(len(exp)) + .5 | |
plt.barh(pos, vals, align='center', color=colors) | |
plt.yticks(pos, names) | |
if self.mode == "classification": | |
title = 'Local explanation for class %s' % self.class_names[label] | |
else: | |
title = 'Local explanation' | |
plt.title(title) | |
return fig | |
def show_in_notebook(self, | |
labels=None, | |
predict_proba=True, | |
show_predicted_value=True, | |
**kwargs): | |
"""Shows html explanation in ipython notebook. | |
See as_html() for parameters. | |
This will throw an error if you don't have IPython installed""" | |
from IPython.core.display import display, HTML | |
display(HTML(self.as_html(labels=labels, | |
predict_proba=predict_proba, | |
show_predicted_value=show_predicted_value, | |
**kwargs))) | |
def save_to_file(self, | |
file_path, | |
labels=None, | |
predict_proba=True, | |
show_predicted_value=True, | |
**kwargs): | |
"""Saves html explanation to file. . | |
Params: | |
file_path: file to save explanations to | |
See as_html() for additional parameters. | |
""" | |
file_ = open(file_path, 'w', encoding='utf8') | |
file_.write(self.as_html(labels=labels, | |
predict_proba=predict_proba, | |
show_predicted_value=show_predicted_value, | |
**kwargs)) | |
file_.close() | |
def as_html(self, | |
labels=None, | |
predict_proba=True, | |
show_predicted_value=True, | |
**kwargs): | |
"""Returns the explanation as an html page. | |
Args: | |
labels: desired labels to show explanations for (as barcharts). | |
If you ask for a label for which an explanation wasn't | |
computed, will throw an exception. If None, will show | |
explanations for all available labels. (only used for classification) | |
predict_proba: if true, add barchart with prediction probabilities | |
for the top classes. (only used for classification) | |
show_predicted_value: if true, add barchart with expected value | |
(only used for regression) | |
kwargs: keyword arguments, passed to domain_mapper | |
Returns: | |
code for an html page, including javascript includes. | |
""" | |
def jsonize(x): | |
return json.dumps(x, ensure_ascii=False) | |
if labels is None and self.mode == "classification": | |
labels = self.available_labels() | |
this_dir, _ = os.path.split(__file__) | |
bundle = open(os.path.join(this_dir, 'bundle.js'), | |
encoding="utf8").read() | |
out = u'''<html> | |
<meta http-equiv="content-type" content="text/html; charset=UTF8"> | |
<head><script>%s </script></head><body>''' % bundle | |
random_id = id_generator(size=15, random_state=check_random_state(self.random_state)) | |
out += u''' | |
<div class="lime top_div" id="top_div%s"></div> | |
''' % random_id | |
predict_proba_js = '' | |
if self.mode == "classification" and predict_proba: | |
predict_proba_js = u''' | |
var pp_div = top_div.append('div') | |
.classed('lime predict_proba', true); | |
var pp_svg = pp_div.append('svg').style('width', '100%%'); | |
var pp = new lime.PredictProba(pp_svg, %s, %s); | |
''' % (jsonize([str(x) for x in self.class_names]), | |
jsonize(list(self.predict_proba.astype(float)))) | |
predict_value_js = '' | |
if self.mode == "regression" and show_predicted_value: | |
# reference self.predicted_value | |
# (svg, predicted_value, min_value, max_value) | |
predict_value_js = u''' | |
var pp_div = top_div.append('div') | |
.classed('lime predicted_value', true); | |
var pp_svg = pp_div.append('svg').style('width', '100%%'); | |
var pp = new lime.PredictedValue(pp_svg, %s, %s, %s); | |
''' % (jsonize(float(self.predicted_value)), | |
jsonize(float(self.min_value)), | |
jsonize(float(self.max_value))) | |
exp_js = '''var exp_div; | |
var exp = new lime.Explanation(%s); | |
''' % (jsonize([str(x) for x in self.class_names])) | |
if self.mode == "classification": | |
for label in labels: | |
exp = jsonize(self.as_list(label)) | |
exp_js += u''' | |
exp_div = top_div.append('div').classed('lime explanation', true); | |
exp.show(%s, %d, exp_div); | |
''' % (exp, label) | |
else: | |
exp = jsonize(self.as_list()) | |
exp_js += u''' | |
exp_div = top_div.append('div').classed('lime explanation', true); | |
exp.show(%s, %s, exp_div); | |
''' % (exp, self.dummy_label) | |
raw_js = '''var raw_div = top_div.append('div');''' | |
if self.mode == "classification": | |
html_data = self.local_exp[labels[0]] | |
else: | |
html_data = self.local_exp[self.dummy_label] | |
raw_js += self.domain_mapper.visualize_instance_html( | |
html_data, | |
labels[0] if self.mode == "classification" else self.dummy_label, | |
'raw_div', | |
'exp', | |
**kwargs) | |
out += u''' | |
<script> | |
var top_div = d3.select('#top_div%s').classed('lime top_div', true); | |
%s | |
%s | |
%s | |
%s | |
</script> | |
''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js) | |
out += u'</body></html>' | |
return out | |