Spaces:
Runtime error
Runtime error
File size: 8,564 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras metric for reporting metrics sliced by a feature."""
import copy
import tensorflow as tf, tf_keras
class SlicedMetric(tf_keras.metrics.Metric):
"""A metric sliced by integer, boolean, or string features.
A metric wrapper that computes a metric for different slices of an arbitrary
feature. The slicing is specified via a slicing spec, which is a dictionary
from a slice name to the unique value to be sliced on. For each pair of
`slice_name`, `slicing_value` passed, the suffix `/slice_name` will be
appended to the name of the result of the corresponding slice.
An overall result is also computed without any slicing applied.
In order for this to work correctly, the given metric must support passing
`sample_weight` to its `update_state()` method. Additionally, the slicing
feature must also be passed to `update_state()` method of this class and
must be of a broadcastable shape to the metric inputs.
This wrapper creates a deep copy of the metric passed to it for each slice.
At every call to `update_state()`, the wrapper will call the `update_state()`
method of the metric for each slice with the `sample_weights` set to zero
where the slicing feature is not equal to the corresponding slicing value.
If the given metric returns a tensor, the result of this metric will be a
dictionary mapping from the sliced metric's name to the result for that slice.
If the given metric returns a dictionary of tensors, the result of this metric
will be a flattened dictionary consisting of each of the sliced metrics'
results for every slice.
Example usage:
>>> sliced_metric = SlicedMetric(
... tf_keras.metrics.Accuracy('accuracy'),
... slicing_spec={"control": False, "treatment": True},
... )
>>> sliced_metric.update_state(
... y_true=tf.constant([[0], [1], [0], [1]]),
... y_pred=tf.constant([[1], [0], [1], [1]]),
... slicing_feature=tf.constant([[True], [False], [True], [False]]),
... )
>>> sliced_metric.result()
{
"accuracy": 0.25,
"accuracy/control": 0.5,
"accuracy/treatment": 0
}
"""
def __init__(
self,
metric: tf_keras.metrics.Metric,
slicing_spec: dict[str, str] | dict[str, int],
slicing_feature_dtype: tf.DType | None = None,
name: str | None = None,
):
"""Initializes the instance.
Args:
metric: A `tf_keras.metrics.Metric` instance.
slicing_spec: A dictionary that maps from string slice names, to one of
integer, boolean, or string slicing values.
slicing_feature_dtype: The expected dtype of the slicing feature. The
values in the slicing spec are casted to this type if passed. If None,
the dtype of the slicing feature is inferred based on the values in the
slicing spec.
name: The name of the wrapper metric. Defaults to `sliced_{metric.name}`.
Raises:
A ValueError if `slicing_spec` is empty, contains duplicate slicing
values, or has slicing values of different types.
"""
super().__init__(name=name or f"sliced_{metric.name}", dtype=metric.dtype)
if not slicing_spec:
raise ValueError("The slicing spec must be a non-empty dictionary.")
slice_names, slicing_values = zip(*slicing_spec.items())
if not isinstance(slicing_values[0], (int, bool, str)) or not all(
isinstance(k, type(slicing_values[0])) for k in slicing_values
):
raise ValueError(
"All slicing values in the slicing spec must be one of `int`, "
"`bool`, or `str`, and all values must have the same type. "
f"Got types: {list(map(type, slicing_values))}."
)
if len(slicing_values) > len(set(slicing_values)):
raise ValueError(
"The slicing values passed to the slicing spec must be unique. Got "
f"{slicing_values}."
)
# TODO(b/276811843): Look into validating whether `metric` accepts
# `sample_weights` in its `update_state` method.
# Instance fully owns a deep copy of the metric.
self._metric = copy.deepcopy(metric)
self._slice_names = list(slice_names)
self._slicing_values = list(slicing_values)
self._slicing_values_tensors = [
tf.constant(v, slicing_feature_dtype) for v in slicing_values
]
self._slicing_feature_dtype = self._slicing_values_tensors[0].dtype
self._sliced_metrics = [copy.deepcopy(metric) for _ in self._slicing_values]
def update_state(
self,
*args: tf.Tensor,
sample_weight: tf.Tensor | None = None,
slicing_feature: tf.Tensor,
**kwargs,
):
"""Updates the state of the metrics for each slice.
Args:
*args: A variable amount of `tf.Tensor` instances that will be passed to
the `update_state` method of each metric.
sample_weight: An optional `tf.Tensor` used to weight the sample. Its
dimensions must be broadcastable to the shape(s) of *args.
slicing_feature: A `tf.Tensor` consisting of the feature to be sliced on.
Its dimensions must be broadcastable to the shape(s) of *args.
**kwargs: Keyword arguments that will be passed to the `update_state`
method of each metric.
"""
if slicing_feature.dtype != self._slicing_feature_dtype:
raise ValueError(
"The `slicing_feature` and slicing values in `slicing_spec` must "
"have the same type. Got types: "
f"{(slicing_feature.dtype, self._slicing_feature_dtype)}."
)
if sample_weight is not None:
for _ in range(len(slicing_feature.shape) - len(sample_weight.shape)):
sample_weight = tf.expand_dims(sample_weight, axis=-1)
for _ in range(len(sample_weight.shape) - len(slicing_feature.shape)):
slicing_feature = tf.expand_dims(slicing_feature, axis=-1)
self._metric.update_state(*args, sample_weight=sample_weight, **kwargs)
for slicing_val, metric in zip(
self._slicing_values_tensors, self._sliced_metrics
):
slice_mask = tf.cast(slicing_feature == slicing_val, dtype=tf.float32)
if sample_weight is not None:
weight = slice_mask * tf.cast(sample_weight, dtype=tf.float32)
else:
weight = slice_mask
metric.update_state(*args, sample_weight=weight, **kwargs)
def result(self) -> dict[str, tf.Tensor]:
"""Aggregates all the metrics' results into a flattened dictionary."""
metric_name = self._metric.name
metric_result = self._metric.result()
slice_results = [metric.result() for metric in self._sliced_metrics]
if isinstance(metric_result, tf.Tensor):
results = {metric_name: metric_result}
slice_names = (f"{metric_name}/{name}" for name in self._slice_names)
results.update(zip(slice_names, slice_results))
return results
if isinstance(metric_result, dict) and all(
isinstance(result, tf.Tensor) for result in metric_result.values()
):
results = {**metric_result}
for slice_name, slice_result in zip(self._slice_names, slice_results):
result_names, result_values = zip(*slice_result.items())
slice_names = [f"{name}/{slice_name}" for name in result_names]
results.update(zip(slice_names, result_values))
return results
raise ValueError(
"The output of the given metric must either be a `tf.Tensor` or "
"a `dict[str, tf.Tensor]`, but got unsupported output: "
f"{metric_result}."
)
def get_config(self):
return {
"name": self.name,
"metric": tf_keras.metrics.serialize(self._metric),
"slicing_spec": dict(zip(self._slice_names, self._slicing_values)),
"slicing_feature_dtype": self._slicing_feature_dtype.name,
}
@classmethod
def from_config(cls, config):
config["metric"] = tf_keras.metrics.deserialize(config["metric"])
config["slicing_feature_dtype"] = tf.as_dtype(
config["slicing_feature_dtype"]
)
return cls(**config)
|