from typing import Any, Dict, Optional import torch from torch import Tensor from torchmetrics import Metric class NumTokens(Metric): """Keep track of how many tokens we've seen. """ # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch # of the next epoch. # Right now the hack is that we override reset(), which would mess up the forward method. # We then override forward to do the right thing. is_differentiable = False higher_is_better = False full_state_update = False count: Tensor def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", persistent=True) # We want the count to be saved to state-dict def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore self.count += target.numel() def compute(self) -> Tensor: return self.count def reset(self): count = self.count super().reset() self.count = count # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """forward computation using single call to `update` to calculate the metric value on the current batch and accumulate global state. This can be done when the global metric state is a sinple reduction of batch states. """ self.update(*args, **kwargs) return self.compute()