Spaces:
Running
Running
File size: 1,034 Bytes
c583015 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
from torch import nn
import torch.nn.functional as F
from adain import mi, sigma
class Loss(nn.Module):
def __init__(self, lamb=8):
super().__init__()
self.lamb = lamb
def content_loss(self, enc_out: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return F.mse_loss(enc_out, t)
def style_loss(self, out_activations: dict, style_activations: dict) -> torch.Tensor:
means, sds = 0, 0
for out_act, style_act in zip(out_activations.values(), style_activations.values()):
means += F.mse_loss(mi(out_act), mi(style_act))
sds += F.mse_loss(sigma(out_act), sigma(style_act))
return means + sds
def forward(self, enc_out: torch.Tensor, t: torch.Tensor, out_activations: dict, style_activations: dict) -> torch.Tensor:
self.loss_c = self.content_loss(enc_out, t)
self.loss_s = self.style_loss(out_activations, style_activations)
return (self.loss_c + self.lamb * self.loss_s)
|