Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
class BaseLosses(nn.Module): | |
def __init__(self, cfg, losses, params, losses_func, num_joints, **kwargs): | |
super().__init__() | |
# Save parameters | |
self.num_joints = num_joints | |
self._params = params | |
# Add total indicator | |
losses.append("total") if "total" not in losses else None | |
# Register losses | |
for loss in losses: | |
self.register_buffer(loss, torch.tensor(0.0)) | |
self.register_buffer("count", torch.tensor(0.0)) | |
self.losses = losses | |
# Instantiate loss functions | |
self._losses_func = {} | |
for loss in losses[:-1]: | |
self._losses_func[loss] = losses_func[loss](reduction='mean') | |
def _update_loss(self, loss: str, outputs, inputs): | |
'''Update the loss and return the weighted loss.''' | |
# Update the loss | |
val = self._losses_func[loss](outputs, inputs) | |
# self.losses_values[loss] += val.detach() | |
getattr(self, loss).add_(val.detach()) | |
# Return a weighted sum | |
weighted_loss = self._params[loss] * val | |
return weighted_loss | |
def reset(self): | |
'''Reset the losses to 0.''' | |
for loss in self.losses: | |
setattr(self, loss, torch.tensor(0.0, device=getattr(self, loss).device)) | |
setattr(self, "count", torch.tensor(0.0, device=getattr(self, "count").device)) | |
def compute(self, split): | |
'''Compute the losses and return a dictionary with the losses.''' | |
count = self.count | |
# Loss dictionary | |
loss_dict = {loss: getattr(self, loss)/count for loss in self.losses} | |
# Format the losses for logging | |
log_dict = { self.loss2logname(loss, split): value.item() | |
for loss, value in loss_dict.items() if not torch.isnan(value)} | |
# Reset the losses | |
self.reset() | |
return log_dict | |
def loss2logname(self, loss: str, split: str): | |
'''Convert the loss name to a log name.''' | |
if loss == "total": | |
log_name = f"{loss}/{split}" | |
else: | |
loss_type, name = loss.split("_") | |
log_name = f"{loss_type}/{name}/{split}" | |
return log_name | |