strexp / lime /explanation.py
markytools's picture
added strexp
d61b9c7
"""
Explanation class, with visualization functions.
"""
from io import open
import os
import os.path
import json
import string
import numpy as np
from .exceptions import LimeError
from sklearn.utils import check_random_state
def id_generator(size=15, random_state=None):
"""Helper function to generate random div ids. This is useful for embedding
HTML into ipython notebooks."""
chars = list(string.ascii_uppercase + string.digits)
return ''.join(random_state.choice(chars, size, replace=True))
class DomainMapper(object):
"""Class for mapping features to the specific domain.
The idea is that there would be a subclass for each domain (text, tables,
images, etc), so that we can have a general Explanation class, and separate
out the specifics of visualizing features in here.
"""
def __init__(self):
pass
def map_exp_ids(self, exp, **kwargs):
"""Maps the feature ids to concrete names.
Default behaviour is the identity function. Subclasses can implement
this as they see fit.
Args:
exp: list of tuples [(id, weight), (id,weight)]
kwargs: optional keyword arguments
Returns:
exp: list of tuples [(name, weight), (name, weight)...]
"""
return exp
def visualize_instance_html(self,
exp,
label,
div_name,
exp_object_name,
**kwargs):
"""Produces html for visualizing the instance.
Default behaviour does nothing. Subclasses can implement this as they
see fit.
Args:
exp: list of tuples [(id, weight), (id,weight)]
label: label id (integer)
div_name: name of div object to be used for rendering(in js)
exp_object_name: name of js explanation object
kwargs: optional keyword arguments
Returns:
js code for visualizing the instance
"""
return ''
class Explanation(object):
"""Object returned by explainers."""
def __init__(self,
domain_mapper,
mode='classification',
class_names=None,
random_state=None):
"""
Initializer.
Args:
domain_mapper: must inherit from DomainMapper class
type: "classification" or "regression"
class_names: list of class names (only used for classification)
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
self.random_state = random_state
self.mode = mode
self.domain_mapper = domain_mapper
self.local_exp = {}
self.intercept = {}
self.score = None
self.local_pred = None
if mode == 'classification':
self.class_names = class_names
self.top_labels = None
self.predict_proba = None
elif mode == 'regression':
self.class_names = ['negative', 'positive']
self.predicted_value = None
self.min_value = 0.0
self.max_value = 1.0
self.dummy_label = 1
else:
raise LimeError('Invalid explanation mode "{}". '
'Should be either "classification" '
'or "regression".'.format(mode))
def available_labels(self):
"""
Returns the list of classification labels for which we have any explanations.
"""
try:
assert self.mode == "classification"
except AssertionError:
raise NotImplementedError('Not supported for regression explanations.')
else:
ans = self.top_labels if self.top_labels else self.local_exp.keys()
return list(ans)
def as_list(self, label=1, **kwargs):
"""Returns the explanation as a list.
Args:
label: desired label. If you ask for a label for which an
explanation wasn't computed, will throw an exception.
Will be ignored for regression explanations.
kwargs: keyword arguments, passed to domain_mapper
Returns:
list of tuples (representation, weight), where representation is
given by domain_mapper. Weight is a float.
"""
label_to_use = label if self.mode == "classification" else self.dummy_label
ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs)
ans = [(x[0], float(x[1])) for x in ans]
return ans
def as_map(self):
"""Returns the map of explanations.
Returns:
Map from label to list of tuples (feature_id, weight).
"""
return self.local_exp
def as_pyplot_figure(self, label=1, **kwargs):
"""Returns the explanation as a pyplot figure.
Will throw an error if you don't have matplotlib installed
Args:
label: desired label. If you ask for a label for which an
explanation wasn't computed, will throw an exception.
Will be ignored for regression explanations.
kwargs: keyword arguments, passed to domain_mapper
Returns:
pyplot figure (barchart).
"""
import matplotlib.pyplot as plt
exp = self.as_list(label=label, **kwargs)
fig = plt.figure()
vals = [x[1] for x in exp]
names = [x[0] for x in exp]
vals.reverse()
names.reverse()
colors = ['green' if x > 0 else 'red' for x in vals]
pos = np.arange(len(exp)) + .5
plt.barh(pos, vals, align='center', color=colors)
plt.yticks(pos, names)
if self.mode == "classification":
title = 'Local explanation for class %s' % self.class_names[label]
else:
title = 'Local explanation'
plt.title(title)
return fig
def show_in_notebook(self,
labels=None,
predict_proba=True,
show_predicted_value=True,
**kwargs):
"""Shows html explanation in ipython notebook.
See as_html() for parameters.
This will throw an error if you don't have IPython installed"""
from IPython.core.display import display, HTML
display(HTML(self.as_html(labels=labels,
predict_proba=predict_proba,
show_predicted_value=show_predicted_value,
**kwargs)))
def save_to_file(self,
file_path,
labels=None,
predict_proba=True,
show_predicted_value=True,
**kwargs):
"""Saves html explanation to file. .
Params:
file_path: file to save explanations to
See as_html() for additional parameters.
"""
file_ = open(file_path, 'w', encoding='utf8')
file_.write(self.as_html(labels=labels,
predict_proba=predict_proba,
show_predicted_value=show_predicted_value,
**kwargs))
file_.close()
def as_html(self,
labels=None,
predict_proba=True,
show_predicted_value=True,
**kwargs):
"""Returns the explanation as an html page.
Args:
labels: desired labels to show explanations for (as barcharts).
If you ask for a label for which an explanation wasn't
computed, will throw an exception. If None, will show
explanations for all available labels. (only used for classification)
predict_proba: if true, add barchart with prediction probabilities
for the top classes. (only used for classification)
show_predicted_value: if true, add barchart with expected value
(only used for regression)
kwargs: keyword arguments, passed to domain_mapper
Returns:
code for an html page, including javascript includes.
"""
def jsonize(x):
return json.dumps(x, ensure_ascii=False)
if labels is None and self.mode == "classification":
labels = self.available_labels()
this_dir, _ = os.path.split(__file__)
bundle = open(os.path.join(this_dir, 'bundle.js'),
encoding="utf8").read()
out = u'''<html>
<meta http-equiv="content-type" content="text/html; charset=UTF8">
<head><script>%s </script></head><body>''' % bundle
random_id = id_generator(size=15, random_state=check_random_state(self.random_state))
out += u'''
<div class="lime top_div" id="top_div%s"></div>
''' % random_id
predict_proba_js = ''
if self.mode == "classification" and predict_proba:
predict_proba_js = u'''
var pp_div = top_div.append('div')
.classed('lime predict_proba', true);
var pp_svg = pp_div.append('svg').style('width', '100%%');
var pp = new lime.PredictProba(pp_svg, %s, %s);
''' % (jsonize([str(x) for x in self.class_names]),
jsonize(list(self.predict_proba.astype(float))))
predict_value_js = ''
if self.mode == "regression" and show_predicted_value:
# reference self.predicted_value
# (svg, predicted_value, min_value, max_value)
predict_value_js = u'''
var pp_div = top_div.append('div')
.classed('lime predicted_value', true);
var pp_svg = pp_div.append('svg').style('width', '100%%');
var pp = new lime.PredictedValue(pp_svg, %s, %s, %s);
''' % (jsonize(float(self.predicted_value)),
jsonize(float(self.min_value)),
jsonize(float(self.max_value)))
exp_js = '''var exp_div;
var exp = new lime.Explanation(%s);
''' % (jsonize([str(x) for x in self.class_names]))
if self.mode == "classification":
for label in labels:
exp = jsonize(self.as_list(label))
exp_js += u'''
exp_div = top_div.append('div').classed('lime explanation', true);
exp.show(%s, %d, exp_div);
''' % (exp, label)
else:
exp = jsonize(self.as_list())
exp_js += u'''
exp_div = top_div.append('div').classed('lime explanation', true);
exp.show(%s, %s, exp_div);
''' % (exp, self.dummy_label)
raw_js = '''var raw_div = top_div.append('div');'''
if self.mode == "classification":
html_data = self.local_exp[labels[0]]
else:
html_data = self.local_exp[self.dummy_label]
raw_js += self.domain_mapper.visualize_instance_html(
html_data,
labels[0] if self.mode == "classification" else self.dummy_label,
'raw_div',
'exp',
**kwargs)
out += u'''
<script>
var top_div = d3.select('#top_div%s').classed('lime top_div', true);
%s
%s
%s
%s
</script>
''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js)
out += u'</body></html>'
return out