File size: 8,564 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras metric for reporting metrics sliced by a feature."""

import copy

import tensorflow as tf, tf_keras


class SlicedMetric(tf_keras.metrics.Metric):
  """A metric sliced by integer, boolean, or string features.

  A metric wrapper that computes a metric for different slices of an arbitrary
  feature. The slicing is specified via a slicing spec, which is a dictionary
  from a slice name to the unique value to be sliced on. For each pair of
  `slice_name`, `slicing_value` passed, the suffix `/slice_name` will be
  appended to the name of the result of the corresponding slice.
  An overall result is also computed without any slicing applied.

  In order for this to work correctly, the given metric must support passing
  `sample_weight` to its `update_state()` method. Additionally, the slicing
  feature must also be passed to `update_state()` method of this class and
  must be of a broadcastable shape to the metric inputs.
  This wrapper creates a deep copy of the metric passed to it for each slice.
  At every call to `update_state()`, the wrapper will call the `update_state()`
  method of the metric for each slice with the `sample_weights` set to zero
  where the slicing feature is not equal to the corresponding slicing value.

  If the given metric returns a tensor, the result of this metric will be a
  dictionary mapping from the sliced metric's name to the result for that slice.
  If the given metric returns a dictionary of tensors, the result of this metric
  will be a flattened dictionary consisting of each of the sliced metrics'
  results for every slice.

  Example usage:

  >>> sliced_metric = SlicedMetric(
  ...     tf_keras.metrics.Accuracy('accuracy'),
  ...     slicing_spec={"control": False, "treatment": True},
  ... )
  >>> sliced_metric.update_state(
  ...     y_true=tf.constant([[0], [1], [0], [1]]),
  ...     y_pred=tf.constant([[1], [0], [1], [1]]),
  ...     slicing_feature=tf.constant([[True], [False], [True], [False]]),
  ... )
  >>> sliced_metric.result()
  {
      "accuracy": 0.25,
      "accuracy/control": 0.5,
      "accuracy/treatment": 0
  }
  """

  def __init__(
      self,
      metric: tf_keras.metrics.Metric,
      slicing_spec: dict[str, str] | dict[str, int],
      slicing_feature_dtype: tf.DType | None = None,
      name: str | None = None,
  ):
    """Initializes the instance.

    Args:
      metric: A `tf_keras.metrics.Metric` instance.
      slicing_spec: A dictionary that maps from string slice names, to one of
        integer, boolean, or string slicing values.
      slicing_feature_dtype: The expected dtype of the slicing feature. The
        values in the slicing spec are casted to this type if passed. If None,
        the dtype of the slicing feature is inferred based on the values in the
        slicing spec.
      name: The name of the wrapper metric. Defaults to `sliced_{metric.name}`.

    Raises:
      A ValueError if `slicing_spec` is empty, contains duplicate slicing
      values, or has slicing values of different types.
    """
    super().__init__(name=name or f"sliced_{metric.name}", dtype=metric.dtype)

    if not slicing_spec:
      raise ValueError("The slicing spec must be a non-empty dictionary.")

    slice_names, slicing_values = zip(*slicing_spec.items())
    if not isinstance(slicing_values[0], (int, bool, str)) or not all(
        isinstance(k, type(slicing_values[0])) for k in slicing_values
    ):
      raise ValueError(
          "All slicing values in the slicing spec must be one of `int`, "
          "`bool`, or `str`, and all values must have the same type. "
          f"Got types: {list(map(type, slicing_values))}."
      )

    if len(slicing_values) > len(set(slicing_values)):
      raise ValueError(
          "The slicing values passed to the slicing spec must be unique. Got "
          f"{slicing_values}."
      )

    # TODO(b/276811843): Look into validating whether `metric` accepts
    # `sample_weights` in its `update_state` method.

    # Instance fully owns a deep copy of the metric.
    self._metric = copy.deepcopy(metric)
    self._slice_names = list(slice_names)
    self._slicing_values = list(slicing_values)
    self._slicing_values_tensors = [
        tf.constant(v, slicing_feature_dtype) for v in slicing_values
    ]
    self._slicing_feature_dtype = self._slicing_values_tensors[0].dtype
    self._sliced_metrics = [copy.deepcopy(metric) for _ in self._slicing_values]

  def update_state(
      self,
      *args: tf.Tensor,
      sample_weight: tf.Tensor | None = None,
      slicing_feature: tf.Tensor,
      **kwargs,
  ):
    """Updates the state of the metrics for each slice.

    Args:
      *args: A variable amount of `tf.Tensor` instances that will be passed to
        the `update_state` method of each metric.
      sample_weight: An optional `tf.Tensor` used to weight the sample. Its
        dimensions must be broadcastable to the shape(s) of *args.
      slicing_feature: A `tf.Tensor` consisting of the feature to be sliced on.
        Its dimensions must be broadcastable to the shape(s) of *args.
      **kwargs: Keyword arguments that will be passed to the `update_state`
        method of each metric.
    """

    if slicing_feature.dtype != self._slicing_feature_dtype:
      raise ValueError(
          "The `slicing_feature` and slicing values in `slicing_spec` must "
          "have the same type. Got types: "
          f"{(slicing_feature.dtype, self._slicing_feature_dtype)}."
      )

    if sample_weight is not None:
      for _ in range(len(slicing_feature.shape) - len(sample_weight.shape)):
        sample_weight = tf.expand_dims(sample_weight, axis=-1)

      for _ in range(len(sample_weight.shape) - len(slicing_feature.shape)):
        slicing_feature = tf.expand_dims(slicing_feature, axis=-1)

    self._metric.update_state(*args, sample_weight=sample_weight, **kwargs)
    for slicing_val, metric in zip(
        self._slicing_values_tensors, self._sliced_metrics
    ):
      slice_mask = tf.cast(slicing_feature == slicing_val, dtype=tf.float32)
      if sample_weight is not None:
        weight = slice_mask * tf.cast(sample_weight, dtype=tf.float32)
      else:
        weight = slice_mask
      metric.update_state(*args, sample_weight=weight, **kwargs)

  def result(self) -> dict[str, tf.Tensor]:
    """Aggregates all the metrics' results into a flattened dictionary."""
    metric_name = self._metric.name
    metric_result = self._metric.result()
    slice_results = [metric.result() for metric in self._sliced_metrics]

    if isinstance(metric_result, tf.Tensor):
      results = {metric_name: metric_result}
      slice_names = (f"{metric_name}/{name}" for name in self._slice_names)
      results.update(zip(slice_names, slice_results))
      return results

    if isinstance(metric_result, dict) and all(
        isinstance(result, tf.Tensor) for result in metric_result.values()
    ):
      results = {**metric_result}
      for slice_name, slice_result in zip(self._slice_names, slice_results):
        result_names, result_values = zip(*slice_result.items())
        slice_names = [f"{name}/{slice_name}" for name in result_names]
        results.update(zip(slice_names, result_values))
      return results

    raise ValueError(
        "The output of the given metric must either be a `tf.Tensor` or "
        "a `dict[str, tf.Tensor]`, but got unsupported output: "
        f"{metric_result}."
    )

  def get_config(self):
    return {
        "name": self.name,
        "metric": tf_keras.metrics.serialize(self._metric),
        "slicing_spec": dict(zip(self._slice_names, self._slicing_values)),
        "slicing_feature_dtype": self._slicing_feature_dtype.name,
    }

  @classmethod
  def from_config(cls, config):
    config["metric"] = tf_keras.metrics.deserialize(config["metric"])
    config["slicing_feature_dtype"] = tf.as_dtype(
        config["slicing_feature_dtype"]
    )
    return cls(**config)