CLIPInverter / adapter /clipadapter.py
Canberk Baykal
app.py
b5ed368
raw
history blame
2.41 kB
from torch import nn
from models.stylegan2.model import PixelNorm
from torch.nn import Linear, LayerNorm, LeakyReLU, Sequential, Module, Conv2d, GroupNorm
class TextModulationModule(Module):
def __init__(self, in_channels):
super(TextModulationModule, self).__init__()
self.conv = Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=False)
self.norm = GroupNorm(32, in_channels)
self.gamma_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels))
self.beta_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels))
self.leakyrelu = LeakyReLU()
def forward(self, x, embedding):
x = self.conv(x)
x = self.norm(x)
log_gamma = self.gamma_function(embedding.float())
gamma = log_gamma.exp().unsqueeze(2).unsqueeze(3)
beta = self.beta_function(embedding.float()).unsqueeze(2).unsqueeze(3)
out = x * (1 + gamma) + beta
out = self.leakyrelu(out)
return out
class SubTextMapper(Module):
def __init__(self, opts, in_channels):
super(SubTextMapper, self).__init__()
self.opts = opts
self.pixelnorm = PixelNorm()
self.modulation_module_list = nn.ModuleList([TextModulationModule(in_channels) for _ in range(1)])
def forward(self, x, embedding):
x = self.pixelnorm(x)
for modulation_module in self.modulation_module_list:
x = modulation_module(x, embedding)
return x
class CLIPAdapter(Module):
def __init__(self, opts):
super(CLIPAdapter, self).__init__()
self.opts = opts
if not opts.no_coarse_mapper:
self.coarse_mapping = SubTextMapper(opts, 512)
if not opts.no_medium_mapper:
self.medium_mapping = SubTextMapper(opts, 256)
if not opts.no_fine_mapper:
self.fine_mapping = SubTextMapper(opts, 128)
def forward(self, features, txt_embed):
txt_embed = txt_embed.detach()
c1, c2, c3 = features
if not self.opts.no_coarse_mapper:
c3 = self.coarse_mapping(c3, txt_embed)
if not self.opts.no_medium_mapper:
c2 = self.medium_mapping(c2, txt_embed)
if not self.opts.no_fine_mapper:
c1 = self.fine_mapping(c1, txt_embed)
return (c1,c2,c3)