File size: 4,965 Bytes
708d62c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A library for instantiating frame interpolation evaluation metrics."""
from typing import Callable, Dict, Text
from ..losses import losses
import tensorflow as tf
class TrainLossMetric(tf.keras.metrics.Metric):
"""Compute training loss for our example and prediction format.
The purpose of this is to ensure that we always include a loss that is exactly
like the training loss into the evaluation in order to detect possible
overfitting.
"""
def __init__(self, name='eval_loss', **kwargs):
super(TrainLossMetric, self).__init__(name=name, **kwargs)
self.acc = self.add_weight(name='train_metric_acc', initializer='zeros')
self.count = self.add_weight(name='train_metric_count', initializer='zeros')
def update_state(self,
batch,
predictions,
sample_weight=None,
checkpoint_step=0):
loss_functions = losses.training_losses()
loss_list = []
for (loss_value, loss_weight) in loss_functions.values():
loss_list.append(
loss_value(batch, predictions) * loss_weight(checkpoint_step))
loss = tf.add_n(loss_list)
self.acc.assign_add(loss)
self.count.assign_add(1)
def result(self):
return self.acc / self.count
def reset_states(self):
self.acc.assign(0)
self.count.assign(0)
class L1Metric(tf.keras.metrics.Metric):
"""Compute L1 over our training example and prediction format.
The purpose of this is to ensure that we have at least one metric that is
compatible across all eval the session and allows us to quickly compare models
against each other.
"""
def __init__(self, name='eval_loss', **kwargs):
super(L1Metric, self).__init__(name=name, **kwargs)
self.acc = self.add_weight(name='l1_metric_acc', initializer='zeros')
self.count = self.add_weight(name='l1_metric_count', initializer='zeros')
def update_state(self, batch, prediction, sample_weight=None,
checkpoint_step=0):
self.acc.assign_add(losses.l1_loss(batch, prediction))
self.count.assign_add(1)
def result(self):
return self.acc / self.count
def reset_states(self):
self.acc.assign(0)
self.count.assign(0)
class GenericLossMetric(tf.keras.metrics.Metric):
"""Metric based on any loss function."""
def __init__(self, name: str, loss: Callable[..., tf.Tensor],
weight: Callable[..., tf.Tensor], **kwargs):
"""Initializes a metric based on a loss function and a weight schedule.
Args:
name: The name of the metric.
loss: The callable loss that calculates a loss value for a (prediction,
target) pair.
weight: The callable weight scheduling function that samples a weight
based on iteration.
**kwargs: Any additional keyword arguments to be passed.
"""
super(GenericLossMetric, self).__init__(name=name, **kwargs)
self.acc = self.add_weight(name='loss_metric_acc', initializer='zeros')
self.count = self.add_weight(name='loss_metric_count', initializer='zeros')
self.loss = loss
self.weight = weight
def update_state(self,
batch,
predictions,
sample_weight=None,
checkpoint_step=0):
self.acc.assign_add(
self.loss(batch, predictions) * self.weight(checkpoint_step))
self.count.assign_add(1)
def result(self):
return self.acc / self.count
def reset_states(self):
self.acc.assign(0)
self.count.assign(0)
def create_metrics_fn() -> Dict[Text, tf.keras.metrics.Metric]:
"""Create evaluation metrics.
L1 and total training loss are added by default.
The rest are the configured by the test_losses item via gin.
Returns:
A dictionary from metric name to Keras Metric object.
"""
metrics = {}
# L1 is explicitly added just so we always have some consistent numbers around
# to compare across sessions.
metrics['l1'] = L1Metric()
# We also always include training loss for the eval set to detect overfitting:
metrics['training_loss'] = TrainLossMetric()
test_losses = losses.test_losses()
for loss_name, (loss_value, loss_weight) in test_losses.items():
metrics[loss_name] = GenericLossMetric(
name=loss_name, loss=loss_value, weight=loss_weight)
return metrics
|