# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: from enum import Enum import torch import torch.distributed as dist class Summary(Enum): NONE = 0 AVERAGE = 1 SUM = 2 COUNT = 3 class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): self.name = name self.fmt = fmt self.summary_type = summary_type self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def all_reduce(self): if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) self.sum, self.count = total.tolist() self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) def summary(self): fmtstr = '' if self.summary_type is Summary.NONE: fmtstr = '' elif self.summary_type is Summary.AVERAGE: fmtstr = '{name} {avg:.3f}' elif self.summary_type is Summary.SUM: fmtstr = '{name} {sum:.3f}' elif self.summary_type is Summary.COUNT: fmtstr = '{name} {count:.3f}' else: raise ValueError('invalid summary type %r' % self.summary_type) return fmtstr.format(**self.__dict__)