|
import tempfile |
|
import torch |
|
import yaml |
|
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator |
|
from basicsr.data.paired_image_dataset import PairedImageDataset |
|
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss |
|
|
|
from gfpgan.archs.arcface_arch import ResNetArcFace |
|
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1 |
|
from gfpgan.models.gfpgan_model import GFPGANModel |
|
|
|
|
|
def test_gfpgan_model(): |
|
with open('tests/data/test_gfpgan_model.yml', mode='r') as f: |
|
opt = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
|
|
model = GFPGANModel(opt) |
|
|
|
assert model.__class__.__name__ == 'GFPGANModel' |
|
assert isinstance(model.net_g, GFPGANv1) |
|
assert isinstance(model.net_d, StyleGAN2Discriminator) |
|
|
|
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator) |
|
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator) |
|
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator) |
|
|
|
assert isinstance(model.network_identity, ResNetArcFace) |
|
|
|
assert isinstance(model.cri_pix, L1Loss) |
|
assert isinstance(model.cri_perceptual, PerceptualLoss) |
|
assert isinstance(model.cri_gan, GANLoss) |
|
assert isinstance(model.cri_l1, L1Loss) |
|
|
|
assert isinstance(model.optimizers[0], torch.optim.Adam) |
|
assert isinstance(model.optimizers[1], torch.optim.Adam) |
|
|
|
|
|
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32) |
|
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32) |
|
loc_left_eye = torch.rand((1, 4), dtype=torch.float32) |
|
loc_right_eye = torch.rand((1, 4), dtype=torch.float32) |
|
loc_mouth = torch.rand((1, 4), dtype=torch.float32) |
|
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth) |
|
model.feed_data(data) |
|
|
|
assert model.lq.shape == (1, 3, 512, 512) |
|
assert model.gt.shape == (1, 3, 512, 512) |
|
assert model.loc_left_eyes.shape == (1, 4) |
|
assert model.loc_right_eyes.shape == (1, 4) |
|
assert model.loc_mouths.shape == (1, 4) |
|
|
|
|
|
model.feed_data(data) |
|
model.optimize_parameters(1) |
|
assert model.output.shape == (1, 3, 512, 512) |
|
assert isinstance(model.log_dict, dict) |
|
|
|
expected_keys = [ |
|
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth', |
|
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye', |
|
'l_d_right_eye', 'l_d_mouth' |
|
] |
|
assert set(expected_keys).issubset(set(model.log_dict.keys())) |
|
|
|
|
|
model.feed_data(data) |
|
model.optimize_parameters(100000) |
|
assert model.output.shape == (1, 3, 512, 512) |
|
assert isinstance(model.log_dict, dict) |
|
|
|
expected_keys = [ |
|
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth', |
|
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye', |
|
'l_d_right_eye', 'l_d_mouth' |
|
] |
|
assert set(expected_keys).issubset(set(model.log_dict.keys())) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
model.opt['path']['models'] = tmpdir |
|
model.opt['path']['training_states'] = tmpdir |
|
model.save(0, 1) |
|
|
|
|
|
model.test() |
|
assert model.output.shape == (1, 3, 512, 512) |
|
|
|
model.__delattr__('net_g_ema') |
|
model.test() |
|
assert model.output.shape == (1, 3, 512, 512) |
|
assert model.net_g.training is True |
|
|
|
|
|
|
|
dataset_opt = dict( |
|
name='Demo', |
|
dataroot_gt='tests/data/gt', |
|
dataroot_lq='tests/data/gt', |
|
io_backend=dict(type='disk'), |
|
scale=4, |
|
phase='val') |
|
dataset = PairedImageDataset(dataset_opt) |
|
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) |
|
assert model.is_train is True |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
model.opt['path']['visualization'] = tmpdir |
|
model.nondist_validation(dataloader, 1, None, save_img=True) |
|
assert model.is_train is True |
|
|
|
assert 'psnr' in model.metric_results |
|
assert isinstance(model.metric_results['psnr'], float) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
model.opt['is_train'] = False |
|
model.opt['val']['suffix'] = 'test' |
|
model.opt['path']['visualization'] = tmpdir |
|
model.opt['val']['pbar'] = True |
|
model.nondist_validation(dataloader, 1, None, save_img=True) |
|
|
|
assert 'psnr' in model.metric_results |
|
assert isinstance(model.metric_results['psnr'], float) |
|
|
|
|
|
model.opt['val']['suffix'] = None |
|
model.opt['name'] = 'demo' |
|
model.opt['path']['visualization'] = tmpdir |
|
model.nondist_validation(dataloader, 1, None, save_img=True) |
|
|
|
assert 'psnr' in model.metric_results |
|
assert isinstance(model.metric_results['psnr'], float) |
|
|