hichem-abdellali commited on
Commit
adaef8a
1 Parent(s): 2247036

update the user friendly metrics to logs into w&b (#4)

Browse files

- update the user friendly metrics to logs into w&b (c96c4e5195eb95ccd20cc07d19d03a6e4f66f472)

Files changed (1) hide show
  1. user-friendly-metrics.py +139 -24
user-friendly-metrics.py CHANGED
@@ -12,16 +12,15 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- import evaluate
16
- import datasets
17
- import motmetrics as mm
18
- from motmetrics.metrics import (events_to_df_map,
19
- obj_frequencies,
20
- track_ratios)
21
- import numpy as np
22
 
 
 
23
  from seametrics.user_friendly.utils import calculate_from_payload
24
 
 
 
25
  _CITATION = """\
26
  @InProceedings{huggingface:module,
27
  title = {A great new module},
@@ -70,17 +69,19 @@ class UserFriendlyMetrics(evaluate.Metric):
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- "predictions": datasets.Sequence(
75
- datasets.Sequence(datasets.Value("float"))
76
- ),
77
- "references": datasets.Sequence(
78
- datasets.Sequence(datasets.Value("float"))
79
- )
80
- }),
 
 
81
  # Additional links to the codebase or references
82
  codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
83
- reference_urls=["http://path.to.reference.url/new_module"]
84
  )
85
 
86
  def _download_and_prepare(self, dl_manager):
@@ -88,14 +89,128 @@ class UserFriendlyMetrics(evaluate.Metric):
88
  # TODO: Download external resources if needed
89
  pass
90
 
91
- def _compute(self,
92
- payload,
93
- max_iou: float = 0.5,
94
- filters = {},
95
- recognition_thresholds = [0.3, 0.5, 0.8],
96
- debug: bool = False):
 
 
97
  """Returns the scores"""
98
  # TODO: Compute the different scores of the module
99
- return calculate_from_payload(payload, max_iou, filters, recognition_thresholds, debug)
100
- #return calculate(predictions, references, max_iou)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import datetime
16
+ import os
 
 
 
 
 
17
 
18
+ import datasets
19
+ import evaluate
20
  from seametrics.user_friendly.utils import calculate_from_payload
21
 
22
+ import wandb
23
+
24
  _CITATION = """\
25
  @InProceedings{huggingface:module,
26
  title = {A great new module},
 
69
  citation=_CITATION,
70
  inputs_description=_KWARGS_DESCRIPTION,
71
  # This defines the format of each prediction and reference
72
+ features=datasets.Features(
73
+ {
74
+ "predictions": datasets.Sequence(
75
+ datasets.Sequence(datasets.Value("float"))
76
+ ),
77
+ "references": datasets.Sequence(
78
+ datasets.Sequence(datasets.Value("float"))
79
+ ),
80
+ }
81
+ ),
82
  # Additional links to the codebase or references
83
  codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
84
+ reference_urls=["http://path.to.reference.url/new_module"],
85
  )
86
 
87
  def _download_and_prepare(self, dl_manager):
 
89
  # TODO: Download external resources if needed
90
  pass
91
 
92
+ def _compute(
93
+ self,
94
+ payload,
95
+ max_iou: float = 0.5,
96
+ filters={},
97
+ recognition_thresholds=[0.3, 0.5, 0.8],
98
+ debug: bool = False,
99
+ ):
100
  """Returns the scores"""
101
  # TODO: Compute the different scores of the module
102
+ return calculate_from_payload(
103
+ payload, max_iou, filters, recognition_thresholds, debug
104
+ )
105
+ # return calculate(predictions, references, max_iou)
106
+
107
+ def wandb(
108
+ self,
109
+ results,
110
+ wandb_section: str = None,
111
+ wandb_project="user_friendly_metrics",
112
+ log_plots: bool = True,
113
+ debug: bool = False,
114
+ ):
115
+ """
116
+ Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for global metrics.
117
+
118
+ Args:
119
+ results (dict): Results dictionary with 'global' and 'per_sequence' keys.
120
+ wandb_section (str, optional): W&B section for metric grouping. Defaults to None.
121
+ wandb_project (str, optional): The name of the wandb project. Defaults to 'user_friendly_metrics'.
122
+ log_plots (bool, optional): Generates categorized bar charts for global metrics. Defaults to True.
123
+ debug (bool, optional): Logs detailed summaries and histories to the terminal console. Defaults to False.
124
+ """
125
+
126
+ current_datetime = datetime.datetime.now()
127
+ formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
128
+ wandb.login(key=os.getenv("WANDB_API_KEY"))
129
+
130
+ run = wandb.init(
131
+ project=wandb_project,
132
+ name=f"evaluation-{formatted_datetime}",
133
+ reinit=True,
134
+ settings=wandb.Settings(silent=not debug),
135
+ )
136
+
137
+ categories = {
138
+ "confusion_metrics": {"fp", "tp", "fn"},
139
+ "evaluation_metrics": {"f1", "recall", "precision"},
140
+ "recognition_metrics": {
141
+ "recognition_0.3",
142
+ "recognition_0.5",
143
+ "recognition_0.8",
144
+ "recognized_0.3",
145
+ "recognized_0.5",
146
+ "recognized_0.8",
147
+ },
148
+ }
149
+
150
+ chart_data = {key: [] for key in categories.keys()}
151
+
152
+ # Log global metrics
153
+ if "global" in results:
154
+ for global_key, global_metrics in results["global"].items():
155
+ for metric, value in global_metrics["all"].items():
156
+ log_key = (
157
+ f"{wandb_section}/global/{global_key}/{metric}"
158
+ if wandb_section
159
+ else f"global/{global_key}/{metric}"
160
+ )
161
+ run.log({log_key: value})
162
+
163
+ if debug:
164
+ print(f"Logged to W&B: {log_key} = {value}")
165
+
166
+ for category, metrics in categories.items():
167
+ if metric in metrics:
168
+ chart_data[category].append([metric, value])
169
+
170
+ if log_plots:
171
+ for category, data in chart_data.items():
172
+ if data:
173
+ table_data = [[label, value] for label, value in data]
174
+ table = wandb.Table(data=table_data, columns=["metrics", "value"])
175
+ run.log(
176
+ {
177
+ f"{category}_bar_chart": wandb.plot.bar(
178
+ table,
179
+ "metrics",
180
+ "value",
181
+ title=f"{category.replace('_', ' ').title()}",
182
+ )
183
+ }
184
+ )
185
+
186
+ if "per_sequence" in results:
187
+ sorted_sequences = sorted(
188
+ results["per_sequence"].items(),
189
+ key=lambda x: x[1]
190
+ .get("evaluation_metrics", {})
191
+ .get("f1", {})
192
+ .get("all", 0),
193
+ reverse=True,
194
+ )
195
+
196
+ for sequence_name, sequence_data in sorted_sequences:
197
+ for seq_key, seq_metrics in sequence_data.items():
198
+ for metric, value in seq_metrics["all"].items():
199
+ log_key = (
200
+ f"{wandb_section}/per_sequence/{sequence_name}/{seq_key}/{metric}"
201
+ if wandb_section
202
+ else f"per_sequence/{sequence_name}/{seq_key}/{metric}"
203
+ )
204
+ run.log({log_key: value})
205
+ if debug:
206
+ print(
207
+ f"Logged to W&B: {sequence_name} -> {log_key} = {value}"
208
+ )
209
+
210
+ if debug:
211
+ print("\nDebug Mode: Logging Summary and History")
212
+ print(f"Results Summary:\n{results}")
213
+ print(f"WandB Settings:\n{run.settings}")
214
+ print("All metrics have been logged.")
215
 
216
+ run.finish()