Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras metric for computing sliced metric values based on treatment group.""" | |
from typing import Optional | |
import tensorflow as tf, tf_keras | |
from official.recommendation.uplift.metrics import sliced_metric | |
class TreatmentSlicedMetric(sliced_metric.SlicedMetric): | |
"""Computes (weighted) sliced metric values based on treatment group. | |
Two copies are made of the given metric to compute the same metric for the | |
control and treatment groups. The metric name is used for the output results, | |
with a "/control" and "/treatment" suffix for the control and treatment | |
groups. If the given metric outputs a dictionary of tensors the result will be | |
a flattened dictionary with all the sliced metric results. | |
Example usage: | |
>>> sliced_metric = TreatmentSlicedMetric( | |
... metric=tf_keras.metrics.Mean(name="example/mean") | |
... ) | |
>>> sliced_metric.update_state( | |
... values=tf.constant([[0], [1], [5]]), | |
... is_treatment=tf.constant([[1], [0], [1]]), | |
... ) | |
>>> sliced_metric.result() | |
{ | |
"example/mean": 2.0, | |
"example/mean/control": 1.0, | |
"example/mean/treatment": 2.5 | |
} | |
""" | |
def __init__(self, metric: tf_keras.metrics.Metric): | |
"""Initializes a `TreatmentSlicedMetric` instance. | |
Args: | |
metric: A `keras.metrics.Metric` instance with implemented get_config and | |
from_config methods. Its update_state method is expected have a | |
`update_state(self, values, sample_weight=None)` signature. The control | |
and treatment metrics will have the same name as the given metric, with | |
"/control" and "/treatment" suffixes. | |
""" | |
super().__init__( | |
metric=metric, | |
slicing_spec={"control": False, "treatment": True}, | |
name=f"treatment_sliced_{metric.name}", | |
) | |
def update_state( | |
self, | |
values: tf.Tensor, | |
is_treatment: tf.Tensor, | |
sample_weight: Optional[tf.Tensor] = None, | |
**kwargs, | |
): | |
"""Computes aggregate, control and treatment metric updates. | |
Args: | |
values: A `tf.Tensor` instance of shape (D0, ..., DN) passed to the | |
`update_state` method of the treatment, control, and overall metrics. | |
is_treatment: a `tf.Tensor` of shape (D0,) or (D0, 1) castable to boolean | |
indicating if the example belongs to the treatment group (True) or | |
control group (False). | |
sample_weight: optional `tf.Tensor` for the sample weight. If given it | |
will also be sliced by the is_treatment tensor. | |
**kwargs: Keyword arguments that will be passed to the `update_state` | |
method of each metric. | |
""" | |
slicing_feature = tf.reshape(tf.cast(is_treatment, tf.bool), [-1]) | |
if sample_weight is not None: | |
sample_weight = tf.reshape(sample_weight, [-1]) | |
super().update_state( | |
values, | |
sample_weight=sample_weight, | |
slicing_feature=slicing_feature, | |
**kwargs, | |
) | |
def get_config(self): | |
return {"metric": tf_keras.metrics.serialize(self._metric)} | |
def from_config(cls, config): | |
config["metric"] = tf_keras.metrics.deserialize(config["metric"]) | |
return cls(**config) | |