Spaces:
No application file
No application file
culture
commited on
Commit
·
a7622b9
1
Parent(s):
58b9e16
Upload tests/test_gfpgan_model.py
Browse files- tests/test_gfpgan_model.py +132 -0
tests/test_gfpgan_model.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import torch
|
3 |
+
import yaml
|
4 |
+
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
|
5 |
+
from basicsr.data.paired_image_dataset import PairedImageDataset
|
6 |
+
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
7 |
+
|
8 |
+
from gfpgan.archs.arcface_arch import ResNetArcFace
|
9 |
+
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
|
10 |
+
from gfpgan.models.gfpgan_model import GFPGANModel
|
11 |
+
|
12 |
+
|
13 |
+
def test_gfpgan_model():
|
14 |
+
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
|
15 |
+
opt = yaml.load(f, Loader=yaml.FullLoader)
|
16 |
+
|
17 |
+
# build model
|
18 |
+
model = GFPGANModel(opt)
|
19 |
+
# test attributes
|
20 |
+
assert model.__class__.__name__ == 'GFPGANModel'
|
21 |
+
assert isinstance(model.net_g, GFPGANv1) # generator
|
22 |
+
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
|
23 |
+
# facial component discriminators
|
24 |
+
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
|
25 |
+
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
|
26 |
+
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
|
27 |
+
# identity network
|
28 |
+
assert isinstance(model.network_identity, ResNetArcFace)
|
29 |
+
# losses
|
30 |
+
assert isinstance(model.cri_pix, L1Loss)
|
31 |
+
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
32 |
+
assert isinstance(model.cri_gan, GANLoss)
|
33 |
+
assert isinstance(model.cri_l1, L1Loss)
|
34 |
+
# optimizer
|
35 |
+
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
36 |
+
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
37 |
+
|
38 |
+
# prepare data
|
39 |
+
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
40 |
+
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
41 |
+
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
|
42 |
+
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
|
43 |
+
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
|
44 |
+
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
|
45 |
+
model.feed_data(data)
|
46 |
+
# check data shape
|
47 |
+
assert model.lq.shape == (1, 3, 512, 512)
|
48 |
+
assert model.gt.shape == (1, 3, 512, 512)
|
49 |
+
assert model.loc_left_eyes.shape == (1, 4)
|
50 |
+
assert model.loc_right_eyes.shape == (1, 4)
|
51 |
+
assert model.loc_mouths.shape == (1, 4)
|
52 |
+
|
53 |
+
# ----------------- test optimize_parameters -------------------- #
|
54 |
+
model.feed_data(data)
|
55 |
+
model.optimize_parameters(1)
|
56 |
+
assert model.output.shape == (1, 3, 512, 512)
|
57 |
+
assert isinstance(model.log_dict, dict)
|
58 |
+
# check returned keys
|
59 |
+
expected_keys = [
|
60 |
+
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
61 |
+
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
62 |
+
'l_d_right_eye', 'l_d_mouth'
|
63 |
+
]
|
64 |
+
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
65 |
+
|
66 |
+
# ----------------- remove pyramid_loss_weight-------------------- #
|
67 |
+
model.feed_data(data)
|
68 |
+
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
|
69 |
+
assert model.output.shape == (1, 3, 512, 512)
|
70 |
+
assert isinstance(model.log_dict, dict)
|
71 |
+
# check returned keys
|
72 |
+
expected_keys = [
|
73 |
+
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
74 |
+
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
75 |
+
'l_d_right_eye', 'l_d_mouth'
|
76 |
+
]
|
77 |
+
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
78 |
+
|
79 |
+
# ----------------- test save -------------------- #
|
80 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
81 |
+
model.opt['path']['models'] = tmpdir
|
82 |
+
model.opt['path']['training_states'] = tmpdir
|
83 |
+
model.save(0, 1)
|
84 |
+
|
85 |
+
# ----------------- test the test function -------------------- #
|
86 |
+
model.test()
|
87 |
+
assert model.output.shape == (1, 3, 512, 512)
|
88 |
+
# delete net_g_ema
|
89 |
+
model.__delattr__('net_g_ema')
|
90 |
+
model.test()
|
91 |
+
assert model.output.shape == (1, 3, 512, 512)
|
92 |
+
assert model.net_g.training is True # should back to training mode after testing
|
93 |
+
|
94 |
+
# ----------------- test nondist_validation -------------------- #
|
95 |
+
# construct dataloader
|
96 |
+
dataset_opt = dict(
|
97 |
+
name='Demo',
|
98 |
+
dataroot_gt='tests/data/gt',
|
99 |
+
dataroot_lq='tests/data/gt',
|
100 |
+
io_backend=dict(type='disk'),
|
101 |
+
scale=4,
|
102 |
+
phase='val')
|
103 |
+
dataset = PairedImageDataset(dataset_opt)
|
104 |
+
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
105 |
+
assert model.is_train is True
|
106 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
107 |
+
model.opt['path']['visualization'] = tmpdir
|
108 |
+
model.nondist_validation(dataloader, 1, None, save_img=True)
|
109 |
+
assert model.is_train is True
|
110 |
+
# check metric_results
|
111 |
+
assert 'psnr' in model.metric_results
|
112 |
+
assert isinstance(model.metric_results['psnr'], float)
|
113 |
+
|
114 |
+
# validation
|
115 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
116 |
+
model.opt['is_train'] = False
|
117 |
+
model.opt['val']['suffix'] = 'test'
|
118 |
+
model.opt['path']['visualization'] = tmpdir
|
119 |
+
model.opt['val']['pbar'] = True
|
120 |
+
model.nondist_validation(dataloader, 1, None, save_img=True)
|
121 |
+
# check metric_results
|
122 |
+
assert 'psnr' in model.metric_results
|
123 |
+
assert isinstance(model.metric_results['psnr'], float)
|
124 |
+
|
125 |
+
# if opt['val']['suffix'] is None
|
126 |
+
model.opt['val']['suffix'] = None
|
127 |
+
model.opt['name'] = 'demo'
|
128 |
+
model.opt['path']['visualization'] = tmpdir
|
129 |
+
model.nondist_validation(dataloader, 1, None, save_img=True)
|
130 |
+
# check metric_results
|
131 |
+
assert 'psnr' in model.metric_results
|
132 |
+
assert isinstance(model.metric_results['psnr'], float)
|