Spaces:
Runtime error
Runtime error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import paddle | |
import paddle.nn as nn | |
from arch.base_module import MiddleNet, ResBlock | |
from arch.encoder import Encoder | |
from arch.decoder import Decoder, DecoderUnet, SingleDecoder | |
from utils.load_params import load_dygraph_pretrain | |
from utils.logging import get_logger | |
class StyleTextRec(nn.Layer): | |
def __init__(self, config): | |
super(StyleTextRec, self).__init__() | |
self.logger = get_logger() | |
self.text_generator = TextGenerator(config["Predictor"][ | |
"text_generator"]) | |
self.bg_generator = BgGeneratorWithMask(config["Predictor"][ | |
"bg_generator"]) | |
self.fusion_generator = FusionGeneratorSimple(config["Predictor"][ | |
"fusion_generator"]) | |
bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"] | |
text_generator_pretrain = config["Predictor"]["text_generator"][ | |
"pretrain"] | |
fusion_generator_pretrain = config["Predictor"]["fusion_generator"][ | |
"pretrain"] | |
load_dygraph_pretrain( | |
self.bg_generator, | |
self.logger, | |
path=bg_generator_pretrain, | |
load_static_weights=False) | |
load_dygraph_pretrain( | |
self.text_generator, | |
self.logger, | |
path=text_generator_pretrain, | |
load_static_weights=False) | |
load_dygraph_pretrain( | |
self.fusion_generator, | |
self.logger, | |
path=fusion_generator_pretrain, | |
load_static_weights=False) | |
def forward(self, style_input, text_input): | |
text_gen_output = self.text_generator.forward(style_input, text_input) | |
fake_text = text_gen_output["fake_text"] | |
fake_sk = text_gen_output["fake_sk"] | |
bg_gen_output = self.bg_generator.forward(style_input) | |
bg_encode_feature = bg_gen_output["bg_encode_feature"] | |
bg_decode_feature1 = bg_gen_output["bg_decode_feature1"] | |
bg_decode_feature2 = bg_gen_output["bg_decode_feature2"] | |
fake_bg = bg_gen_output["fake_bg"] | |
fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg) | |
fake_fusion = fusion_gen_output["fake_fusion"] | |
return { | |
"fake_fusion": fake_fusion, | |
"fake_text": fake_text, | |
"fake_sk": fake_sk, | |
"fake_bg": fake_bg, | |
} | |
class TextGenerator(nn.Layer): | |
def __init__(self, config): | |
super(TextGenerator, self).__init__() | |
name = config["module_name"] | |
encode_dim = config["encode_dim"] | |
norm_layer = config["norm_layer"] | |
conv_block_dropout = config["conv_block_dropout"] | |
conv_block_num = config["conv_block_num"] | |
conv_block_dilation = config["conv_block_dilation"] | |
if norm_layer == "InstanceNorm2D": | |
use_bias = True | |
else: | |
use_bias = False | |
self.encoder_text = Encoder( | |
name=name + "_encoder_text", | |
in_channels=3, | |
encode_dim=encode_dim, | |
use_bias=use_bias, | |
norm_layer=norm_layer, | |
act="ReLU", | |
act_attr=None, | |
conv_block_dropout=conv_block_dropout, | |
conv_block_num=conv_block_num, | |
conv_block_dilation=conv_block_dilation) | |
self.encoder_style = Encoder( | |
name=name + "_encoder_style", | |
in_channels=3, | |
encode_dim=encode_dim, | |
use_bias=use_bias, | |
norm_layer=norm_layer, | |
act="ReLU", | |
act_attr=None, | |
conv_block_dropout=conv_block_dropout, | |
conv_block_num=conv_block_num, | |
conv_block_dilation=conv_block_dilation) | |
self.decoder_text = Decoder( | |
name=name + "_decoder_text", | |
encode_dim=encode_dim, | |
out_channels=int(encode_dim / 2), | |
use_bias=use_bias, | |
norm_layer=norm_layer, | |
act="ReLU", | |
act_attr=None, | |
conv_block_dropout=conv_block_dropout, | |
conv_block_num=conv_block_num, | |
conv_block_dilation=conv_block_dilation, | |
out_conv_act="Tanh", | |
out_conv_act_attr=None) | |
self.decoder_sk = Decoder( | |
name=name + "_decoder_sk", | |
encode_dim=encode_dim, | |
out_channels=1, | |
use_bias=use_bias, | |
norm_layer=norm_layer, | |
act="ReLU", | |
act_attr=None, | |
conv_block_dropout=conv_block_dropout, | |
conv_block_num=conv_block_num, | |
conv_block_dilation=conv_block_dilation, | |
out_conv_act="Sigmoid", | |
out_conv_act_attr=None) | |
self.middle = MiddleNet( | |
name=name + "_middle_net", | |
in_channels=int(encode_dim / 2) + 1, | |
mid_channels=encode_dim, | |
out_channels=3, | |
use_bias=use_bias) | |
def forward(self, style_input, text_input): | |
style_feature = self.encoder_style.forward(style_input)["res_blocks"] | |
text_feature = self.encoder_text.forward(text_input)["res_blocks"] | |
fake_c_temp = self.decoder_text.forward([text_feature, | |
style_feature])["out_conv"] | |
fake_sk = self.decoder_sk.forward([text_feature, | |
style_feature])["out_conv"] | |
fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1)) | |
return {"fake_sk": fake_sk, "fake_text": fake_text} | |
class BgGeneratorWithMask(nn.Layer): | |
def __init__(self, config): | |
super(BgGeneratorWithMask, self).__init__() | |
name = config["module_name"] | |
encode_dim = config["encode_dim"] | |
norm_layer = config["norm_layer"] | |
conv_block_dropout = config["conv_block_dropout"] | |
conv_block_num = config["conv_block_num"] | |
conv_block_dilation = config["conv_block_dilation"] | |
self.output_factor = config.get("output_factor", 1.0) | |
if norm_layer == "InstanceNorm2D": | |
use_bias = True | |
else: | |
use_bias = False | |
self.encoder_bg = Encoder( | |
name=name + "_encoder_bg", | |
in_channels=3, | |
encode_dim=encode_dim, | |
use_bias=use_bias, | |
norm_layer=norm_layer, | |
act="ReLU", | |
act_attr=None, | |
conv_block_dropout=conv_block_dropout, | |
conv_block_num=conv_block_num, | |
conv_block_dilation=conv_block_dilation) | |
self.decoder_bg = SingleDecoder( | |
name=name + "_decoder_bg", | |
encode_dim=encode_dim, | |
out_channels=3, | |
use_bias=use_bias, | |
norm_layer=norm_layer, | |
act="ReLU", | |
act_attr=None, | |
conv_block_dropout=conv_block_dropout, | |
conv_block_num=conv_block_num, | |
conv_block_dilation=conv_block_dilation, | |
out_conv_act="Tanh", | |
out_conv_act_attr=None) | |
self.decoder_mask = Decoder( | |
name=name + "_decoder_mask", | |
encode_dim=encode_dim // 2, | |
out_channels=1, | |
use_bias=use_bias, | |
norm_layer=norm_layer, | |
act="ReLU", | |
act_attr=None, | |
conv_block_dropout=conv_block_dropout, | |
conv_block_num=conv_block_num, | |
conv_block_dilation=conv_block_dilation, | |
out_conv_act="Sigmoid", | |
out_conv_act_attr=None) | |
self.middle = MiddleNet( | |
name=name + "_middle_net", | |
in_channels=3 + 1, | |
mid_channels=encode_dim, | |
out_channels=3, | |
use_bias=use_bias) | |
def forward(self, style_input): | |
encode_bg_output = self.encoder_bg(style_input) | |
decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"], | |
encode_bg_output["down2"], | |
encode_bg_output["down1"]) | |
fake_c_temp = decode_bg_output["out_conv"] | |
fake_bg_mask = self.decoder_mask.forward(encode_bg_output[ | |
"res_blocks"])["out_conv"] | |
fake_bg = self.middle( | |
paddle.concat( | |
(fake_c_temp, fake_bg_mask), axis=1)) | |
return { | |
"bg_encode_feature": encode_bg_output["res_blocks"], | |
"bg_decode_feature1": decode_bg_output["up1"], | |
"bg_decode_feature2": decode_bg_output["up2"], | |
"fake_bg": fake_bg, | |
"fake_bg_mask": fake_bg_mask, | |
} | |
class FusionGeneratorSimple(nn.Layer): | |
def __init__(self, config): | |
super(FusionGeneratorSimple, self).__init__() | |
name = config["module_name"] | |
encode_dim = config["encode_dim"] | |
norm_layer = config["norm_layer"] | |
conv_block_dropout = config["conv_block_dropout"] | |
conv_block_dilation = config["conv_block_dilation"] | |
if norm_layer == "InstanceNorm2D": | |
use_bias = True | |
else: | |
use_bias = False | |
self._conv = nn.Conv2D( | |
in_channels=6, | |
out_channels=encode_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=1, | |
weight_attr=paddle.ParamAttr(name=name + "_conv_weights"), | |
bias_attr=False) | |
self._res_block = ResBlock( | |
name="{}_conv_block".format(name), | |
channels=encode_dim, | |
norm_layer=norm_layer, | |
use_dropout=conv_block_dropout, | |
use_dilation=conv_block_dilation, | |
use_bias=use_bias) | |
self._reduce_conv = nn.Conv2D( | |
in_channels=encode_dim, | |
out_channels=3, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=1, | |
weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"), | |
bias_attr=False) | |
def forward(self, fake_text, fake_bg): | |
fake_concat = paddle.concat((fake_text, fake_bg), axis=1) | |
fake_concat_tmp = self._conv(fake_concat) | |
output_res = self._res_block(fake_concat_tmp) | |
fake_fusion = self._reduce_conv(output_res) | |
return {"fake_fusion": fake_fusion} | |