|
from abc import ABC, abstractmethod |
|
import torch |
|
from torch import nn |
|
from singleVis.backend import convert_distance_to_probability, compute_cross_entropy |
|
|
|
import torch |
|
torch.manual_seed(0) |
|
torch.cuda.manual_seed_all(0) |
|
|
|
|
|
"""Losses modules for preserving four propertes""" |
|
|
|
|
|
class Loss(nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
@abstractmethod |
|
def forward(self, *args, **kwargs): |
|
pass |
|
|
|
class UmapLoss(nn.Module): |
|
def __init__(self, negative_sample_rate, device, _a=1.0, _b=1.0, repulsion_strength=1.0): |
|
super(UmapLoss, self).__init__() |
|
|
|
self._negative_sample_rate = negative_sample_rate |
|
self._a = _a, |
|
self._b = _b, |
|
self._repulsion_strength = repulsion_strength |
|
self.DEVICE = torch.device(device) |
|
|
|
@property |
|
def a(self): |
|
return self._a[0] |
|
|
|
@property |
|
def b(self): |
|
return self._b[0] |
|
|
|
def forward(self, embedding_to, embedding_from): |
|
batch_size = embedding_to.shape[0] |
|
|
|
embedding_neg_to = torch.repeat_interleave(embedding_to, self._negative_sample_rate, dim=0) |
|
repeat_neg = torch.repeat_interleave(embedding_from, self._negative_sample_rate, dim=0) |
|
randperm = torch.randperm(repeat_neg.shape[0]) |
|
embedding_neg_from = repeat_neg[randperm] |
|
|
|
|
|
|
|
distance_embedding = torch.cat( |
|
( |
|
torch.norm(embedding_to - embedding_from, dim=1), |
|
torch.norm(embedding_neg_to - embedding_neg_from, dim=1), |
|
), |
|
dim=0, |
|
) |
|
probabilities_distance = convert_distance_to_probability( |
|
distance_embedding, self.a, self.b |
|
) |
|
probabilities_distance = probabilities_distance.to(self.DEVICE) |
|
|
|
|
|
probabilities_graph = torch.cat( |
|
(torch.ones(batch_size), torch.zeros(batch_size * self._negative_sample_rate)), dim=0, |
|
) |
|
probabilities_graph = probabilities_graph.to(device=self.DEVICE) |
|
|
|
|
|
(_, _, ce_loss) = compute_cross_entropy( |
|
probabilities_graph, |
|
probabilities_distance, |
|
repulsion_strength=self._repulsion_strength, |
|
) |
|
|
|
return torch.mean(ce_loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReconstructionLoss(nn.Module): |
|
def __init__(self, beta=1.0,alpha=0.5): |
|
super(ReconstructionLoss, self).__init__() |
|
self._beta = beta |
|
self._alpha = alpha |
|
|
|
def forward(self, edge_to, edge_from, recon_to, recon_from, a_to, a_from): |
|
loss1 = torch.mean(torch.mean(torch.multiply(torch.pow((1+a_to), self._beta), torch.pow(edge_to - recon_to, 2)), 1)) |
|
loss2 = torch.mean(torch.mean(torch.multiply(torch.pow((1+a_from), self._beta), torch.pow(edge_from - recon_from, 2)), 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (loss1 + loss2)/2 |
|
|
|
|
|
|
|
class SmoothnessLoss(nn.Module): |
|
def __init__(self, margin=0.0): |
|
super(SmoothnessLoss, self).__init__() |
|
self._margin = margin |
|
|
|
def forward(self, embedding, target, Coefficient): |
|
loss = torch.mean(Coefficient * torch.clamp(torch.norm(embedding-target, dim=1)-self._margin, min=0)) |
|
return loss |
|
|
|
|
|
class SingleVisLoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, lambd): |
|
super(SingleVisLoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.lambd = lambd |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, outputs): |
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
|
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from) |
|
|
|
umap_l = self.umap_loss(embedding_to, embedding_from) |
|
|
|
loss = umap_l + self.lambd * recon_l |
|
|
|
return umap_l, recon_l, loss |
|
|
|
class HybridLoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, smooth_loss, lambd1, lambd2): |
|
super(HybridLoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.smooth_loss = smooth_loss |
|
self.lambd1 = lambd1 |
|
self.lambd2 = lambd2 |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, embeded_to, coeff, outputs): |
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
|
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from) |
|
umap_l = self.umap_loss(embedding_to, embedding_from) |
|
smooth_l = self.smooth_loss(embedding_to, embeded_to, coeff) |
|
|
|
loss = umap_l + self.lambd1 * recon_l + self.lambd2 * smooth_l |
|
|
|
return umap_l, recon_l, smooth_l, loss |
|
|
|
|
|
class TemporalLoss(nn.Module): |
|
def __init__(self, prev_w, device) -> None: |
|
super(TemporalLoss, self).__init__() |
|
self.prev_w = prev_w |
|
self.device = device |
|
for param_name in self.prev_w.keys(): |
|
self.prev_w[param_name] = self.prev_w[param_name].to(device=self.device, dtype=torch.float32) |
|
|
|
def forward(self, curr_module): |
|
loss = torch.tensor(0., requires_grad=True).to(self.device) |
|
|
|
for name, curr_param in curr_module.named_parameters(): |
|
|
|
prev_param = self.prev_w[name] |
|
|
|
loss = loss + torch.sum(torch.square(curr_param-prev_param)) |
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
class DummyTemporalLoss(nn.Module): |
|
def __init__(self, device) -> None: |
|
super(DummyTemporalLoss, self).__init__() |
|
self.device = device |
|
|
|
def forward(self, curr_module): |
|
loss = torch.tensor(0., requires_grad=True).to(self.device) |
|
return loss |
|
|
|
|
|
class PositionRecoverLoss(nn.Module): |
|
def __init__(self, device) -> None: |
|
super(PositionRecoverLoss, self).__init__() |
|
self.device = device |
|
def forward(self, position, recover_position): |
|
mse_loss = nn.MSELoss().to(self.device) |
|
loss = mse_loss(position, recover_position) |
|
return loss |
|
|
|
|
|
class DVILoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, temporal_loss, lambd1, lambd2, device): |
|
super(DVILoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.temporal_loss = temporal_loss |
|
self.lambd1 = lambd1 |
|
self.lambd2 = lambd2 |
|
self.device = device |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, curr_model, outputs): |
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
|
|
|
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from).to(self.device) |
|
umap_l = self.umap_loss(embedding_to, embedding_from).to(self.device) |
|
temporal_l = self.temporal_loss(curr_model).to(self.device) |
|
|
|
loss = umap_l + self.lambd1 * recon_l + self.lambd2 * temporal_l |
|
|
|
return umap_l, self.lambd1 *recon_l, self.lambd2 *temporal_l, loss |
|
|
|
class MINE(nn.Module): |
|
def __init__(self): |
|
super(MINE, self).__init__() |
|
|
|
self.network = nn.Sequential( |
|
nn.Linear(2, 100), |
|
nn.ReLU(), |
|
nn.Linear(100, 1), |
|
) |
|
|
|
def forward(self, x, y): |
|
joint = torch.cat((x, y), dim=1) |
|
marginal = torch.cat((x, y[torch.randperm(x.size(0))]), dim=1) |
|
t_joint = self.network(joint) |
|
t_marginal = self.network(marginal) |
|
|
|
mi = torch.mean(t_joint) - torch.log(torch.mean(torch.exp(t_marginal))) |
|
return -mi |
|
|
|
|
|
class TVILoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, temporal_loss, MI_loss, lambd1, lambd2, lambd3, device): |
|
super(TVILoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.temporal_loss = temporal_loss |
|
self.MI_loss = MI_loss |
|
self.lambd1 = lambd1 |
|
self.lambd2 = lambd2 |
|
self.lambd3 = lambd3 |
|
self.device = device |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, curr_model, outputs): |
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from).to(self.device) |
|
umap_l = self.umap_loss(embedding_to, embedding_from).to(self.device) |
|
temporal_l = self.temporal_loss(curr_model).to(self.device) |
|
|
|
|
|
|
|
MI_l_embedding = self.MI_loss(embedding_to, embedding_from).to(self.device) |
|
MI_l_edge = self.MI_loss(edge_to, edge_from).to(self.device) |
|
|
|
MI_l = (MI_l_embedding + MI_l_edge) / 2 |
|
loss = umap_l + self.lambd1 * recon_l + self.lambd2 * temporal_l + self.lambd3 * MI_l |
|
|
|
return umap_l, self.lambd1 * recon_l, self.lambd2 * temporal_l, loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|