Spaces:
Build error
Build error
File size: 11,885 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 |
"""
Explanation class, with visualization functions.
"""
from io import open
import os
import os.path
import json
import string
import numpy as np
from .exceptions import LimeError
from sklearn.utils import check_random_state
def id_generator(size=15, random_state=None):
"""Helper function to generate random div ids. This is useful for embedding
HTML into ipython notebooks."""
chars = list(string.ascii_uppercase + string.digits)
return ''.join(random_state.choice(chars, size, replace=True))
class DomainMapper(object):
"""Class for mapping features to the specific domain.
The idea is that there would be a subclass for each domain (text, tables,
images, etc), so that we can have a general Explanation class, and separate
out the specifics of visualizing features in here.
"""
def __init__(self):
pass
def map_exp_ids(self, exp, **kwargs):
"""Maps the feature ids to concrete names.
Default behaviour is the identity function. Subclasses can implement
this as they see fit.
Args:
exp: list of tuples [(id, weight), (id,weight)]
kwargs: optional keyword arguments
Returns:
exp: list of tuples [(name, weight), (name, weight)...]
"""
return exp
def visualize_instance_html(self,
exp,
label,
div_name,
exp_object_name,
**kwargs):
"""Produces html for visualizing the instance.
Default behaviour does nothing. Subclasses can implement this as they
see fit.
Args:
exp: list of tuples [(id, weight), (id,weight)]
label: label id (integer)
div_name: name of div object to be used for rendering(in js)
exp_object_name: name of js explanation object
kwargs: optional keyword arguments
Returns:
js code for visualizing the instance
"""
return ''
class Explanation(object):
"""Object returned by explainers."""
def __init__(self,
domain_mapper,
mode='classification',
class_names=None,
random_state=None):
"""
Initializer.
Args:
domain_mapper: must inherit from DomainMapper class
type: "classification" or "regression"
class_names: list of class names (only used for classification)
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
self.random_state = random_state
self.mode = mode
self.domain_mapper = domain_mapper
self.local_exp = {}
self.intercept = {}
self.score = None
self.local_pred = None
if mode == 'classification':
self.class_names = class_names
self.top_labels = None
self.predict_proba = None
elif mode == 'regression':
self.class_names = ['negative', 'positive']
self.predicted_value = None
self.min_value = 0.0
self.max_value = 1.0
self.dummy_label = 1
else:
raise LimeError('Invalid explanation mode "{}". '
'Should be either "classification" '
'or "regression".'.format(mode))
def available_labels(self):
"""
Returns the list of classification labels for which we have any explanations.
"""
try:
assert self.mode == "classification"
except AssertionError:
raise NotImplementedError('Not supported for regression explanations.')
else:
ans = self.top_labels if self.top_labels else self.local_exp.keys()
return list(ans)
def as_list(self, label=1, **kwargs):
"""Returns the explanation as a list.
Args:
label: desired label. If you ask for a label for which an
explanation wasn't computed, will throw an exception.
Will be ignored for regression explanations.
kwargs: keyword arguments, passed to domain_mapper
Returns:
list of tuples (representation, weight), where representation is
given by domain_mapper. Weight is a float.
"""
label_to_use = label if self.mode == "classification" else self.dummy_label
ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs)
ans = [(x[0], float(x[1])) for x in ans]
return ans
def as_map(self):
"""Returns the map of explanations.
Returns:
Map from label to list of tuples (feature_id, weight).
"""
return self.local_exp
def as_pyplot_figure(self, label=1, **kwargs):
"""Returns the explanation as a pyplot figure.
Will throw an error if you don't have matplotlib installed
Args:
label: desired label. If you ask for a label for which an
explanation wasn't computed, will throw an exception.
Will be ignored for regression explanations.
kwargs: keyword arguments, passed to domain_mapper
Returns:
pyplot figure (barchart).
"""
import matplotlib.pyplot as plt
exp = self.as_list(label=label, **kwargs)
fig = plt.figure()
vals = [x[1] for x in exp]
names = [x[0] for x in exp]
vals.reverse()
names.reverse()
colors = ['green' if x > 0 else 'red' for x in vals]
pos = np.arange(len(exp)) + .5
plt.barh(pos, vals, align='center', color=colors)
plt.yticks(pos, names)
if self.mode == "classification":
title = 'Local explanation for class %s' % self.class_names[label]
else:
title = 'Local explanation'
plt.title(title)
return fig
def show_in_notebook(self,
labels=None,
predict_proba=True,
show_predicted_value=True,
**kwargs):
"""Shows html explanation in ipython notebook.
See as_html() for parameters.
This will throw an error if you don't have IPython installed"""
from IPython.core.display import display, HTML
display(HTML(self.as_html(labels=labels,
predict_proba=predict_proba,
show_predicted_value=show_predicted_value,
**kwargs)))
def save_to_file(self,
file_path,
labels=None,
predict_proba=True,
show_predicted_value=True,
**kwargs):
"""Saves html explanation to file. .
Params:
file_path: file to save explanations to
See as_html() for additional parameters.
"""
file_ = open(file_path, 'w', encoding='utf8')
file_.write(self.as_html(labels=labels,
predict_proba=predict_proba,
show_predicted_value=show_predicted_value,
**kwargs))
file_.close()
def as_html(self,
labels=None,
predict_proba=True,
show_predicted_value=True,
**kwargs):
"""Returns the explanation as an html page.
Args:
labels: desired labels to show explanations for (as barcharts).
If you ask for a label for which an explanation wasn't
computed, will throw an exception. If None, will show
explanations for all available labels. (only used for classification)
predict_proba: if true, add barchart with prediction probabilities
for the top classes. (only used for classification)
show_predicted_value: if true, add barchart with expected value
(only used for regression)
kwargs: keyword arguments, passed to domain_mapper
Returns:
code for an html page, including javascript includes.
"""
def jsonize(x):
return json.dumps(x, ensure_ascii=False)
if labels is None and self.mode == "classification":
labels = self.available_labels()
this_dir, _ = os.path.split(__file__)
bundle = open(os.path.join(this_dir, 'bundle.js'),
encoding="utf8").read()
out = u'''<html>
<meta http-equiv="content-type" content="text/html; charset=UTF8">
<head><script>%s </script></head><body>''' % bundle
random_id = id_generator(size=15, random_state=check_random_state(self.random_state))
out += u'''
<div class="lime top_div" id="top_div%s"></div>
''' % random_id
predict_proba_js = ''
if self.mode == "classification" and predict_proba:
predict_proba_js = u'''
var pp_div = top_div.append('div')
.classed('lime predict_proba', true);
var pp_svg = pp_div.append('svg').style('width', '100%%');
var pp = new lime.PredictProba(pp_svg, %s, %s);
''' % (jsonize([str(x) for x in self.class_names]),
jsonize(list(self.predict_proba.astype(float))))
predict_value_js = ''
if self.mode == "regression" and show_predicted_value:
# reference self.predicted_value
# (svg, predicted_value, min_value, max_value)
predict_value_js = u'''
var pp_div = top_div.append('div')
.classed('lime predicted_value', true);
var pp_svg = pp_div.append('svg').style('width', '100%%');
var pp = new lime.PredictedValue(pp_svg, %s, %s, %s);
''' % (jsonize(float(self.predicted_value)),
jsonize(float(self.min_value)),
jsonize(float(self.max_value)))
exp_js = '''var exp_div;
var exp = new lime.Explanation(%s);
''' % (jsonize([str(x) for x in self.class_names]))
if self.mode == "classification":
for label in labels:
exp = jsonize(self.as_list(label))
exp_js += u'''
exp_div = top_div.append('div').classed('lime explanation', true);
exp.show(%s, %d, exp_div);
''' % (exp, label)
else:
exp = jsonize(self.as_list())
exp_js += u'''
exp_div = top_div.append('div').classed('lime explanation', true);
exp.show(%s, %s, exp_div);
''' % (exp, self.dummy_label)
raw_js = '''var raw_div = top_div.append('div');'''
if self.mode == "classification":
html_data = self.local_exp[labels[0]]
else:
html_data = self.local_exp[self.dummy_label]
raw_js += self.domain_mapper.visualize_instance_html(
html_data,
labels[0] if self.mode == "classification" else self.dummy_label,
'raw_div',
'exp',
**kwargs)
out += u'''
<script>
var top_div = d3.select('#top_div%s').classed('lime top_div', true);
%s
%s
%s
%s
</script>
''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js)
out += u'</body></html>'
return out
|