hichem-abdellali commited on
Commit
a637b07
·
verified ·
1 Parent(s): 74f9f8f

update the _compute function to use the seametrics library

Browse files

The User Friendly Metrics have been migrated into the Seametrics library, providing enhanced control over the code and enabling testing capabilities.

Files changed (1) hide show
  1. user-friendly-metrics.py +2 -216
user-friendly-metrics.py CHANGED
@@ -20,6 +20,8 @@ from motmetrics.metrics import (events_to_df_map,
20
  track_ratios)
21
  import numpy as np
22
 
 
 
23
  _CITATION = """\
24
  @InProceedings{huggingface:module,
25
  title = {A great new module},
@@ -97,219 +99,3 @@ class UserFriendlyMetrics(evaluate.Metric):
97
  return calculate_from_payload(payload, max_iou, filters, recognition_thresholds, debug)
98
  #return calculate(predictions, references, max_iou)
99
 
100
- def recognition(track_ratios, th = 0.5):
101
- """Number of objects tracked for at least 20 percent of lifespan."""
102
- return track_ratios[track_ratios >= th].count()
103
-
104
- def num_gt_ids(df):
105
- """Number of unique gt ids."""
106
- return df.full["OId"].dropna().unique().shape[0]
107
-
108
- def calculate(predictions,
109
- references,
110
- max_iou: float = 0.5,
111
- recognition_thresholds: list = [0.3, 0.5, 0.8]
112
- ):
113
-
114
- """Returns the scores"""
115
-
116
- try:
117
- np_predictions = np.array(predictions)
118
- except:
119
- raise ValueError("The predictions should be a list of np.arrays in the format [frame number, object id, bb_left, bb_top, bb_width, bb_height, confidence]")
120
-
121
- try:
122
- np_references = np.array(references)
123
- except:
124
- raise ValueError("The references should be a list of np.arrays in the format [frame number, object id, bb_left, bb_top, bb_width, bb_height]")
125
-
126
- if np_predictions.shape[1] != 7:
127
- raise ValueError("The predictions should be a list of np.arrays in the format [frame number, object id, bb_left, bb_top, bb_width, bb_height, confidence]")
128
- if np_references.shape[1] != 6:
129
- raise ValueError("The references should be a list of np.arrays in the format [frame number, object id, bb_left, bb_top, bb_width, bb_height]")
130
-
131
- if np_predictions[:, 0].min() <= 0:
132
- raise ValueError("The frame number in the predictions should be a positive integer")
133
- if np_references[:, 0].min() <= 0:
134
- raise ValueError("The frame number in the references should be a positive integer")
135
-
136
- num_frames = int(max(np_references[:, 0].max(), np_predictions[:, 0].max()))
137
-
138
- acc = mm.MOTAccumulator(auto_id=True)
139
- for i in range(1, num_frames+1):
140
- preds = np_predictions[np_predictions[:, 0] == i, 1:6]
141
- refs = np_references[np_references[:, 0] == i, 1:6]
142
- C = mm.distances.iou_matrix(refs[:,1:], preds[:,1:], max_iou = 1-max_iou) #motmetrics expects iou association threshold to be smaller for stricter association
143
- acc.update(refs[:,0].astype('int').tolist(), preds[:,0].astype('int').tolist(), C)
144
-
145
- mh = mm.metrics.create()
146
- summary = mh.compute(acc, metrics=['num_misses', 'num_false_positives', 'num_detections']).to_dict()
147
-
148
- df = events_to_df_map(acc.events)
149
- tr_ratios = track_ratios(df, obj_frequencies(df))
150
- unique_gt_ids = num_gt_ids(df)
151
-
152
- namemap = {"num_misses": "fn",
153
- "num_false_positives": "fp",
154
- "num_detections": "tp"}
155
-
156
- for key in list(summary.keys()):
157
- if key in namemap:
158
- summary[namemap[key]] = float(summary[key][0])
159
- summary.pop(key)
160
- else:
161
- summary[key] = float(summary[key][0])
162
-
163
- summary["num_gt_ids"] = unique_gt_ids
164
-
165
- for th in recognition_thresholds:
166
- recognized = recognition(tr_ratios, th)
167
- summary[f'recognized_{th}'] = int(recognized)
168
-
169
- return summary
170
-
171
- def build_metrics_template(models, filters):
172
- metrics_dict = {}
173
- for model in models:
174
- metrics_dict[model] = {}
175
- metrics_dict[model]["all"] = {}
176
- for filter, filter_ranges in filters.items():
177
- metrics_dict[model][filter] = {}
178
- for filter_range in filter_ranges:
179
- filter_range_name = filter_range[0]
180
- metrics_dict[model][filter][filter_range_name] = {}
181
- return metrics_dict
182
-
183
-
184
- def calculate_from_payload(payload: dict,
185
- max_iou: float = 0.5,
186
- filters = {},
187
- recognition_thresholds = [0.3, 0.5, 0.8],
188
- debug: bool = False):
189
-
190
- if not isinstance(payload, dict):
191
- try:
192
- payload = payload.to_dict()
193
- except Exception as e:
194
- raise ValueError(
195
- "The payload should be a dictionary or a compatible object"
196
- ) from e
197
- gt_field_name = payload['gt_field_name']
198
- models = payload['models']
199
- sequence_list = payload['sequence_list']
200
-
201
- if debug:
202
- print("gt_field_name: ", gt_field_name)
203
- print("models: ", models)
204
- print("sequence_list: ", sequence_list)
205
-
206
- metrics_per_sequence = {}
207
- metrics_global = build_metrics_template(models, filters)
208
-
209
- for sequence in sequence_list:
210
- metrics_per_sequence[sequence] = {}
211
- frames = payload['sequences'][sequence][gt_field_name]
212
-
213
- all_formated_references = {"all": []}
214
- for filter, filter_ranges in filters.items():
215
- all_formated_references[filter] = {}
216
- for filter_range in filter_ranges:
217
- filter_range_name = filter_range[0]
218
- all_formated_references[filter][filter_range_name] = []
219
-
220
- for frame_id, frame in enumerate(frames):
221
- for detection in frame:
222
- index = detection['index']
223
- x, y, w, h = detection['bounding_box']
224
- all_formated_references["all"].append([frame_id+1, index, x, y, w, h])
225
-
226
- for filter, filter_ranges in filters.items():
227
- filter_value = detection[filter]
228
- for filter_range in filter_ranges:
229
- filter_range_name, filter_range_limits = filter_range[0], filter_range[1]
230
- if filter_value >= filter_range_limits[0] and filter_value <= filter_range_limits[1]:
231
- all_formated_references[filter][filter_range_name].append([frame_id+1, index, x, y, w, h])
232
-
233
- metrics_per_sequence[sequence] = build_metrics_template(models, filters)
234
-
235
- for model in models:
236
- frames = payload['sequences'][sequence][model]
237
- formated_predictions = []
238
-
239
- for frame_id, frame in enumerate(frames):
240
- for detection in frame:
241
- index = detection['index']
242
- x, y, w, h = detection['bounding_box']
243
- confidence = 1
244
- formated_predictions.append([frame_id+1, index, x, y, w, h, confidence])
245
-
246
- if debug:
247
- print("sequence/model: ", sequence, model)
248
- print("formated_predictions: ", formated_predictions)
249
- print("formated_references: ", all_formated_references)
250
-
251
- if len(formated_predictions) == 0:
252
- metrics_per_sequence[sequence][model] = "Model had no predictions."
253
- elif len(all_formated_references["all"]) == 0:
254
- metrics_per_sequence[sequence][model] = "No ground truth."
255
-
256
- else:
257
-
258
- sequence_metrics = calculate(formated_predictions, all_formated_references["all"], max_iou=max_iou, recognition_thresholds = recognition_thresholds)
259
- sequence_metrics = realize_metrics(sequence_metrics, recognition_thresholds)
260
- metrics_per_sequence[sequence][model]["all"] = sequence_metrics
261
-
262
- metrics_global[model]["all"] = sum_dicts(metrics_global[model]["all"], sequence_metrics)
263
- metrics_global[model]["all"] = realize_metrics(metrics_global[model]["all"], recognition_thresholds)
264
-
265
- for filter, filter_ranges in filters.items():
266
-
267
- for filter_range in filter_ranges:
268
-
269
- filter_range_name = filter_range[0]
270
- sequence_metrics = calculate(formated_predictions, all_formated_references[filter][filter_range_name], max_iou=max_iou, recognition_thresholds = recognition_thresholds)
271
- sequence_metrics = realize_metrics(sequence_metrics, recognition_thresholds)
272
- metrics_per_sequence[sequence][model][filter][filter_range_name] = sequence_metrics
273
-
274
- metrics_global[model][filter][filter_range_name] = sum_dicts(metrics_global[model][filter][filter_range_name], sequence_metrics)
275
- metrics_global[model][filter][filter_range_name] = realize_metrics(metrics_global[model][filter][filter_range_name], recognition_thresholds)
276
-
277
- output = {"global": metrics_global, "per_sequence": metrics_per_sequence}
278
-
279
- return output
280
-
281
- def sum_dicts(dict1, dict2):
282
- """
283
- Recursively sums the numerical values in two nested dictionaries.
284
- """
285
- result = {}
286
- for key in dict1.keys() | dict2.keys(): # Union of keys from both dictionaries
287
- val1 = dict1.get(key, 0)
288
- val2 = dict2.get(key, 0)
289
- if isinstance(val1, dict) and isinstance(val2, dict):
290
- # If both values are dictionaries, recursively sum them
291
- result[key] = sum_dicts(val1, val2)
292
- elif isinstance(val1, (int, float)) and isinstance(val2, (int, float)):
293
- # If both are numbers, sum them
294
- result[key] = val1 + val2
295
- else:
296
- # If only one dictionary has the key, take the non-zero value
297
- result[key] = val1 if val1 != 0 else val2
298
- return result
299
-
300
- def realize_metrics(metrics_dict,
301
- recognition_thresholds):
302
- """
303
- calculates metrics based on raw metrics
304
- """
305
-
306
- metrics_dict["precision"] = metrics_dict["tp"]/(metrics_dict["tp"]+metrics_dict["fp"])
307
- metrics_dict["recall"] = metrics_dict["tp"]/(metrics_dict["tp"]+metrics_dict["fn"])
308
-
309
- metrics_dict["f1"] = 2*metrics_dict["precision"]*metrics_dict["recall"]/(metrics_dict["precision"]+metrics_dict["recall"]+1e-6)
310
-
311
- for th in recognition_thresholds:
312
- metrics_dict[f"recognition_{th}"] = metrics_dict[f"recognized_{th}"]/metrics_dict["num_gt_ids"]
313
-
314
- return metrics_dict
315
-
 
20
  track_ratios)
21
  import numpy as np
22
 
23
+ from seametrics.user_friendly.utils import calculate_from_payload
24
+
25
  _CITATION = """\
26
  @InProceedings{huggingface:module,
27
  title = {A great new module},
 
99
  return calculate_from_payload(payload, max_iou, filters, recognition_thresholds, debug)
100
  #return calculate(predictions, references, max_iou)
101