Spaces:
Runtime error
Runtime error
File size: 406 Bytes
2a92dc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
import torch.nn as nn
class AdaIN(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
ch = y.size(1)
sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1)
x_mu = x.mean(dim=[2, 3], keepdim=True)
x_sigma = x.std(dim=[2, 3], keepdim=True)
return sigma * ((x - x_mu) / x_sigma) + mu
|