|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from src.efficientvit.apps.utils.dist import sync_tensor |
|
|
|
__all__ = ["AverageMeter"] |
|
|
|
|
|
class AverageMeter: |
|
"""Computes and stores the average and current value.""" |
|
|
|
def __init__(self, is_distributed=True): |
|
self.is_distributed = is_distributed |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float: |
|
return sync_tensor(val, reduce="sum") if self.is_distributed else val |
|
|
|
def update(self, val: torch.Tensor or int or float, delta_n=1): |
|
self.count += self._sync(delta_n) |
|
self.sum += self._sync(val * delta_n) |
|
|
|
def get_count(self) -> torch.Tensor or int or float: |
|
return ( |
|
self.count.item() |
|
if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 |
|
else self.count |
|
) |
|
|
|
@property |
|
def avg(self): |
|
avg = -1 if self.count == 0 else self.sum / self.count |
|
return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg |
|
|