horizon-metrics / horizon-metrics.py
Victoria Oberascher
remove default values for height and fov
f3b0d73
raw
history blame
9.09 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import evaluate
import datasets
import numpy as np
from seametrics.horizon.utils import *
_CITATION = """\
@InProceedings{huggingface:module,
title = {Horizon Metrics},
authors={huggingface, Inc.},
year={2024}
}
"""
# TODO: Add description of the module here
_DESCRIPTION = """\
This metric is intended to calculate horizon prediction metrics."""
# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predictions for each image. Each prediction
should be a nested array like this:
- [[x1, y1], [x2, y2]]
references: list of references for each image. Each reference
should be a nested array like this:
- [[x1, y1], [x2, y2]]
Returns:
dict containing following metrics:
'average_slope_error': Measures the average difference in slope between the predicted and ground truth horizon.
'average_midpoint_error': Calculates the average difference in midpoint position between the predicted and ground truth horizon.
'stddev_slope_error': Indicates the variability of errors in slope between the predicted and ground truth horizon.
'stddev_midpoint_error': Quantifies the variability of errors in midpoint position between the predicted and ground truth horizon.
'max_slope_error': Represents the maximum difference in slope between the predicted and ground truth horizon.
'max_midpoint_error': Indicates the maximum difference in midpoint position between the predicted and ground truth horizon.
'num_slope_error_jumps': Calculates the differences between errors in successive frames for the slope. It then counts the number of jumps in these errors by comparing the absolute differences to a specified threshold.
'num_midpoint_error_jumps': Calculates the differences between errors in successive frames for the midpoint. It then counts the number of jumps in these errors by comparing the absolute differences to a specified threshold.
Examples:
>>> ground_truth_points = [[[0.0, 0.5384765625], [1.0, 0.4931640625]],
[[0.0, 0.53796875], [1.0, 0.4928515625]],
[[0.0, 0.5374609375], [1.0, 0.4925390625]],
[[0.0, 0.536953125], [1.0, 0.4922265625]],
[[0.0, 0.5364453125], [1.0, 0.4919140625]]]
>>> prediction_points = [[[0.0, 0.5428930956049597], [1.0, 0.4642497615378973]],
[[0.0, 0.5428930956049597], [1.0, 0.4642497615378973]],
[[0.0, 0.523573113510805], [1.0, 0.47642688648919496]],
[[0.0, 0.5200016849393765], [1.0, 0.4728554579177664]],
[[0.0, 0.523573113510805], [1.0, 0.47642688648919496]]]
>>> module = evaluate.load("SEA-AI/horizon-metrics", vertical_fov_degrees=25.6, height=512, roll_threshold=0.5, pitch_threshold=0.1)
>>> module.add(predictions=ground_truth_points, references=prediction_points)
>>> module.compute()
>>> {'average_slope_error': 0.014823194839790999,
'average_midpoint_error': 0.014285714285714301,
'stddev_slope_error': 0.01519178791378349,
'stddev_midpoint_error': 0.0022661781575342445,
'max_slope_error': 0.033526146567062376,
'max_midpoint_error': 0.018161272321428612,
'num_slope_error_jumps': 1,
'num_midpoint_error_jumps': 1}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION,
_KWARGS_DESCRIPTION)
class HorizonMetrics(evaluate.Metric):
"""
HorizonMetrics is a metric class that calculates horizon prediction metrics.
Args:
vertical_fov_degrees (float): Vertical field of view in degrees.
height (int): Height of the image.
roll_threshold (float, optional): Threshold for roll angle. Defaults to 0.5.
pitch_threshold (float, optional): Threshold for pitch angle. Defaults to 0.1.
**kwargs: Additional keyword arguments.
Attributes:
slope_threshold (float): Threshold for slope calculated from roll threshold.
midpoint_threshold (float): Threshold for midpoint calculated from pitch threshold.
predictions (list): List of predicted horizons.
ground_truth_det (list): List of ground truth horizons.
slope_error_list (list): List of slope errors.
midpoint_error_list (list): List of midpoint errors.
Methods:
_info(): Returns the metric information.
add(predictions, references, **kwargs): Updates the predictions and ground truth detections.
_compute(predictions, references, **kwargs): Computes the horizon error across the sequence.
"""
def __init__(self,
vertical_fov_degrees,
height,
roll_threshold=0.5,
pitch_threshold=0.1,
**kwargs):
super().__init__(**kwargs)
self.slope_threshold = roll_to_slope(roll_threshold)
self.midpoint_threshold = pitch_to_midpoint(pitch_threshold,
vertical_fov_degrees)
self.predictions = None
self.ground_truth_det = None
self.slope_error_list = []
self.midpoint_error_list = []
self.height = height
self.vertical_fov_degrees = vertical_fov_degrees
def _info(self):
"""
Returns the metric information.
Returns:
MetricInfo: The metric information.
"""
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features({
'predictions':
datasets.Sequence(
datasets.Sequence(
datasets.Sequence(datasets.Value("float")))),
'references':
datasets.Sequence(
datasets.Sequence(
datasets.Sequence(datasets.Value("float")))),
}),
codebase_urls=["http://github.com/path/to/codebase/of/new_module"])
def add(self, *, predictions, references, **kwargs):
"""
Updates the predictions and ground truth detections.
Parameters:
predictions (list): List of predicted horizons.
references (list): List of ground truth horizons.
**kwargs: Additional keyword arguments.
"""
super(evaluate.Metric, self).add(prediction=predictions,
references=references,
**kwargs)
self.predictions = predictions
self.ground_truth_det = references
def _compute(self, *, predictions, references, **kwargs):
"""
Computes the horizon error across the sequence.
Returns:
float: The computed horizon error.
"""
# calculate erros and store values in slope_error_list and midpoint_error_list
for annotated_horizon, proposed_horizon in zip(self.ground_truth_det,
self.predictions):
if annotated_horizon is None or proposed_horizon is None:
continue
slope_error, midpoint_error = calculate_horizon_error(
annotated_horizon, proposed_horizon)
self.slope_error_list.append(slope_error)
self.midpoint_error_list.append(midpoint_error)
# calculate slope errors, midpoint errors and jumps
result = calculate_horizon_error_across_sequence(
self.slope_error_list, self.midpoint_error_list,
self.slope_threshold, self.midpoint_threshold,
self.vertical_fov_degrees, self.height)
# calulcate detection rate
detected_horizon_count = len(
self.predictions) - self.predictions.count(None)
detected_gt_count = len(
self.ground_truth_det) - self.ground_truth_det.count(None)
detection_rate = detected_horizon_count / detected_gt_count
result['detection_rate'] = detection_rate
return result