Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras metric for reporting metrics sliced by a feature.""" | |
import copy | |
import tensorflow as tf, tf_keras | |
class SlicedMetric(tf_keras.metrics.Metric): | |
"""A metric sliced by integer, boolean, or string features. | |
A metric wrapper that computes a metric for different slices of an arbitrary | |
feature. The slicing is specified via a slicing spec, which is a dictionary | |
from a slice name to the unique value to be sliced on. For each pair of | |
`slice_name`, `slicing_value` passed, the suffix `/slice_name` will be | |
appended to the name of the result of the corresponding slice. | |
An overall result is also computed without any slicing applied. | |
In order for this to work correctly, the given metric must support passing | |
`sample_weight` to its `update_state()` method. Additionally, the slicing | |
feature must also be passed to `update_state()` method of this class and | |
must be of a broadcastable shape to the metric inputs. | |
This wrapper creates a deep copy of the metric passed to it for each slice. | |
At every call to `update_state()`, the wrapper will call the `update_state()` | |
method of the metric for each slice with the `sample_weights` set to zero | |
where the slicing feature is not equal to the corresponding slicing value. | |
If the given metric returns a tensor, the result of this metric will be a | |
dictionary mapping from the sliced metric's name to the result for that slice. | |
If the given metric returns a dictionary of tensors, the result of this metric | |
will be a flattened dictionary consisting of each of the sliced metrics' | |
results for every slice. | |
Example usage: | |
>>> sliced_metric = SlicedMetric( | |
... tf_keras.metrics.Accuracy('accuracy'), | |
... slicing_spec={"control": False, "treatment": True}, | |
... ) | |
>>> sliced_metric.update_state( | |
... y_true=tf.constant([[0], [1], [0], [1]]), | |
... y_pred=tf.constant([[1], [0], [1], [1]]), | |
... slicing_feature=tf.constant([[True], [False], [True], [False]]), | |
... ) | |
>>> sliced_metric.result() | |
{ | |
"accuracy": 0.25, | |
"accuracy/control": 0.5, | |
"accuracy/treatment": 0 | |
} | |
""" | |
def __init__( | |
self, | |
metric: tf_keras.metrics.Metric, | |
slicing_spec: dict[str, str] | dict[str, int], | |
slicing_feature_dtype: tf.DType | None = None, | |
name: str | None = None, | |
): | |
"""Initializes the instance. | |
Args: | |
metric: A `tf_keras.metrics.Metric` instance. | |
slicing_spec: A dictionary that maps from string slice names, to one of | |
integer, boolean, or string slicing values. | |
slicing_feature_dtype: The expected dtype of the slicing feature. The | |
values in the slicing spec are casted to this type if passed. If None, | |
the dtype of the slicing feature is inferred based on the values in the | |
slicing spec. | |
name: The name of the wrapper metric. Defaults to `sliced_{metric.name}`. | |
Raises: | |
A ValueError if `slicing_spec` is empty, contains duplicate slicing | |
values, or has slicing values of different types. | |
""" | |
super().__init__(name=name or f"sliced_{metric.name}", dtype=metric.dtype) | |
if not slicing_spec: | |
raise ValueError("The slicing spec must be a non-empty dictionary.") | |
slice_names, slicing_values = zip(*slicing_spec.items()) | |
if not isinstance(slicing_values[0], (int, bool, str)) or not all( | |
isinstance(k, type(slicing_values[0])) for k in slicing_values | |
): | |
raise ValueError( | |
"All slicing values in the slicing spec must be one of `int`, " | |
"`bool`, or `str`, and all values must have the same type. " | |
f"Got types: {list(map(type, slicing_values))}." | |
) | |
if len(slicing_values) > len(set(slicing_values)): | |
raise ValueError( | |
"The slicing values passed to the slicing spec must be unique. Got " | |
f"{slicing_values}." | |
) | |
# TODO(b/276811843): Look into validating whether `metric` accepts | |
# `sample_weights` in its `update_state` method. | |
# Instance fully owns a deep copy of the metric. | |
self._metric = copy.deepcopy(metric) | |
self._slice_names = list(slice_names) | |
self._slicing_values = list(slicing_values) | |
self._slicing_values_tensors = [ | |
tf.constant(v, slicing_feature_dtype) for v in slicing_values | |
] | |
self._slicing_feature_dtype = self._slicing_values_tensors[0].dtype | |
self._sliced_metrics = [copy.deepcopy(metric) for _ in self._slicing_values] | |
def update_state( | |
self, | |
*args: tf.Tensor, | |
sample_weight: tf.Tensor | None = None, | |
slicing_feature: tf.Tensor, | |
**kwargs, | |
): | |
"""Updates the state of the metrics for each slice. | |
Args: | |
*args: A variable amount of `tf.Tensor` instances that will be passed to | |
the `update_state` method of each metric. | |
sample_weight: An optional `tf.Tensor` used to weight the sample. Its | |
dimensions must be broadcastable to the shape(s) of *args. | |
slicing_feature: A `tf.Tensor` consisting of the feature to be sliced on. | |
Its dimensions must be broadcastable to the shape(s) of *args. | |
**kwargs: Keyword arguments that will be passed to the `update_state` | |
method of each metric. | |
""" | |
if slicing_feature.dtype != self._slicing_feature_dtype: | |
raise ValueError( | |
"The `slicing_feature` and slicing values in `slicing_spec` must " | |
"have the same type. Got types: " | |
f"{(slicing_feature.dtype, self._slicing_feature_dtype)}." | |
) | |
if sample_weight is not None: | |
for _ in range(len(slicing_feature.shape) - len(sample_weight.shape)): | |
sample_weight = tf.expand_dims(sample_weight, axis=-1) | |
for _ in range(len(sample_weight.shape) - len(slicing_feature.shape)): | |
slicing_feature = tf.expand_dims(slicing_feature, axis=-1) | |
self._metric.update_state(*args, sample_weight=sample_weight, **kwargs) | |
for slicing_val, metric in zip( | |
self._slicing_values_tensors, self._sliced_metrics | |
): | |
slice_mask = tf.cast(slicing_feature == slicing_val, dtype=tf.float32) | |
if sample_weight is not None: | |
weight = slice_mask * tf.cast(sample_weight, dtype=tf.float32) | |
else: | |
weight = slice_mask | |
metric.update_state(*args, sample_weight=weight, **kwargs) | |
def result(self) -> dict[str, tf.Tensor]: | |
"""Aggregates all the metrics' results into a flattened dictionary.""" | |
metric_name = self._metric.name | |
metric_result = self._metric.result() | |
slice_results = [metric.result() for metric in self._sliced_metrics] | |
if isinstance(metric_result, tf.Tensor): | |
results = {metric_name: metric_result} | |
slice_names = (f"{metric_name}/{name}" for name in self._slice_names) | |
results.update(zip(slice_names, slice_results)) | |
return results | |
if isinstance(metric_result, dict) and all( | |
isinstance(result, tf.Tensor) for result in metric_result.values() | |
): | |
results = {**metric_result} | |
for slice_name, slice_result in zip(self._slice_names, slice_results): | |
result_names, result_values = zip(*slice_result.items()) | |
slice_names = [f"{name}/{slice_name}" for name in result_names] | |
results.update(zip(slice_names, result_values)) | |
return results | |
raise ValueError( | |
"The output of the given metric must either be a `tf.Tensor` or " | |
"a `dict[str, tf.Tensor]`, but got unsupported output: " | |
f"{metric_result}." | |
) | |
def get_config(self): | |
return { | |
"name": self.name, | |
"metric": tf_keras.metrics.serialize(self._metric), | |
"slicing_spec": dict(zip(self._slice_names, self._slicing_values)), | |
"slicing_feature_dtype": self._slicing_feature_dtype.name, | |
} | |
def from_config(cls, config): | |
config["metric"] = tf_keras.metrics.deserialize(config["metric"]) | |
config["slicing_feature_dtype"] = tf.as_dtype( | |
config["slicing_feature_dtype"] | |
) | |
return cls(**config) | |