deanna-emery's picture
updates
93528c6
raw
history blame
4.96 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a Keras model for the `TwoTowerUpliftNetwork` layer."""
from __future__ import annotations
from typing import Any, Callable, Mapping, MutableMapping
import tensorflow as tf, tf_keras
from official.recommendation.uplift import keys
from official.recommendation.uplift import types
from official.recommendation.uplift.layers.uplift_networks import base_uplift_networks
from official.recommendation.uplift.layers.uplift_networks import two_tower_output_head
@tf_keras.utils.register_keras_serializable(package="Uplift")
class TwoTowerUpliftModel(tf_keras.Model):
"""Training and inference model for a `BaseTwoTowerUpliftNetwork` layer."""
def __init__(
self,
treatment_indicator_feature_name: str,
uplift_network: base_uplift_networks.BaseTwoTowerUpliftNetwork,
inverse_link_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
**kwargs,
):
"""Initializes the instance.
Args:
treatment_indicator_feature_name: the name of the feature representing the
treatment_indicator tensor, which should be castable to a boolean tensor
(False for control and True for treatment). This tensor is required
during training and evaluation to compute the true logits needed for
loss computation.
uplift_network: a layer for computing control and treatment logits. Its
input is expected to be a dictionary of feature tensors and its output
is exptected to be a `TwoTowerNetworkOutputs` instance.
inverse_link_fn: a function for computing the control and treatment
predictions from their respective logits. If left as `None` it is
functionally equivalent to the identity function.
**kwargs: base model keyword arguments.
"""
super().__init__(**kwargs)
self._treatment_indicator_feature_name = treatment_indicator_feature_name
self._uplift_network = uplift_network
self._inverse_link_fn = inverse_link_fn
self._output_head = two_tower_output_head.TwoTowerOutputHead(
treatment_indicator_feature_name=treatment_indicator_feature_name,
uplift_network=uplift_network,
inverse_link_fn=inverse_link_fn,
)
def call(
self,
inputs: types.DictOfTensors,
training: bool | None = None,
mask: tf.Tensor | None = None,
) -> types.TwoTowerPredictionOutputs | types.TwoTowerTrainingOutputs:
return self._output_head(inputs=inputs, training=training, mask=mask)
def _assert_treatment_indicator_in_data(self, data):
inputs, _, _ = tf_keras.utils.unpack_x_y_sample_weight(data)
if self._treatment_indicator_feature_name not in inputs:
raise ValueError(
"The treatment_indicator feature (specified as"
f" '{self._treatment_indicator_feature_name}') must be part of the"
" inputs during training and evaluation, but got input features"
f" {set(inputs.keys())} instead."
)
def train_step(self, data) -> types.TwoTowerTrainingOutputs:
self._assert_treatment_indicator_in_data(data)
return super().train_step(data)
def test_step(self, data) -> types.TwoTowerTrainingOutputs:
self._assert_treatment_indicator_in_data(data)
return super().test_step(data)
def predict_step(self, data) -> dict[str, tf.Tensor]:
outputs = super().predict_step(data)
return {
keys.TwoTowerPredictionKeys.CONTROL: outputs.control_predictions,
keys.TwoTowerPredictionKeys.TREATMENT: outputs.treatment_predictions,
keys.TwoTowerPredictionKeys.UPLIFT: outputs.uplift,
}
def get_config(self) -> Mapping[str, Any]:
config = super().get_config()
config.update({
"treatment_indicator_feature_name": (
self._treatment_indicator_feature_name
),
"uplift_network": tf_keras.utils.serialize_keras_object(
self._uplift_network
),
"inverse_link_fn": tf_keras.utils.serialize_keras_object(
self._inverse_link_fn
),
})
return config
@classmethod
def from_config(cls, config: MutableMapping[str, Any]) -> TwoTowerUpliftModel:
config["uplift_network"] = tf_keras.layers.deserialize(
config["uplift_network"]
)
config["inverse_link_fn"] = tf_keras.utils.deserialize_keras_object(
config["inverse_link_fn"]
)
return cls(**config)