File size: 1,034 Bytes
c583015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torch import nn
import torch.nn.functional as F

from adain import mi, sigma


class Loss(nn.Module):
    def __init__(self, lamb=8):
        super().__init__()
        self.lamb = lamb

    def content_loss(self, enc_out: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        return F.mse_loss(enc_out, t)
    
    def style_loss(self, out_activations: dict, style_activations: dict) -> torch.Tensor:
        means, sds = 0, 0
        for out_act, style_act in zip(out_activations.values(), style_activations.values()):
            means += F.mse_loss(mi(out_act), mi(style_act))
            sds += F.mse_loss(sigma(out_act), sigma(style_act))
            
        return means + sds

    def forward(self, enc_out: torch.Tensor, t: torch.Tensor, out_activations: dict, style_activations: dict) -> torch.Tensor:
        self.loss_c = self.content_loss(enc_out, t)
        self.loss_s = self.style_loss(out_activations, style_activations)
        
        return (self.loss_c + self.lamb * self.loss_s)