Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Wrapper to apply any loss function on the true logits tensor.""" | |
from __future__ import annotations | |
from typing import Any, Callable, Mapping, MutableMapping | |
import tensorflow as tf, tf_keras | |
from official.recommendation.uplift import types | |
class TrueLogitsLoss(tf_keras.__internal__.losses.LossFunctionWrapper): | |
"""Computes any arbitrary loss between the labels and the true logits tensor. | |
Note that the prediction tensor is expected to be of a tensor of type | |
`TwoTowerTrainingOutputs`. | |
Example standalone usage: | |
>>> y_true = tf.ones((3, 1)) | |
>>> y_pred = types.TwoTowerTrainingOutputs( | |
... control_logits=tf.constant([[0], [1], [0]]), | |
... treatment_logits=tf.constant([[1], [0], [1]]), | |
... true_logits=tf.ones((3, 1)), | |
... is_treatment=tf.constant([[True], [False], [True]]) | |
... ) | |
>>> loss = TrueLogitsLoss( | |
... loss_fn=tf_keras.losses.mean_squared_error, | |
... name="mean_squared_error", | |
... reduction=tf_keras.losses.Reduction.SUM, | |
... ) | |
>>> loss(y_true, y_pred) | |
0.0 | |
Example usage with the `compile()` API: | |
```python | |
model.compile( | |
optimizer='sgd'. | |
loss=TrueLogitsLoss( | |
loss_fn=tf_keras.losses.categorical_crossentropy, | |
name="categorical_crossentropy", | |
from_logits=True | |
) | |
) | |
``` | |
""" | |
def __init__( | |
self, | |
loss_fn: Callable[[Any, tf.Tensor], tf.Tensor], | |
name: str = "true_logits_loss", | |
reduction=tf_keras.losses.Reduction.AUTO, | |
**loss_fn_kwargs, | |
): | |
"""Initialize `TrueLogitsLoss` instance. | |
Args: | |
loss_fn: The loss function to apply between the labels and true logits | |
tensor, with signature `loss_fn(y_true, y_pred, **loss_fn_kwargs)`. | |
name: Optional name for the instance. | |
reduction: Type of `tf_keras.losses.Reduction` to apply to loss. Default | |
value is `AUTO`. `AUTO` indicates that the reduction option will be | |
determined by the usage context. For almost all cases this defaults to | |
`SUM_OVER_BATCH_SIZE`. When used under a `tf.distribute.Strategy`, | |
except via `Model.compile()` and `Model.fit()`, using `AUTO` or | |
`SUM_OVER_BATCH_SIZE` will raise an error. | |
**loss_fn_kwargs: The keyword arguments that are passed on to `loss_fn`. | |
""" | |
super().__init__( | |
fn=loss_fn, name=name, reduction=reduction, **loss_fn_kwargs | |
) | |
def call( | |
self, y_true: Any, y_pred: types.TwoTowerTrainingOutputs | |
) -> tf.Tensor: | |
return super().call(y_true, y_pred.true_logits) | |
def get_config(self) -> Mapping[str, Any]: | |
config = super().get_config() | |
config["loss_fn"] = config.pop("fn") | |
return config | |
def from_config(cls, config: MutableMapping[str, Any]) -> TrueLogitsLoss: | |
config["loss_fn"] = tf_keras.losses.get(config["loss_fn"]) | |
return cls(**config) | |