Spaces:
Build error
Build error
Victoria Oberascher
commited on
Commit
·
5d7bacb
1
Parent(s):
576a9f6
add detection rate metric
Browse files- horizon-metrics.py +35 -14
horizon-metrics.py
CHANGED
@@ -65,7 +65,7 @@ Examples:
|
|
65 |
[[0.0, 0.523573113510805], [1.0, 0.47642688648919496]]]
|
66 |
|
67 |
|
68 |
-
>>> module = evaluate.load("SEA-AI/horizon-metrics")
|
69 |
>>> module.add(predictions=ground_truth_points, references=prediction_points)
|
70 |
>>> module.compute()
|
71 |
>>> {'average_slope_error': 0.014823194839790999,
|
@@ -111,16 +111,20 @@ class HorizonMetrics(evaluate.Metric):
|
|
111 |
roll_threshold=0.5,
|
112 |
pitch_threshold=0.1,
|
113 |
vertical_fov_degrees=25.6,
|
|
|
114 |
**kwargs):
|
115 |
|
116 |
super().__init__(**kwargs)
|
|
|
117 |
self.slope_threshold = roll_to_slope(roll_threshold)
|
118 |
self.midpoint_threshold = pitch_to_midpoint(pitch_threshold,
|
119 |
vertical_fov_degrees)
|
120 |
self.predictions = None
|
121 |
self.ground_truth_det = None
|
122 |
-
self.slope_error_list =
|
123 |
-
self.midpoint_error_list =
|
|
|
|
|
124 |
|
125 |
def _info(self):
|
126 |
"""
|
@@ -163,15 +167,6 @@ class HorizonMetrics(evaluate.Metric):
|
|
163 |
|
164 |
self.predictions = predictions
|
165 |
self.ground_truth_det = references
|
166 |
-
self.slope_error_list = []
|
167 |
-
self.midpoint_error_list = []
|
168 |
-
|
169 |
-
for annotated_horizon, proposed_horizon in zip(self.ground_truth_det,
|
170 |
-
self.predictions):
|
171 |
-
slope_error, midpoint_error = calculate_horizon_error(
|
172 |
-
annotated_horizon, proposed_horizon)
|
173 |
-
self.slope_error_list.append(slope_error)
|
174 |
-
self.midpoint_error_list.append(midpoint_error)
|
175 |
|
176 |
def _compute(self, *, predictions, references, **kwargs):
|
177 |
"""
|
@@ -180,6 +175,32 @@ class HorizonMetrics(evaluate.Metric):
|
|
180 |
Returns:
|
181 |
float: The computed horizon error.
|
182 |
"""
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
self.slope_error_list, self.midpoint_error_list,
|
185 |
-
self.slope_threshold, self.midpoint_threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
[[0.0, 0.523573113510805], [1.0, 0.47642688648919496]]]
|
66 |
|
67 |
|
68 |
+
>>> module = evaluate.load("SEA-AI/horizon-metrics", roll_threshold=0.5, pitch_threshold=0.1, vertical_fov_degrees=25.6, height=512)
|
69 |
>>> module.add(predictions=ground_truth_points, references=prediction_points)
|
70 |
>>> module.compute()
|
71 |
>>> {'average_slope_error': 0.014823194839790999,
|
|
|
111 |
roll_threshold=0.5,
|
112 |
pitch_threshold=0.1,
|
113 |
vertical_fov_degrees=25.6,
|
114 |
+
height=512,
|
115 |
**kwargs):
|
116 |
|
117 |
super().__init__(**kwargs)
|
118 |
+
|
119 |
self.slope_threshold = roll_to_slope(roll_threshold)
|
120 |
self.midpoint_threshold = pitch_to_midpoint(pitch_threshold,
|
121 |
vertical_fov_degrees)
|
122 |
self.predictions = None
|
123 |
self.ground_truth_det = None
|
124 |
+
self.slope_error_list = []
|
125 |
+
self.midpoint_error_list = []
|
126 |
+
self.height = height
|
127 |
+
self.vertical_fov_degrees = vertical_fov_degrees
|
128 |
|
129 |
def _info(self):
|
130 |
"""
|
|
|
167 |
|
168 |
self.predictions = predictions
|
169 |
self.ground_truth_det = references
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
def _compute(self, *, predictions, references, **kwargs):
|
172 |
"""
|
|
|
175 |
Returns:
|
176 |
float: The computed horizon error.
|
177 |
"""
|
178 |
+
|
179 |
+
# calculate erros and store values in slope_error_list and midpoint_error_list
|
180 |
+
for annotated_horizon, proposed_horizon in zip(self.ground_truth_det,
|
181 |
+
self.predictions):
|
182 |
+
|
183 |
+
if annotated_horizon is None or proposed_horizon is None:
|
184 |
+
continue
|
185 |
+
|
186 |
+
slope_error, midpoint_error = calculate_horizon_error(
|
187 |
+
annotated_horizon, proposed_horizon)
|
188 |
+
self.slope_error_list.append(slope_error)
|
189 |
+
self.midpoint_error_list.append(midpoint_error)
|
190 |
+
|
191 |
+
# calculate slope errors, midpoint errors and jumps
|
192 |
+
result = calculate_horizon_error_across_sequence(
|
193 |
self.slope_error_list, self.midpoint_error_list,
|
194 |
+
self.slope_threshold, self.midpoint_threshold,
|
195 |
+
self.vertical_fov_degrees, self.height)
|
196 |
+
|
197 |
+
# calulcate detection rate
|
198 |
+
detected_horizon_count = len(
|
199 |
+
self.predictions) - self.predictions.count(None)
|
200 |
+
detected_gt_count = len(
|
201 |
+
self.ground_truth_det) - self.ground_truth_det.count(None)
|
202 |
+
|
203 |
+
detection_rate = detected_horizon_count / detected_gt_count
|
204 |
+
result['detection_rate'] = detection_rate
|
205 |
+
|
206 |
+
return result
|