Spaces:
Sleeping
Sleeping
"""Contains classes and methods related to interpretation for components in Gradio.""" | |
from __future__ import annotations | |
import copy | |
import math | |
from abc import ABC, abstractmethod | |
from typing import TYPE_CHECKING, Any | |
import numpy as np | |
from gradio_client import utils as client_utils | |
from gradio import components | |
if TYPE_CHECKING: # Only import for type checking (is False at runtime). | |
from gradio import Interface | |
class Interpretable(ABC): # noqa: B024 | |
def __init__(self) -> None: | |
self.set_interpret_parameters() | |
def set_interpret_parameters(self): # noqa: B027 | |
""" | |
Set any parameters for interpretation. Properties can be set here to be | |
used in get_interpretation_neighbors and get_interpretation_scores. | |
""" | |
pass | |
def get_interpretation_scores( | |
self, x: Any, neighbors: list[Any] | None, scores: list[float], **kwargs | |
) -> list: | |
""" | |
Arrange the output values from the neighbors into interpretation scores for the interface to render. | |
Parameters: | |
x: Input to interface | |
neighbors: Neighboring values to input x used for interpretation. | |
scores: Output value corresponding to each neighbor in neighbors | |
Returns: | |
Arrangement of interpretation scores for interfaces to render. | |
""" | |
return scores | |
class TokenInterpretable(Interpretable, ABC): | |
def tokenize(self, x: Any) -> tuple[list, list, None]: | |
""" | |
Interprets an input data point x by splitting it into a list of tokens (e.g | |
a string into words or an image into super-pixels). | |
""" | |
return [], [], None | |
def get_masked_inputs(self, tokens: list, binary_mask_matrix: list[list]) -> list: | |
return [] | |
class NeighborInterpretable(Interpretable, ABC): | |
def get_interpretation_neighbors(self, x: Any) -> tuple[list, dict]: | |
""" | |
Generates values similar to input to be used to interpret the significance of the input in the final output. | |
Parameters: | |
x: Input to interface | |
Returns: (neighbor_values, interpret_kwargs, interpret_by_removal) | |
neighbor_values: Neighboring values to input x to compute for interpretation | |
interpret_kwargs: Keyword arguments to be passed to get_interpretation_scores | |
""" | |
return [], {} | |
async def run_interpret(interface: Interface, raw_input: list): | |
""" | |
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box | |
interpretation for a certain set of UI component types, as well as the custom interpretation case. | |
Parameters: | |
raw_input: a list of raw inputs to apply the interpretation(s) on. | |
""" | |
if isinstance(interface.interpretation, list): # Either "default" or "shap" | |
processed_input = [ | |
input_component.preprocess(raw_input[i]) | |
for i, input_component in enumerate(interface.input_components) | |
] | |
original_output = await interface.call_function(0, processed_input) | |
original_output = original_output["prediction"] | |
if len(interface.output_components) == 1: | |
original_output = [original_output] | |
scores, alternative_outputs = [], [] | |
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)): | |
if interp == "default": | |
input_component = interface.input_components[i] | |
neighbor_raw_input = list(raw_input) | |
if isinstance(input_component, TokenInterpretable): | |
tokens, neighbor_values, masks = input_component.tokenize(x) | |
interface_scores = [] | |
alternative_output = [] | |
for neighbor_input in neighbor_values: | |
neighbor_raw_input[i] = neighbor_input | |
processed_neighbor_input = [ | |
input_component.preprocess(neighbor_raw_input[i]) | |
for i, input_component in enumerate( | |
interface.input_components | |
) | |
] | |
neighbor_output = await interface.call_function( | |
0, processed_neighbor_input | |
) | |
neighbor_output = neighbor_output["prediction"] | |
if len(interface.output_components) == 1: | |
neighbor_output = [neighbor_output] | |
processed_neighbor_output = [ | |
output_component.postprocess(neighbor_output[i]) | |
for i, output_component in enumerate( | |
interface.output_components | |
) | |
] | |
alternative_output.append(processed_neighbor_output) | |
interface_scores.append( | |
quantify_difference_in_label( | |
interface, original_output, neighbor_output | |
) | |
) | |
alternative_outputs.append(alternative_output) | |
scores.append( | |
input_component.get_interpretation_scores( | |
raw_input[i], | |
neighbor_values, | |
interface_scores, | |
masks=masks, | |
tokens=tokens, | |
) | |
) | |
elif isinstance(input_component, NeighborInterpretable): | |
( | |
neighbor_values, | |
interpret_kwargs, | |
) = input_component.get_interpretation_neighbors( | |
x | |
) # type: ignore | |
interface_scores = [] | |
alternative_output = [] | |
for neighbor_input in neighbor_values: | |
neighbor_raw_input[i] = neighbor_input | |
processed_neighbor_input = [ | |
input_component.preprocess(neighbor_raw_input[i]) | |
for i, input_component in enumerate( | |
interface.input_components | |
) | |
] | |
neighbor_output = await interface.call_function( | |
0, processed_neighbor_input | |
) | |
neighbor_output = neighbor_output["prediction"] | |
if len(interface.output_components) == 1: | |
neighbor_output = [neighbor_output] | |
processed_neighbor_output = [ | |
output_component.postprocess(neighbor_output[i]) | |
for i, output_component in enumerate( | |
interface.output_components | |
) | |
] | |
alternative_output.append(processed_neighbor_output) | |
interface_scores.append( | |
quantify_difference_in_label( | |
interface, original_output, neighbor_output | |
) | |
) | |
alternative_outputs.append(alternative_output) | |
interface_scores = [-score for score in interface_scores] | |
scores.append( | |
input_component.get_interpretation_scores( | |
raw_input[i], | |
neighbor_values, | |
interface_scores, | |
**interpret_kwargs, | |
) | |
) | |
else: | |
raise ValueError( | |
f"Component {input_component} does not support interpretation" | |
) | |
elif interp == "shap" or interp == "shapley": | |
try: | |
import shap # type: ignore | |
except (ImportError, ModuleNotFoundError) as err: | |
raise ValueError( | |
"The package `shap` is required for this interpretation method. Try: `pip install shap`" | |
) from err | |
input_component = interface.input_components[i] | |
if not isinstance(input_component, TokenInterpretable): | |
raise ValueError( | |
f"Input component {input_component} does not support `shap` interpretation" | |
) | |
tokens, _, masks = input_component.tokenize(x) | |
# construct a masked version of the input | |
def get_masked_prediction(binary_mask): | |
assert isinstance(input_component, TokenInterpretable) | |
masked_xs = input_component.get_masked_inputs(tokens, binary_mask) | |
preds = [] | |
for masked_x in masked_xs: | |
processed_masked_input = copy.deepcopy(processed_input) | |
processed_masked_input[i] = input_component.preprocess(masked_x) | |
new_output = client_utils.synchronize_async( | |
interface.call_function, 0, processed_masked_input | |
) | |
new_output = new_output["prediction"] | |
if len(interface.output_components) == 1: | |
new_output = [new_output] | |
pred = get_regression_or_classification_value( | |
interface, original_output, new_output | |
) | |
preds.append(pred) | |
return np.array(preds) | |
num_total_segments = len(tokens) | |
explainer = shap.KernelExplainer( | |
get_masked_prediction, np.zeros((1, num_total_segments)) | |
) | |
shap_values = explainer.shap_values( | |
np.ones((1, num_total_segments)), | |
nsamples=int(interface.num_shap * num_total_segments), | |
silent=True, | |
) | |
assert shap_values is not None, "SHAP values could not be calculated" | |
scores.append( | |
input_component.get_interpretation_scores( | |
raw_input[i], | |
None, | |
shap_values[0].tolist(), | |
masks=masks, | |
tokens=tokens, | |
) | |
) | |
alternative_outputs.append([]) | |
elif interp is None: | |
scores.append(None) | |
alternative_outputs.append([]) | |
else: | |
raise ValueError(f"Unknown interpretation method: {interp}") | |
return scores, alternative_outputs | |
elif interface.interpretation: # custom interpretation function | |
processed_input = [ | |
input_component.preprocess(raw_input[i]) | |
for i, input_component in enumerate(interface.input_components) | |
] | |
interpreter = interface.interpretation | |
interpretation = interpreter(*processed_input) | |
if len(raw_input) == 1: | |
interpretation = [interpretation] | |
return interpretation, [] | |
else: | |
raise ValueError("No interpretation method specified.") | |
def diff(original: Any, perturbed: Any) -> int | float: | |
try: # try computing numerical difference | |
score = float(original) - float(perturbed) | |
except ValueError: # otherwise, look at strict difference in label | |
score = int(original != perturbed) | |
return score | |
def quantify_difference_in_label( | |
interface: Interface, original_output: list, perturbed_output: list | |
) -> int | float: | |
output_component = interface.output_components[0] | |
post_original_output = output_component.postprocess(original_output[0]) | |
post_perturbed_output = output_component.postprocess(perturbed_output[0]) | |
if isinstance(output_component, components.Label): | |
original_label = post_original_output["label"] | |
perturbed_label = post_perturbed_output["label"] | |
# Handle different return types of Label interface | |
if "confidences" in post_original_output: | |
original_confidence = original_output[0][original_label] | |
perturbed_confidence = perturbed_output[0][original_label] | |
score = original_confidence - perturbed_confidence | |
else: | |
score = diff(original_label, perturbed_label) | |
return score | |
elif isinstance(output_component, components.Number): | |
score = diff(post_original_output, post_perturbed_output) | |
return score | |
else: | |
raise ValueError( | |
f"This interpretation method doesn't support the Output component: {output_component}" | |
) | |
def get_regression_or_classification_value( | |
interface: Interface, original_output: list, perturbed_output: list | |
) -> int | float: | |
"""Used to combine regression/classification for Shap interpretation method.""" | |
output_component = interface.output_components[0] | |
post_original_output = output_component.postprocess(original_output[0]) | |
post_perturbed_output = output_component.postprocess(perturbed_output[0]) | |
if isinstance(output_component, components.Label): | |
original_label = post_original_output["label"] | |
perturbed_label = post_perturbed_output["label"] | |
# Handle different return types of Label interface | |
if "confidences" in post_original_output: | |
if math.isnan(perturbed_output[0][original_label]): | |
return 0 | |
return perturbed_output[0][original_label] | |
else: | |
score = diff( | |
perturbed_label, original_label | |
) # Intentionally inverted order of arguments. | |
return score | |
else: | |
raise ValueError( | |
f"This interpretation method doesn't support the Output component: {output_component}" | |
) | |