deanna-emery's picture
updates
93528c6
raw
history blame
3.52 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper to apply any loss function on the true logits tensor."""
from __future__ import annotations
from typing import Any, Callable, Mapping, MutableMapping
import tensorflow as tf, tf_keras
from official.recommendation.uplift import types
@tf_keras.utils.register_keras_serializable(package="Uplift")
class TrueLogitsLoss(tf_keras.__internal__.losses.LossFunctionWrapper):
"""Computes any arbitrary loss between the labels and the true logits tensor.
Note that the prediction tensor is expected to be of a tensor of type
`TwoTowerTrainingOutputs`.
Example standalone usage:
>>> y_true = tf.ones((3, 1))
>>> y_pred = types.TwoTowerTrainingOutputs(
... control_logits=tf.constant([[0], [1], [0]]),
... treatment_logits=tf.constant([[1], [0], [1]]),
... true_logits=tf.ones((3, 1)),
... is_treatment=tf.constant([[True], [False], [True]])
... )
>>> loss = TrueLogitsLoss(
... loss_fn=tf_keras.losses.mean_squared_error,
... name="mean_squared_error",
... reduction=tf_keras.losses.Reduction.SUM,
... )
>>> loss(y_true, y_pred)
0.0
Example usage with the `compile()` API:
```python
model.compile(
optimizer='sgd'.
loss=TrueLogitsLoss(
loss_fn=tf_keras.losses.categorical_crossentropy,
name="categorical_crossentropy",
from_logits=True
)
)
```
"""
def __init__(
self,
loss_fn: Callable[[Any, tf.Tensor], tf.Tensor],
name: str = "true_logits_loss",
reduction=tf_keras.losses.Reduction.AUTO,
**loss_fn_kwargs,
):
"""Initialize `TrueLogitsLoss` instance.
Args:
loss_fn: The loss function to apply between the labels and true logits
tensor, with signature `loss_fn(y_true, y_pred, **loss_fn_kwargs)`.
name: Optional name for the instance.
reduction: Type of `tf_keras.losses.Reduction` to apply to loss. Default
value is `AUTO`. `AUTO` indicates that the reduction option will be
determined by the usage context. For almost all cases this defaults to
`SUM_OVER_BATCH_SIZE`. When used under a `tf.distribute.Strategy`,
except via `Model.compile()` and `Model.fit()`, using `AUTO` or
`SUM_OVER_BATCH_SIZE` will raise an error.
**loss_fn_kwargs: The keyword arguments that are passed on to `loss_fn`.
"""
super().__init__(
fn=loss_fn, name=name, reduction=reduction, **loss_fn_kwargs
)
def call(
self, y_true: Any, y_pred: types.TwoTowerTrainingOutputs
) -> tf.Tensor:
return super().call(y_true, y_pred.true_logits)
def get_config(self) -> Mapping[str, Any]:
config = super().get_config()
config["loss_fn"] = config.pop("fn")
return config
@classmethod
def from_config(cls, config: MutableMapping[str, Any]) -> TrueLogitsLoss:
config["loss_fn"] = tf_keras.losses.get(config["loss_fn"])
return cls(**config)