Victoria Oberascher commited on
Commit
5d7bacb
·
1 Parent(s): 576a9f6

add detection rate metric

Browse files
Files changed (1) hide show
  1. horizon-metrics.py +35 -14
horizon-metrics.py CHANGED
@@ -65,7 +65,7 @@ Examples:
65
  [[0.0, 0.523573113510805], [1.0, 0.47642688648919496]]]
66
 
67
 
68
- >>> module = evaluate.load("SEA-AI/horizon-metrics")
69
  >>> module.add(predictions=ground_truth_points, references=prediction_points)
70
  >>> module.compute()
71
  >>> {'average_slope_error': 0.014823194839790999,
@@ -111,16 +111,20 @@ class HorizonMetrics(evaluate.Metric):
111
  roll_threshold=0.5,
112
  pitch_threshold=0.1,
113
  vertical_fov_degrees=25.6,
 
114
  **kwargs):
115
 
116
  super().__init__(**kwargs)
 
117
  self.slope_threshold = roll_to_slope(roll_threshold)
118
  self.midpoint_threshold = pitch_to_midpoint(pitch_threshold,
119
  vertical_fov_degrees)
120
  self.predictions = None
121
  self.ground_truth_det = None
122
- self.slope_error_list = None
123
- self.midpoint_error_list = None
 
 
124
 
125
  def _info(self):
126
  """
@@ -163,15 +167,6 @@ class HorizonMetrics(evaluate.Metric):
163
 
164
  self.predictions = predictions
165
  self.ground_truth_det = references
166
- self.slope_error_list = []
167
- self.midpoint_error_list = []
168
-
169
- for annotated_horizon, proposed_horizon in zip(self.ground_truth_det,
170
- self.predictions):
171
- slope_error, midpoint_error = calculate_horizon_error(
172
- annotated_horizon, proposed_horizon)
173
- self.slope_error_list.append(slope_error)
174
- self.midpoint_error_list.append(midpoint_error)
175
 
176
  def _compute(self, *, predictions, references, **kwargs):
177
  """
@@ -180,6 +175,32 @@ class HorizonMetrics(evaluate.Metric):
180
  Returns:
181
  float: The computed horizon error.
182
  """
183
- return calculate_horizon_error_across_sequence(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  self.slope_error_list, self.midpoint_error_list,
185
- self.slope_threshold, self.midpoint_threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  [[0.0, 0.523573113510805], [1.0, 0.47642688648919496]]]
66
 
67
 
68
+ >>> module = evaluate.load("SEA-AI/horizon-metrics", roll_threshold=0.5, pitch_threshold=0.1, vertical_fov_degrees=25.6, height=512)
69
  >>> module.add(predictions=ground_truth_points, references=prediction_points)
70
  >>> module.compute()
71
  >>> {'average_slope_error': 0.014823194839790999,
 
111
  roll_threshold=0.5,
112
  pitch_threshold=0.1,
113
  vertical_fov_degrees=25.6,
114
+ height=512,
115
  **kwargs):
116
 
117
  super().__init__(**kwargs)
118
+
119
  self.slope_threshold = roll_to_slope(roll_threshold)
120
  self.midpoint_threshold = pitch_to_midpoint(pitch_threshold,
121
  vertical_fov_degrees)
122
  self.predictions = None
123
  self.ground_truth_det = None
124
+ self.slope_error_list = []
125
+ self.midpoint_error_list = []
126
+ self.height = height
127
+ self.vertical_fov_degrees = vertical_fov_degrees
128
 
129
  def _info(self):
130
  """
 
167
 
168
  self.predictions = predictions
169
  self.ground_truth_det = references
 
 
 
 
 
 
 
 
 
170
 
171
  def _compute(self, *, predictions, references, **kwargs):
172
  """
 
175
  Returns:
176
  float: The computed horizon error.
177
  """
178
+
179
+ # calculate erros and store values in slope_error_list and midpoint_error_list
180
+ for annotated_horizon, proposed_horizon in zip(self.ground_truth_det,
181
+ self.predictions):
182
+
183
+ if annotated_horizon is None or proposed_horizon is None:
184
+ continue
185
+
186
+ slope_error, midpoint_error = calculate_horizon_error(
187
+ annotated_horizon, proposed_horizon)
188
+ self.slope_error_list.append(slope_error)
189
+ self.midpoint_error_list.append(midpoint_error)
190
+
191
+ # calculate slope errors, midpoint errors and jumps
192
+ result = calculate_horizon_error_across_sequence(
193
  self.slope_error_list, self.midpoint_error_list,
194
+ self.slope_threshold, self.midpoint_threshold,
195
+ self.vertical_fov_degrees, self.height)
196
+
197
+ # calulcate detection rate
198
+ detected_horizon_count = len(
199
+ self.predictions) - self.predictions.count(None)
200
+ detected_gt_count = len(
201
+ self.ground_truth_det) - self.ground_truth_det.count(None)
202
+
203
+ detection_rate = detected_horizon_count / detected_gt_count
204
+ result['detection_rate'] = detection_rate
205
+
206
+ return result