Spaces:
Build error
Build error
File size: 2,406 Bytes
b5ed368 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from torch import nn
from models.stylegan2.model import PixelNorm
from torch.nn import Linear, LayerNorm, LeakyReLU, Sequential, Module, Conv2d, GroupNorm
class TextModulationModule(Module):
def __init__(self, in_channels):
super(TextModulationModule, self).__init__()
self.conv = Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=False)
self.norm = GroupNorm(32, in_channels)
self.gamma_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels))
self.beta_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels))
self.leakyrelu = LeakyReLU()
def forward(self, x, embedding):
x = self.conv(x)
x = self.norm(x)
log_gamma = self.gamma_function(embedding.float())
gamma = log_gamma.exp().unsqueeze(2).unsqueeze(3)
beta = self.beta_function(embedding.float()).unsqueeze(2).unsqueeze(3)
out = x * (1 + gamma) + beta
out = self.leakyrelu(out)
return out
class SubTextMapper(Module):
def __init__(self, opts, in_channels):
super(SubTextMapper, self).__init__()
self.opts = opts
self.pixelnorm = PixelNorm()
self.modulation_module_list = nn.ModuleList([TextModulationModule(in_channels) for _ in range(1)])
def forward(self, x, embedding):
x = self.pixelnorm(x)
for modulation_module in self.modulation_module_list:
x = modulation_module(x, embedding)
return x
class CLIPAdapter(Module):
def __init__(self, opts):
super(CLIPAdapter, self).__init__()
self.opts = opts
if not opts.no_coarse_mapper:
self.coarse_mapping = SubTextMapper(opts, 512)
if not opts.no_medium_mapper:
self.medium_mapping = SubTextMapper(opts, 256)
if not opts.no_fine_mapper:
self.fine_mapping = SubTextMapper(opts, 128)
def forward(self, features, txt_embed):
txt_embed = txt_embed.detach()
c1, c2, c3 = features
if not self.opts.no_coarse_mapper:
c3 = self.coarse_mapping(c3, txt_embed)
if not self.opts.no_medium_mapper:
c2 = self.medium_mapping(c2, txt_embed)
if not self.opts.no_fine_mapper:
c1 = self.fine_mapping(c1, txt_embed)
return (c1,c2,c3) |