AvatarArtist / recon /training /reconstructor /networks_reconstructor.py
刘虹雨
update
8ed2f16
import os
import sys
import numpy as np
from typing import Optional
import torch
from torch import nn
import torch.utils.model_zoo as model_zoo
from torch_utils import persistence
from segmentation_models_pytorch.base import SegmentationModel
from segmentation_models_pytorch.encoders.resnet import resnet_encoders
from models.deeplabv3.decoder import DeepLabV3Decoder
from models.unet.openaimodel import UNetModel, Upsample, Downsample
from models.mix_transformer.mix_transformer import OverlapPatchEmbed, Block, BlockCross
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
encoders = {}
encoders.update(resnet_encoders)
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, norm_layer=None, **kwargs):
try:
Encoder = encoders[name]["encoder"]
except KeyError:
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
params = encoders[name]["params"]
params.update(depth=depth)
params.update(norm_layer=norm_layer)
encoder = Encoder(**params)
if weights is not None:
print(weights)
# try:
# settings = encoders[name]["pretrained_settings"][weights]
#
# ProcessLookupError
# except KeyError:
# raise KeyError(
# "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
# weights,
# name,
# list(encoders[name]["pretrained_settings"].keys()),
# )
# )
# print('model_zoo.load_url(settings["url"])', settings["url"])
weights = torch.load(weights, map_location="cpu")
encoder.load_state_dict(weights, strict=False)
weights_or = "imagenet"
# pretrained = weights_or is not None
encoder.set_in_channels(in_channels, pretrained=weights_or is not None)
if output_stride != 32:
encoder.make_dilated(output_stride)
return encoder
def AddCoord(im):
B, C, H, W = im.shape
y, x = torch.meshgrid(torch.linspace(-1, 1, H, dtype=torch.float32, device=im.device),
torch.linspace(-1, 1, W, dtype=torch.float32, device=im.device), indexing='ij')
xy = torch.stack([x, y], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1) # (B,H,W,2)
xy = xy.permute(0, 3, 1, 2) # (B,2,H,W)
ret = torch.cat([im, xy], dim=1)
return ret
# Global appearance encoder, remove all bn layers, change input dimension to 5, and remove segmentation head following Live3DPortrait: https://arxiv.org/abs/2305.02310
@persistence.persistent_class
class EncoderGlobal(SegmentationModel):
"""DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation"
Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_channels: A number of convolution filters in ASPP module. Default is 256
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
Returns:
``torch.nn.Module``: **DeepLabV3**
.. _DeeplabV3:
https://arxiv.org/abs/1706.05587
"""
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_channels: int = 256,
in_channels: int = 5,
activation: Optional[str] = None,
aux_params: Optional[dict] = None,
norm_layer: nn.Module = nn.Identity
):
super().__init__()
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
norm_layer=norm_layer
)
self.decoder = DeepLabV3Decoder(
in_channels=self.encoder.out_channels[-1],
out_channels=decoder_channels,
)
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
x = AddCoord(x)
# self.check_input_shape(x)
features = self.encoder(x)
decoder_output = self.decoder(*features)
return decoder_output
# Detail appearance encoder
@persistence.persistent_class
class EncoderDetail(nn.Module):
def __init__(
self,
in_channels: int = 5
):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
)
for m in self.encoder.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
def forward(self, x):
x = AddCoord(x)
output = self.encoder(x)
return output
# Canonicalization and reenactment module
@persistence.persistent_class
class EncoderCanonical(nn.Module):
def __init__(self, img_size=64, patch_size=3, in_chans=512, embed_dims=1024, mot_dims=512, mot_dims_hidden=512,
H_y=8, W_y=8, num_heads=4, mlp_ratios=2, qkv_bias=True, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., num_blocks_neutral=3, num_blocks_motion=3,
norm_layer=nn.LayerNorm,
sr_ratios=1, mapping_layers=0):
super().__init__()
self.num_blocks_neutral = num_blocks_neutral
self.num_blocks_motion = num_blocks_motion
self.mapping_layers = mapping_layers
self.H_y = H_y
self.W_y = W_y
# mapping net for motion feature
if mapping_layers > 0:
self.maps = nn.ModuleList([])
for i in range(mapping_layers):
in_dims = mot_dims if i == 0 else mot_dims_hidden
self.maps.append(nn.Linear(in_dims, mot_dims_hidden, bias=True))
self.maps.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
self.maps_neutral = nn.ModuleList([])
for i in range(mapping_layers):
in_dims = mot_dims if i == 0 else mot_dims_hidden
self.maps_neutral.append(nn.Linear(in_dims, mot_dims_hidden, bias=True))
self.maps_neutral.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
else:
self.maps = None
self.maps_neutral = None
mot_dims_hidden = mot_dims
self.proj_y_neutral = nn.Linear(mot_dims_hidden, H_y * W_y * embed_dims, bias=True)
self.proj_y = nn.Linear(mot_dims_hidden, H_y * W_y * embed_dims, bias=True)
# patch_embed
self.patch_embed = OverlapPatchEmbed(img_size=img_size, patch_size=patch_size, stride=2, in_chans=in_chans,
embed_dim=embed_dims)
# canonicalization blocks
self.trans_blocks_neutral = nn.ModuleList([BlockCross(
dim=embed_dims, dim_y=mot_dims_hidden, H_y=H_y, W_y=W_y, num_heads=num_heads, mlp_ratio=mlp_ratios,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer,
sr_ratio=sr_ratios)
for i in range(num_blocks_neutral)])
# reenactment blocks
self.trans_blocks_motion = nn.ModuleList([BlockCross(
dim=embed_dims, dim_y=mot_dims_hidden, H_y=H_y, W_y=W_y, num_heads=num_heads, mlp_ratio=mlp_ratios,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer,
sr_ratio=sr_ratios)
for i in range(num_blocks_motion)])
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
self.convs_1 = nn.Sequential(
nn.Conv2d(3 * 128, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(96, 256, kernel_size=3, stride=1, padding=1)
)
self.convs = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, np.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def freeze_patch_emb(self):
self.patch_embed.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed'} # has pos_embed may be better
def forward_features(self, x, x_r, y, y_x, scale=1.):
B = x.shape[0]
outs = []
if self.maps is not None:
for layer in self.maps:
y = layer(y)
if self.maps_neutral is not None:
for layer in self.maps_neutral:
y_x = layer(y_x)
y = self.proj_y(y).reshape(B, self.H_y * self.W_y, -1)
y_x = self.proj_y_neutral(y_x).reshape(B, self.H_y * self.W_y, -1)
# trans blocks
x_r = self.convs_1(x_r)
x = torch.cat([x, x_r], dim=1)
x, H, W = self.patch_embed(x)
# neutralize the face
for i, blk in enumerate(self.trans_blocks_neutral):
x = blk(x, y_x, H, W, scale=scale)
# animate the face
for i, blk in enumerate(self.trans_blocks_motion):
x = blk(x, y, H, W, scale=scale)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.pixelshuffle(x)
outs.append(x)
# conv blocks
x = self.convs(x) # 256 256 96
outs.append(x)
return outs
def forward(self, x, x_r, y, y_x, scale=1.):
x = self.forward_features(x, x_r, y, y_x, scale=scale)
return x
# Triplane decoder
@persistence.persistent_class
class DecoderTriplane(nn.Module):
def __init__(self, img_size=256, patch_size=3, embed_dims=1024,
num_heads=2, mlp_ratios=2, qkv_bias=True, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., num_blocks=1, norm_layer=nn.LayerNorm,
sr_ratios=2):
super().__init__()
self.num_blocks = num_blocks
self.convs1 = nn.Sequential(
nn.Conv2d(192, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
)
# patch_embed
self.patch_embed = OverlapPatchEmbed(img_size=img_size, patch_size=patch_size, stride=2, in_chans=128,
embed_dim=embed_dims)
# transformer encoder
self.trans_blocks = nn.ModuleList([Block(
dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer,
sr_ratio=sr_ratios)
for i in range(num_blocks)])
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
self.convs2 = nn.Sequential(
nn.Conv2d(352, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
# nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
)
self.conv_last = nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
self.convs3 = nn.Sequential(
nn.Conv2d(192, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
# nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
)
self.conv_last_2 = nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
self.apply(self._init_weights)
self.conv_last.apply(self._init_weights_last)
def _init_weights_last(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.15)
if m.bias is not None:
m.bias.data.zero_()
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, np.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def freeze_patch_emb(self):
self.patch_embed.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed'} # has pos_embed may be better
def forward(self, x_global, x_detail, x_tri):
x = torch.cat([x_global, x_detail], dim=1) # [B,C,H,W]
B = x.shape[0]
# convs1
x = self.convs1(x)
# trans blocks
x, H, W = self.patch_embed(x)
for i, blk in enumerate(self.trans_blocks):
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.pixelshuffle(x)
x = torch.cat([x, x_global], dim=1)
# convs2
x = self.convs2(x)
x = self.conv_last(x)
# conv_final
x = torch.cat([x, x_tri], dim=1)
x = self.convs3(x)
x = self.conv_last_2(x)
return x
@persistence.persistent_class
class EncoderBG(nn.Module):
def __init__(
self,
in_channels=5,
out_channels=32,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.input_conv = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
) # 128
self.down_conv_1 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
Downsample(64, True, dims=2, out_channels=128)
) # 64
self.down_conv_2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
Downsample(128, True, dims=2, out_channels=256)
) # 32
self.down_conv_3 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
Downsample(256, True, dims=2, out_channels=512)
) # 16
self.middle = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
)
self.up_conv_1 = nn.Sequential(
nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
Upsample(256, True, dims=2, out_channels=256)
) # 32
self.up_conv_2 = nn.Sequential(
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
Upsample(128, True, dims=2, out_channels=128)
) # 64
self.up_conv_3 = nn.Sequential(
nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
Upsample(64, True, dims=2, out_channels=64)
) # 128
self.up_conv_4 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
Upsample(64, True, dims=2, out_channels=64)
) # 256
self.out_conv = nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, np.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = AddCoord(x)
h1 = self.input_conv(x)
h2 = self.down_conv_1(h1)
h3 = self.down_conv_2(h2)
h4 = self.down_conv_3(h3)
h = self.middle(h4)
x = self.up_conv_1(torch.cat([h, h4], 1))
x = self.up_conv_2(torch.cat([x, h3], 1))
x = self.up_conv_3(torch.cat([x, h2], 1))
x = self.up_conv_4(torch.cat([x, h1], 1))
output = self.out_conv(x)
return output