File size: 19,731 Bytes
3eb682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Meters."""
import datetime
import numpy as np
import os
from collections import defaultdict, deque
import torch
from fvcore.common.timer import Timer
from sklearn.metrics import average_precision_score
import timesformer.utils.logging as logging
import timesformer.utils.metrics as metrics
import timesformer.utils.misc as misc
logger = logging.get_logger(__name__)
class TestMeter(object):
"""
Perform the multi-view ensemble for testing: each video with an unique index
will be sampled with multiple clips, and the predictions of the clips will
be aggregated to produce the final prediction for the video.
The accuracy is calculated with the given ground truth labels.
"""
def __init__(
self,
num_videos,
num_clips,
num_cls,
overall_iters,
multi_label=False,
ensemble_method="sum",
):
"""
Construct tensors to store the predictions and labels. Expect to get
num_clips predictions from each video, and calculate the metrics on
num_videos videos.
Args:
num_videos (int): number of videos to test.
num_clips (int): number of clips sampled from each video for
aggregating the final prediction for the video.
num_cls (int): number of classes for each prediction.
overall_iters (int): overall iterations for testing.
multi_label (bool): if True, use map as the metric.
ensemble_method (str): method to perform the ensemble, options
include "sum", and "max".
"""
self.iter_timer = Timer()
self.data_timer = Timer()
self.net_timer = Timer()
self.num_clips = num_clips
self.overall_iters = overall_iters
self.multi_label = multi_label
self.ensemble_method = ensemble_method
# Initialize tensors.
self.video_preds = torch.zeros((num_videos, num_cls))
if multi_label:
self.video_preds -= 1e10
self.video_labels = (
torch.zeros((num_videos, num_cls))
if multi_label
else torch.zeros((num_videos)).long()
)
self.clip_count = torch.zeros((num_videos)).long()
self.topk_accs = []
self.stats = {}
# Reset metric.
self.reset()
def reset(self):
"""
Reset the metric.
"""
self.clip_count.zero_()
self.video_preds.zero_()
if self.multi_label:
self.video_preds -= 1e10
self.video_labels.zero_()
def update_stats(self, preds, labels, clip_ids):
"""
Collect the predictions from the current batch and perform on-the-flight
summation as ensemble.
Args:
preds (tensor): predictions from the current batch. Dimension is
N x C where N is the batch size and C is the channel size
(num_cls).
labels (tensor): the corresponding labels of the current batch.
Dimension is N.
clip_ids (tensor): clip indexes of the current batch, dimension is
N.
"""
for ind in range(preds.shape[0]):
vid_id = int(clip_ids[ind]) // self.num_clips
if self.video_labels[vid_id].sum() > 0:
assert torch.equal(
self.video_labels[vid_id].type(torch.FloatTensor),
labels[ind].type(torch.FloatTensor),
)
self.video_labels[vid_id] = labels[ind]
if self.ensemble_method == "sum":
self.video_preds[vid_id] += preds[ind]
elif self.ensemble_method == "max":
self.video_preds[vid_id] = torch.max(
self.video_preds[vid_id], preds[ind]
)
else:
raise NotImplementedError(
"Ensemble Method {} is not supported".format(
self.ensemble_method
)
)
self.clip_count[vid_id] += 1
def log_iter_stats(self, cur_iter):
"""
Log the stats.
Args:
cur_iter (int): the current iteration of testing.
"""
eta_sec = self.iter_timer.seconds() * (self.overall_iters - cur_iter)
eta = str(datetime.timedelta(seconds=int(eta_sec)))
stats = {
"split": "test_iter",
"cur_iter": "{}".format(cur_iter + 1),
"eta": eta,
"time_diff": self.iter_timer.seconds(),
}
logging.log_json_stats(stats)
def iter_tic(self):
"""
Start to record time.
"""
self.iter_timer.reset()
self.data_timer.reset()
def iter_toc(self):
"""
Stop to record time.
"""
self.iter_timer.pause()
self.net_timer.pause()
def data_toc(self):
self.data_timer.pause()
self.net_timer.reset()
def finalize_metrics(self, ks=(1, 5)):
"""
Calculate and log the final ensembled metrics.
ks (tuple): list of top-k values for topk_accuracies. For example,
ks = (1, 5) correspods to top-1 and top-5 accuracy.
"""
if not all(self.clip_count == self.num_clips):
logger.warning(
"clip count {} ~= num clips {}".format(
", ".join(
[
"{}: {}".format(i, k)
for i, k in enumerate(self.clip_count.tolist())
]
),
self.num_clips,
)
)
self.stats = {"split": "test_final"}
if self.multi_label:
map = get_map(
self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy()
)
self.stats["map"] = map
else:
num_topks_correct = metrics.topks_correct(
self.video_preds, self.video_labels, ks
)
topks = [
(x / self.video_preds.size(0)) * 100.0
for x in num_topks_correct
]
assert len({len(ks), len(topks)}) == 1
for k, topk in zip(ks, topks):
self.stats["top{}_acc".format(k)] = "{:.{prec}f}".format(
topk, prec=2
)
logging.log_json_stats(self.stats)
class ScalarMeter(object):
"""
A scalar meter uses a deque to track a series of scaler values with a given
window size. It supports calculating the median and average values of the
window, and also supports calculating the global average.
"""
def __init__(self, window_size):
"""
Args:
window_size (int): size of the max length of the deque.
"""
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def reset(self):
"""
Reset the deque.
"""
self.deque.clear()
self.total = 0.0
self.count = 0
def add_value(self, value):
"""
Add a new scalar value to the deque.
"""
self.deque.append(value)
self.count += 1
self.total += value
def get_win_median(self):
"""
Calculate the current median value of the deque.
"""
return np.median(self.deque)
def get_win_avg(self):
"""
Calculate the current average value of the deque.
"""
return np.mean(self.deque)
def get_global_avg(self):
"""
Calculate the global mean value.
"""
return self.total / self.count
class TrainMeter(object):
"""
Measure training stats.
"""
def __init__(self, epoch_iters, cfg):
"""
Args:
epoch_iters (int): the overall number of iterations of one epoch.
cfg (CfgNode): configs.
"""
self._cfg = cfg
self.epoch_iters = epoch_iters
self.MAX_EPOCH = cfg.SOLVER.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.data_timer = Timer()
self.net_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
# Current minibatch errors (smoothed over a window).
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Number of misclassified examples.
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
self.output_dir = cfg.OUTPUT_DIR
self.extra_stats = {}
self.extra_stats_total = {}
self.log_period = cfg.LOG_PERIOD
def reset(self):
"""
Reset the Meter.
"""
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
for key in self.extra_stats.keys():
self.extra_stats[key].reset()
self.extra_stats_total[key] = 0.0
def iter_tic(self):
"""
Start to record time.
"""
self.iter_timer.reset()
self.data_timer.reset()
def iter_toc(self):
"""
Stop to record time.
"""
self.iter_timer.pause()
self.net_timer.pause()
def data_toc(self):
self.data_timer.pause()
self.net_timer.reset()
def update_stats(self, top1_err, top5_err, loss, lr, mb_size, stats={}):
"""
Update the current stats.
Args:
top1_err (float): top1 error rate.
top5_err (float): top5 error rate.
loss (float): loss value.
lr (float): learning rate.
mb_size (int): mini batch size.
"""
self.loss.add_value(loss)
self.lr = lr
self.loss_total += loss * mb_size
self.num_samples += mb_size
if not self._cfg.DATA.MULTI_LABEL:
# Current minibatch stats
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
# Aggregate stats
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
for key in stats.keys():
if key not in self.extra_stats:
self.extra_stats[key] = ScalarMeter(self.log_period)
self.extra_stats_total[key] = 0.0
self.extra_stats[key].add_value(stats[key])
self.extra_stats_total[key] += stats[key] * mb_size
def log_iter_stats(self, cur_epoch, cur_iter):
"""
log the stats of the current iteration.
Args:
cur_epoch (int): the number of current epoch.
cur_iter (int): the number of current iteration.
"""
if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0:
return
eta_sec = self.iter_timer.seconds() * (
self.MAX_EPOCH - (cur_epoch * self.epoch_iters + cur_iter + 1)
)
eta = str(datetime.timedelta(seconds=int(eta_sec)))
stats = {
"_type": "train_iter",
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"dt": self.iter_timer.seconds(),
"dt_data": self.data_timer.seconds(),
"dt_net": self.net_timer.seconds(),
"eta": eta,
"loss": self.loss.get_win_median(),
"lr": self.lr,
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()),
}
if not self._cfg.DATA.MULTI_LABEL:
stats["top1_err"] = self.mb_top1_err.get_win_median()
stats["top5_err"] = self.mb_top5_err.get_win_median()
for key in self.extra_stats.keys():
stats[key] = self.extra_stats_total[key] / self.num_samples
logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch):
"""
Log the stats of the current epoch.
Args:
cur_epoch (int): the number of current epoch.
"""
eta_sec = self.iter_timer.seconds() * (
self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters
)
eta = str(datetime.timedelta(seconds=int(eta_sec)))
stats = {
"_type": "train_epoch",
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH),
"dt": self.iter_timer.seconds(),
"dt_data": self.data_timer.seconds(),
"dt_net": self.net_timer.seconds(),
"eta": eta,
"lr": self.lr,
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()),
"RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()),
}
if not self._cfg.DATA.MULTI_LABEL:
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
avg_loss = self.loss_total / self.num_samples
stats["top1_err"] = top1_err
stats["top5_err"] = top5_err
stats["loss"] = avg_loss
for key in self.extra_stats.keys():
stats[key] = self.extra_stats_total[key] / self.num_samples
logging.log_json_stats(stats)
class ValMeter(object):
"""
Measures validation stats.
"""
def __init__(self, max_iter, cfg):
"""
Args:
max_iter (int): the max number of iteration of the current epoch.
cfg (CfgNode): configs.
"""
self._cfg = cfg
self.max_iter = max_iter
self.iter_timer = Timer()
self.data_timer = Timer()
self.net_timer = Timer()
# Current minibatch errors (smoothed over a window).
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Min errors (over the full val set).
self.min_top1_err = 100.0
self.min_top5_err = 100.0
# Number of misclassified examples.
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
self.all_preds = []
self.all_labels = []
self.output_dir = cfg.OUTPUT_DIR
self.extra_stats = {}
self.extra_stats_total = {}
self.log_period = cfg.LOG_PERIOD
def reset(self):
"""
Reset the Meter.
"""
self.iter_timer.reset()
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
self.all_preds = []
self.all_labels = []
for key in self.extra_stats.keys():
self.extra_stats[key].reset()
self.extra_stats_total[key] = 0.0
def iter_tic(self):
"""
Start to record time.
"""
self.iter_timer.reset()
self.data_timer.reset()
def iter_toc(self):
"""
Stop to record time.
"""
self.iter_timer.pause()
self.net_timer.pause()
def data_toc(self):
self.data_timer.pause()
self.net_timer.reset()
def update_stats(self, top1_err, top5_err, mb_size, stats={}):
"""
Update the current stats.
Args:
top1_err (float): top1 error rate.
top5_err (float): top5 error rate.
mb_size (int): mini batch size.
"""
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.num_samples += mb_size
for key in stats.keys():
if key not in self.extra_stats:
self.extra_stats[key] = ScalarMeter(self.log_period)
self.extra_stats_total[key] = 0.0
self.extra_stats[key].add_value(stats[key])
self.extra_stats_total[key] += stats[key] * mb_size
def update_predictions(self, preds, labels):
"""
Update predictions and labels.
Args:
preds (tensor): model output predictions.
labels (tensor): labels.
"""
# TODO: merge update_prediction with update_stats.
self.all_preds.append(preds)
self.all_labels.append(labels)
def log_iter_stats(self, cur_epoch, cur_iter):
"""
log the stats of the current iteration.
Args:
cur_epoch (int): the number of current epoch.
cur_iter (int): the number of current iteration.
"""
if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0:
return
eta_sec = self.iter_timer.seconds() * (self.max_iter - cur_iter - 1)
eta = str(datetime.timedelta(seconds=int(eta_sec)))
stats = {
"_type": "val_iter",
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_diff": self.iter_timer.seconds(),
"eta": eta,
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()),
}
if not self._cfg.DATA.MULTI_LABEL:
stats["top1_err"] = self.mb_top1_err.get_win_median()
stats["top5_err"] = self.mb_top5_err.get_win_median()
for key in self.extra_stats.keys():
stats[key] = self.extra_stats[key].get_win_median()
logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch):
"""
Log the stats of the current epoch.
Args:
cur_epoch (int): the number of current epoch.
"""
stats = {
"_type": "val_epoch",
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH),
"time_diff": self.iter_timer.seconds(),
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()),
"RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()),
}
if self._cfg.DATA.MULTI_LABEL:
stats["map"] = get_map(
torch.cat(self.all_preds).cpu().numpy(),
torch.cat(self.all_labels).cpu().numpy(),
)
else:
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
self.min_top1_err = min(self.min_top1_err, top1_err)
self.min_top5_err = min(self.min_top5_err, top5_err)
stats["top1_err"] = top1_err
stats["top5_err"] = top5_err
stats["min_top1_err"] = self.min_top1_err
stats["min_top5_err"] = self.min_top5_err
for key in self.extra_stats.keys():
stats[key] = self.extra_stats_total[key] / self.num_samples
logging.log_json_stats(stats)
def get_map(preds, labels):
"""
Compute mAP for multi-label case.
Args:
preds (numpy tensor): num_examples x num_classes.
labels (numpy tensor): num_examples x num_classes.
Returns:
mean_ap (int): final mAP score.
"""
logger.info("Getting mAP for {} examples".format(preds.shape[0]))
preds = preds[:, ~(np.all(labels == 0, axis=0))]
labels = labels[:, ~(np.all(labels == 0, axis=0))]
aps = [0]
try:
aps = average_precision_score(labels, preds, average=None)
except ValueError:
print(
"Average precision requires a sufficient number of samples \
in a batch which are missing in this sample."
)
mean_ap = np.mean(aps)
return mean_ap
|