Leonard Bruns
Add Vista example
d323598
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Tuple
import torch
from torch import nn
from ....modules.distributions.distributions import DiagonalGaussianDistribution
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, dict]:
raise NotImplementedError
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log