Spaces:
Build error
Build error
File size: 9,093 Bytes
aee59a3 4d9d4d0 8bfa969 329abf4 aee59a3 6d18e2a aee59a3 6d18e2a aee59a3 ec13ff1 aee59a3 6d18e2a aee59a3 6d18e2a aee59a3 6d18e2a f3b0d73 6d18e2a aee59a3 db9ef5e a209af8 6d18e2a f3b0d73 6d18e2a aee59a3 db9ef5e f3b0d73 c575cfd db9ef5e c575cfd 6d18e2a 5d7bacb c575cfd db9ef5e 5d7bacb db9ef5e aee59a3 6d18e2a aee59a3 6d18e2a 490df31 6d18e2a 490df31 aee59a3 6d18e2a db9ef5e bc89e29 db9ef5e 6d18e2a db9ef5e 6d18e2a db9ef5e 6d18e2a a3751fe db9ef5e dbcf938 db9ef5e 6d18e2a db9ef5e 6d18e2a db9ef5e 1033953 5d7bacb db9ef5e 5d7bacb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import evaluate
import datasets
import numpy as np
from seametrics.horizon.utils import *
_CITATION = """\
@InProceedings{huggingface:module,
title = {Horizon Metrics},
authors={huggingface, Inc.},
year={2024}
}
"""
# TODO: Add description of the module here
_DESCRIPTION = """\
This metric is intended to calculate horizon prediction metrics."""
# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predictions for each image. Each prediction
should be a nested array like this:
- [[x1, y1], [x2, y2]]
references: list of references for each image. Each reference
should be a nested array like this:
- [[x1, y1], [x2, y2]]
Returns:
dict containing following metrics:
'average_slope_error': Measures the average difference in slope between the predicted and ground truth horizon.
'average_midpoint_error': Calculates the average difference in midpoint position between the predicted and ground truth horizon.
'stddev_slope_error': Indicates the variability of errors in slope between the predicted and ground truth horizon.
'stddev_midpoint_error': Quantifies the variability of errors in midpoint position between the predicted and ground truth horizon.
'max_slope_error': Represents the maximum difference in slope between the predicted and ground truth horizon.
'max_midpoint_error': Indicates the maximum difference in midpoint position between the predicted and ground truth horizon.
'num_slope_error_jumps': Calculates the differences between errors in successive frames for the slope. It then counts the number of jumps in these errors by comparing the absolute differences to a specified threshold.
'num_midpoint_error_jumps': Calculates the differences between errors in successive frames for the midpoint. It then counts the number of jumps in these errors by comparing the absolute differences to a specified threshold.
Examples:
>>> ground_truth_points = [[[0.0, 0.5384765625], [1.0, 0.4931640625]],
[[0.0, 0.53796875], [1.0, 0.4928515625]],
[[0.0, 0.5374609375], [1.0, 0.4925390625]],
[[0.0, 0.536953125], [1.0, 0.4922265625]],
[[0.0, 0.5364453125], [1.0, 0.4919140625]]]
>>> prediction_points = [[[0.0, 0.5428930956049597], [1.0, 0.4642497615378973]],
[[0.0, 0.5428930956049597], [1.0, 0.4642497615378973]],
[[0.0, 0.523573113510805], [1.0, 0.47642688648919496]],
[[0.0, 0.5200016849393765], [1.0, 0.4728554579177664]],
[[0.0, 0.523573113510805], [1.0, 0.47642688648919496]]]
>>> module = evaluate.load("SEA-AI/horizon-metrics", vertical_fov_degrees=25.6, height=512, roll_threshold=0.5, pitch_threshold=0.1)
>>> module.add(predictions=ground_truth_points, references=prediction_points)
>>> module.compute()
>>> {'average_slope_error': 0.014823194839790999,
'average_midpoint_error': 0.014285714285714301,
'stddev_slope_error': 0.01519178791378349,
'stddev_midpoint_error': 0.0022661781575342445,
'max_slope_error': 0.033526146567062376,
'max_midpoint_error': 0.018161272321428612,
'num_slope_error_jumps': 1,
'num_midpoint_error_jumps': 1}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION,
_KWARGS_DESCRIPTION)
class HorizonMetrics(evaluate.Metric):
"""
HorizonMetrics is a metric class that calculates horizon prediction metrics.
Args:
vertical_fov_degrees (float): Vertical field of view in degrees.
height (int): Height of the image.
roll_threshold (float, optional): Threshold for roll angle. Defaults to 0.5.
pitch_threshold (float, optional): Threshold for pitch angle. Defaults to 0.1.
**kwargs: Additional keyword arguments.
Attributes:
slope_threshold (float): Threshold for slope calculated from roll threshold.
midpoint_threshold (float): Threshold for midpoint calculated from pitch threshold.
predictions (list): List of predicted horizons.
ground_truth_det (list): List of ground truth horizons.
slope_error_list (list): List of slope errors.
midpoint_error_list (list): List of midpoint errors.
Methods:
_info(): Returns the metric information.
add(predictions, references, **kwargs): Updates the predictions and ground truth detections.
_compute(predictions, references, **kwargs): Computes the horizon error across the sequence.
"""
def __init__(self,
vertical_fov_degrees,
height,
roll_threshold=0.5,
pitch_threshold=0.1,
**kwargs):
super().__init__(**kwargs)
self.slope_threshold = roll_to_slope(roll_threshold)
self.midpoint_threshold = pitch_to_midpoint(pitch_threshold,
vertical_fov_degrees)
self.predictions = None
self.ground_truth_det = None
self.slope_error_list = []
self.midpoint_error_list = []
self.height = height
self.vertical_fov_degrees = vertical_fov_degrees
def _info(self):
"""
Returns the metric information.
Returns:
MetricInfo: The metric information.
"""
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features({
'predictions':
datasets.Sequence(
datasets.Sequence(
datasets.Sequence(datasets.Value("float")))),
'references':
datasets.Sequence(
datasets.Sequence(
datasets.Sequence(datasets.Value("float")))),
}),
codebase_urls=["http://github.com/path/to/codebase/of/new_module"])
def add(self, *, predictions, references, **kwargs):
"""
Updates the predictions and ground truth detections.
Parameters:
predictions (list): List of predicted horizons.
references (list): List of ground truth horizons.
**kwargs: Additional keyword arguments.
"""
super(evaluate.Metric, self).add(prediction=predictions,
references=references,
**kwargs)
self.predictions = predictions
self.ground_truth_det = references
def _compute(self, *, predictions, references, **kwargs):
"""
Computes the horizon error across the sequence.
Returns:
float: The computed horizon error.
"""
# calculate erros and store values in slope_error_list and midpoint_error_list
for annotated_horizon, proposed_horizon in zip(self.ground_truth_det,
self.predictions):
if annotated_horizon is None or proposed_horizon is None:
continue
slope_error, midpoint_error = calculate_horizon_error(
annotated_horizon, proposed_horizon)
self.slope_error_list.append(slope_error)
self.midpoint_error_list.append(midpoint_error)
# calculate slope errors, midpoint errors and jumps
result = calculate_horizon_error_across_sequence(
self.slope_error_list, self.midpoint_error_list,
self.slope_threshold, self.midpoint_threshold,
self.vertical_fov_degrees, self.height)
# calulcate detection rate
detected_horizon_count = len(
self.predictions) - self.predictions.count(None)
detected_gt_count = len(
self.ground_truth_det) - self.ground_truth_det.count(None)
detection_rate = detected_horizon_count / detected_gt_count
result['detection_rate'] = detection_rate
return result
|