Spaces:
Paused
Paused
from torch import nn | |
from .encoder import Encoder | |
from .styledecoder import Synthesis | |
class Generator(nn.Module): | |
def __init__(self, size, style_dim=512, motion_dim=20, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]): | |
super(Generator, self).__init__() | |
# encoder | |
self.enc = Encoder(size, style_dim, motion_dim) | |
self.dec = Synthesis(size, style_dim, motion_dim, blur_kernel, channel_multiplier) | |
def get_direction(self): | |
return self.dec.direction(None) | |
def synthesis(self, wa, alpha, feat): | |
img = self.dec(wa, alpha, feat) | |
return img | |
def forward(self, img_source, img_drive, h_start=None): | |
wa, alpha, feats = self.enc(img_source, img_drive, h_start) | |
# import pdb;pdb.set_trace() | |
img_recon = self.dec(wa, alpha, feats) | |
return img_recon | |