Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Defines a Keras model for the `TwoTowerUpliftNetwork` layer.""" | |
from __future__ import annotations | |
from typing import Any, Callable, Mapping, MutableMapping | |
import tensorflow as tf, tf_keras | |
from official.recommendation.uplift import keys | |
from official.recommendation.uplift import types | |
from official.recommendation.uplift.layers.uplift_networks import base_uplift_networks | |
from official.recommendation.uplift.layers.uplift_networks import two_tower_output_head | |
class TwoTowerUpliftModel(tf_keras.Model): | |
"""Training and inference model for a `BaseTwoTowerUpliftNetwork` layer.""" | |
def __init__( | |
self, | |
treatment_indicator_feature_name: str, | |
uplift_network: base_uplift_networks.BaseTwoTowerUpliftNetwork, | |
inverse_link_fn: Callable[[tf.Tensor], tf.Tensor] | None = None, | |
**kwargs, | |
): | |
"""Initializes the instance. | |
Args: | |
treatment_indicator_feature_name: the name of the feature representing the | |
treatment_indicator tensor, which should be castable to a boolean tensor | |
(False for control and True for treatment). This tensor is required | |
during training and evaluation to compute the true logits needed for | |
loss computation. | |
uplift_network: a layer for computing control and treatment logits. Its | |
input is expected to be a dictionary of feature tensors and its output | |
is exptected to be a `TwoTowerNetworkOutputs` instance. | |
inverse_link_fn: a function for computing the control and treatment | |
predictions from their respective logits. If left as `None` it is | |
functionally equivalent to the identity function. | |
**kwargs: base model keyword arguments. | |
""" | |
super().__init__(**kwargs) | |
self._treatment_indicator_feature_name = treatment_indicator_feature_name | |
self._uplift_network = uplift_network | |
self._inverse_link_fn = inverse_link_fn | |
self._output_head = two_tower_output_head.TwoTowerOutputHead( | |
treatment_indicator_feature_name=treatment_indicator_feature_name, | |
uplift_network=uplift_network, | |
inverse_link_fn=inverse_link_fn, | |
) | |
def call( | |
self, | |
inputs: types.DictOfTensors, | |
training: bool | None = None, | |
mask: tf.Tensor | None = None, | |
) -> types.TwoTowerPredictionOutputs | types.TwoTowerTrainingOutputs: | |
return self._output_head(inputs=inputs, training=training, mask=mask) | |
def _assert_treatment_indicator_in_data(self, data): | |
inputs, _, _ = tf_keras.utils.unpack_x_y_sample_weight(data) | |
if self._treatment_indicator_feature_name not in inputs: | |
raise ValueError( | |
"The treatment_indicator feature (specified as" | |
f" '{self._treatment_indicator_feature_name}') must be part of the" | |
" inputs during training and evaluation, but got input features" | |
f" {set(inputs.keys())} instead." | |
) | |
def train_step(self, data) -> types.TwoTowerTrainingOutputs: | |
self._assert_treatment_indicator_in_data(data) | |
return super().train_step(data) | |
def test_step(self, data) -> types.TwoTowerTrainingOutputs: | |
self._assert_treatment_indicator_in_data(data) | |
return super().test_step(data) | |
def predict_step(self, data) -> dict[str, tf.Tensor]: | |
outputs = super().predict_step(data) | |
return { | |
keys.TwoTowerPredictionKeys.CONTROL: outputs.control_predictions, | |
keys.TwoTowerPredictionKeys.TREATMENT: outputs.treatment_predictions, | |
keys.TwoTowerPredictionKeys.UPLIFT: outputs.uplift, | |
} | |
def get_config(self) -> Mapping[str, Any]: | |
config = super().get_config() | |
config.update({ | |
"treatment_indicator_feature_name": ( | |
self._treatment_indicator_feature_name | |
), | |
"uplift_network": tf_keras.utils.serialize_keras_object( | |
self._uplift_network | |
), | |
"inverse_link_fn": tf_keras.utils.serialize_keras_object( | |
self._inverse_link_fn | |
), | |
}) | |
return config | |
def from_config(cls, config: MutableMapping[str, Any]) -> TwoTowerUpliftModel: | |
config["uplift_network"] = tf_keras.layers.deserialize( | |
config["uplift_network"] | |
) | |
config["inverse_link_fn"] = tf_keras.utils.deserialize_keras_object( | |
config["inverse_link_fn"] | |
) | |
return cls(**config) | |