Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
from enum import Enum | |
import torch | |
import torch.distributed as dist | |
class Summary(Enum): | |
NONE = 0 | |
AVERAGE = 1 | |
SUM = 2 | |
COUNT = 3 | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): | |
self.name = name | |
self.fmt = fmt | |
self.summary_type = summary_type | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def all_reduce(self): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) | |
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) | |
self.sum, self.count = total.tolist() | |
self.avg = self.sum / self.count | |
def __str__(self): | |
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
return fmtstr.format(**self.__dict__) | |
def summary(self): | |
fmtstr = '' | |
if self.summary_type is Summary.NONE: | |
fmtstr = '' | |
elif self.summary_type is Summary.AVERAGE: | |
fmtstr = '{name} {avg:.3f}' | |
elif self.summary_type is Summary.SUM: | |
fmtstr = '{name} {sum:.3f}' | |
elif self.summary_type is Summary.COUNT: | |
fmtstr = '{name} {count:.3f}' | |
else: | |
raise ValueError('invalid summary type %r' % self.summary_type) | |
return fmtstr.format(**self.__dict__) | |