Spaces:
Running
Running
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Model base classes and utilities.""" | |
from typing import Dict, List, Tuple | |
import chex | |
from clrs._src import probing | |
from clrs._src import specs | |
import numpy as np | |
_Array = chex.Array | |
Result = Dict[str, probing.DataPoint] | |
def fuse_perm_and_mask(perm: probing.DataPoint, | |
mask: probing.DataPoint) -> probing.DataPoint: | |
"""Replace permutation pointers active in the mask with self-pointers. | |
Args: | |
perm: a node permutation_pointer; data shape is expected to be | |
[..., N, N], and ideally one-hot over the last two dimensions, although | |
this method does not check for one-hotness. | |
mask: a mask_one over nodes; data shape is expected to be | |
[..., N], and ideally one-hot over the last dimension, although | |
this method does not check for one-hotness. | |
Returns: | |
A node pointer with shape [..., N]. | |
""" | |
assert perm.type_ == specs.Type.PERMUTATION_POINTER | |
assert perm.location == specs.Location.NODE | |
assert mask.name == perm.name + '_mask' | |
assert mask.type_ == specs.Type.MASK_ONE | |
assert mask.location == specs.Location.NODE | |
assert perm.data.shape[-1] == perm.data.shape[-2] | |
assert perm.data.shape[:-1] == mask.data.shape | |
data = np.where(mask.data > 0.5, | |
np.arange(perm.data.shape[-1]), # self-pointers | |
np.argmax(perm.data, axis=-1)) # original pointers | |
return probing.DataPoint(name=perm.name, | |
type_=specs.Type.POINTER, | |
location=perm.location, | |
data=data) | |
def _reduce_permutations_tuple( | |
targets: Tuple[probing.DataPoint, ...]) -> Tuple[probing.DataPoint, ...]: | |
"""Reduce node pointer + mask_one permutation to just node pointer.""" | |
out_targets = [] | |
n_perms = 0 | |
i = 0 | |
while i < len(targets): | |
truth = targets[i] | |
if truth.type_ != specs.Type.PERMUTATION_POINTER: | |
out_targets.append(truth) | |
i += 1 | |
continue | |
truth_mask = targets[i + 1] | |
out_targets.append(fuse_perm_and_mask(truth, truth_mask)) | |
i += 2 | |
n_perms += 1 | |
assert len(out_targets) == len(targets) - n_perms | |
return tuple(out_targets) | |
def _reduce_permutations_dict(predictions: Result) -> Result: | |
"""Reduce node pointer + mask_one permutation to just node pointer.""" | |
out_preds = {} | |
n_perms = 0 | |
for k, pred in predictions.items(): | |
if (k.endswith('_mask') and k[:-5] in predictions and | |
predictions[k[:-5]].type_ == specs.Type.PERMUTATION_POINTER): | |
# This mask will be processed with its associated permutation datapoint | |
continue | |
if pred.type_ != specs.Type.PERMUTATION_POINTER: | |
out_preds[k] = pred | |
continue | |
pred_mask = predictions[k + '_mask'] | |
out_preds[k] = fuse_perm_and_mask(pred, pred_mask) | |
n_perms += 1 | |
assert len(out_preds) == len(predictions) - n_perms | |
return out_preds | |
def evaluate_hints( | |
hints: Tuple[probing.DataPoint, ...], | |
lengths: _Array, | |
hint_preds: List[Result], | |
) -> Dict[str, _Array]: | |
"""Evaluate hint predictions.""" | |
evals = {} | |
hints = _reduce_permutations_tuple(hints) | |
hint_preds = [_reduce_permutations_dict(h) for h in hint_preds] | |
for truth in hints: | |
assert truth.name in hint_preds[0] | |
eval_along_time = [_evaluate(truth, p[truth.name], | |
idx=i+1, lengths=lengths) | |
for (i, p) in enumerate(hint_preds)] | |
evals[truth.name] = np.sum( | |
[x * np.sum(i+1 < lengths) | |
for i, x in enumerate(eval_along_time)]) / np.sum(lengths - 1) | |
evals[truth.name + '_along_time'] = np.array(eval_along_time) | |
# Unlike outputs, the hints sometimes include scalars, which don't have | |
# a meaningful eval score. So we don't compute a global 'hint score' as we | |
# do for outputs. | |
return evals | |
def evaluate( | |
outputs: Tuple[probing.DataPoint, ...], | |
predictions: Result, | |
) -> Dict[str, float]: | |
"""Evaluate output predictions.""" | |
evals = {} | |
outputs = _reduce_permutations_tuple(outputs) | |
predictions = _reduce_permutations_dict(predictions) | |
for truth in outputs: | |
assert truth.name in predictions | |
pred = predictions[truth.name] | |
evals[truth.name] = _evaluate(truth, pred) | |
# Return a single scalar score that is the mean of all output scores. | |
evals['score'] = sum([v.item() for v in evals.values()]) / len(evals) | |
return evals | |
def _evaluate(truth, pred, idx=None, lengths=None): | |
"""Evaluate single prediction of hint or output.""" | |
assert pred.name == truth.name | |
assert pred.location == truth.location | |
assert pred.type_ == truth.type_ | |
if truth.type_ not in _EVAL_FN: | |
raise ValueError('Invalid type') | |
truth_data = truth.data | |
pred_data = pred.data | |
if idx is not None: | |
if np.all(idx >= lengths): | |
return 0. | |
truth_data = truth_data[idx][idx < lengths] | |
pred_data = pred_data[idx < lengths] | |
return _EVAL_FN[truth.type_](pred_data, truth_data) | |
def _eval_one(pred, truth): | |
mask = np.all(truth != specs.OutputClass.MASKED, axis=-1) | |
return np.sum( | |
(np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask) | |
def _mask_fn(pred, truth): | |
"""Evaluate outputs of type MASK, and account for any class imbalance.""" | |
mask = (truth != specs.OutputClass.MASKED).astype(np.float32) | |
# Use F1 score for the masked outputs to address any imbalance | |
tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask) | |
fp = np.sum((((pred > 0.5) * (truth < 0.5)) * 1.0) * mask) | |
fn = np.sum((((pred < 0.5) * (truth > 0.5)) * 1.0) * mask) | |
# Protect against division by zero | |
if tp + fp > 0: | |
precision = tp / (tp + fp) | |
else: | |
precision = np.float32(1.0) | |
if tp + fn > 0: | |
recall = tp / (tp + fn) | |
else: | |
recall = np.float32(1.0) | |
if precision + recall > 0.0: | |
f_1 = 2.0 * precision * recall / (precision + recall) | |
else: | |
f_1 = np.float32(0.0) | |
return f_1 | |
_EVAL_FN = { | |
specs.Type.SCALAR: | |
lambda pred, truth: np.mean((pred - truth)**2), | |
specs.Type.MASK: _mask_fn, | |
specs.Type.MASK_ONE: | |
_eval_one, | |
specs.Type.CATEGORICAL: | |
_eval_one, | |
specs.Type.POINTER: | |
lambda pred, truth: np.mean((pred == truth) * 1.0), | |
} | |