Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
6.67 kB
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for calculating losses."""
from typing import Dict, List, Tuple
import chex
from clrs._src import probing
from clrs._src import specs
import haiku as hk
import jax
import jax.numpy as jnp
_Array = chex.Array
_DataPoint = probing.DataPoint
_Location = specs.Location
_OutputClass = specs.OutputClass
_PredTrajectory = Dict[str, _Array]
_PredTrajectories = List[_PredTrajectory]
_Type = specs.Type
EPS = 1e-12
def _expand_to(x: _Array, y: _Array) -> _Array:
while len(y.shape) > len(x.shape):
x = jnp.expand_dims(x, -1)
return x
def _expand_and_broadcast_to(x: _Array, y: _Array) -> _Array:
return jnp.broadcast_to(_expand_to(x, y), y.shape)
def output_loss_chunked(truth: _DataPoint, pred: _Array,
is_last: _Array, nb_nodes: int) -> float:
"""Output loss for time-chunked training."""
mask = None
if truth.type_ == _Type.SCALAR:
loss = (pred - truth.data)**2
elif truth.type_ == _Type.MASK:
loss = (
jnp.maximum(pred, 0) - pred * truth.data +
jnp.log1p(jnp.exp(-jnp.abs(pred))))
mask = (truth.data != _OutputClass.MASKED)
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
mask = jnp.any(truth.data == _OutputClass.POSITIVE, axis=-1)
masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
jnp.float32)
loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred), axis=-1)
elif truth.type_ == _Type.POINTER:
loss = -jnp.sum(
hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), axis=-1)
elif truth.type_ == _Type.PERMUTATION_POINTER:
# Predictions are NxN logits aiming to represent a doubly stochastic matrix.
# Compute the cross entropy between doubly stochastic pred and truth_data
loss = -jnp.sum(truth.data * pred, axis=-1)
if mask is not None:
mask = mask * _expand_and_broadcast_to(is_last, loss)
else:
mask = _expand_and_broadcast_to(is_last, loss)
total_mask = jnp.maximum(jnp.sum(mask), EPS)
return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask
def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float:
"""Output loss for full-sample training."""
if truth.type_ == _Type.SCALAR:
total_loss = jnp.mean((pred - truth.data)**2)
elif truth.type_ == _Type.MASK:
loss = (
jnp.maximum(pred, 0) - pred * truth.data +
jnp.log1p(jnp.exp(-jnp.abs(pred))))
mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32)
total_loss = jnp.sum(loss * mask) / jnp.sum(mask)
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
jnp.float32)
total_loss = (-jnp.sum(masked_truth * jax.nn.log_softmax(pred)) /
jnp.sum(truth.data == _OutputClass.POSITIVE))
elif truth.type_ == _Type.POINTER:
total_loss = (
jnp.mean(-jnp.sum(
hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred),
axis=-1)))
elif truth.type_ == _Type.PERMUTATION_POINTER:
# Predictions are NxN logits aiming to represent a doubly stochastic matrix.
# Compute the cross entropy between doubly stochastic pred and truth_data
total_loss = jnp.mean(-jnp.sum(truth.data * pred, axis=-1))
return total_loss
def hint_loss_chunked(
truth: _DataPoint,
pred: _Array,
is_first: _Array,
nb_nodes: int,
):
"""Hint loss for time-chunked training."""
loss, mask = _hint_loss(
truth_data=truth.data,
truth_type=truth.type_,
pred=pred,
nb_nodes=nb_nodes,
)
mask *= (1 - _expand_to(is_first, loss)).astype(jnp.float32)
loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
return loss
def hint_loss(
truth: _DataPoint,
preds: List[_Array],
lengths: _Array,
nb_nodes: int,
verbose: bool = False,
):
"""Hint loss for full-sample training."""
total_loss = 0.
verbose_loss = {}
length = truth.data.shape[0] - 1
loss, mask = _hint_loss(
truth_data=truth.data[1:],
truth_type=truth.type_,
pred=jnp.stack(preds),
nb_nodes=nb_nodes,
)
mask *= _is_not_done_broadcast(lengths, jnp.arange(length)[:, None], loss)
loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
if verbose:
verbose_loss['loss_' + truth.name] = loss
else:
total_loss += loss
return verbose_loss if verbose else total_loss
def _hint_loss(
truth_data: _Array,
truth_type: str,
pred: _Array,
nb_nodes: int,
) -> Tuple[_Array, _Array]:
"""Hint loss helper."""
mask = None
if truth_type == _Type.SCALAR:
loss = (pred - truth_data)**2
elif truth_type == _Type.MASK:
loss = (jnp.maximum(pred, 0) - pred * truth_data +
jnp.log1p(jnp.exp(-jnp.abs(pred))))
mask = (truth_data != _OutputClass.MASKED).astype(jnp.float32) # pytype: disable=attribute-error # numpy-scalars
elif truth_type == _Type.MASK_ONE:
loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1,
keepdims=True)
elif truth_type == _Type.CATEGORICAL:
loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1)
mask = jnp.any(truth_data == _OutputClass.POSITIVE, axis=-1).astype(
jnp.float32)
elif truth_type == _Type.POINTER:
loss = -jnp.sum(
hk.one_hot(truth_data, nb_nodes) * jax.nn.log_softmax(pred),
axis=-1)
elif truth_type == _Type.PERMUTATION_POINTER:
# Predictions are NxN logits aiming to represent a doubly stochastic matrix.
# Compute the cross entropy between doubly stochastic pred and truth_data
loss = -jnp.sum(truth_data * pred, axis=-1)
if mask is None:
mask = jnp.ones_like(loss)
return loss, mask
def _is_not_done_broadcast(lengths, i, tensor):
is_not_done = (lengths > i + 1) * 1.0
while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars
is_not_done = jnp.expand_dims(is_not_done, -1)
return is_not_done