culture commited on
Commit
a7622b9
·
1 Parent(s): 58b9e16

Upload tests/test_gfpgan_model.py

Browse files
Files changed (1) hide show
  1. tests/test_gfpgan_model.py +132 -0
tests/test_gfpgan_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import torch
3
+ import yaml
4
+ from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
5
+ from basicsr.data.paired_image_dataset import PairedImageDataset
6
+ from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
7
+
8
+ from gfpgan.archs.arcface_arch import ResNetArcFace
9
+ from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
10
+ from gfpgan.models.gfpgan_model import GFPGANModel
11
+
12
+
13
+ def test_gfpgan_model():
14
+ with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
15
+ opt = yaml.load(f, Loader=yaml.FullLoader)
16
+
17
+ # build model
18
+ model = GFPGANModel(opt)
19
+ # test attributes
20
+ assert model.__class__.__name__ == 'GFPGANModel'
21
+ assert isinstance(model.net_g, GFPGANv1) # generator
22
+ assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
23
+ # facial component discriminators
24
+ assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
25
+ assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
26
+ assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
27
+ # identity network
28
+ assert isinstance(model.network_identity, ResNetArcFace)
29
+ # losses
30
+ assert isinstance(model.cri_pix, L1Loss)
31
+ assert isinstance(model.cri_perceptual, PerceptualLoss)
32
+ assert isinstance(model.cri_gan, GANLoss)
33
+ assert isinstance(model.cri_l1, L1Loss)
34
+ # optimizer
35
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
36
+ assert isinstance(model.optimizers[1], torch.optim.Adam)
37
+
38
+ # prepare data
39
+ gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
40
+ lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
41
+ loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
42
+ loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
43
+ loc_mouth = torch.rand((1, 4), dtype=torch.float32)
44
+ data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
45
+ model.feed_data(data)
46
+ # check data shape
47
+ assert model.lq.shape == (1, 3, 512, 512)
48
+ assert model.gt.shape == (1, 3, 512, 512)
49
+ assert model.loc_left_eyes.shape == (1, 4)
50
+ assert model.loc_right_eyes.shape == (1, 4)
51
+ assert model.loc_mouths.shape == (1, 4)
52
+
53
+ # ----------------- test optimize_parameters -------------------- #
54
+ model.feed_data(data)
55
+ model.optimize_parameters(1)
56
+ assert model.output.shape == (1, 3, 512, 512)
57
+ assert isinstance(model.log_dict, dict)
58
+ # check returned keys
59
+ expected_keys = [
60
+ 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
61
+ 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
62
+ 'l_d_right_eye', 'l_d_mouth'
63
+ ]
64
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
65
+
66
+ # ----------------- remove pyramid_loss_weight-------------------- #
67
+ model.feed_data(data)
68
+ model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
69
+ assert model.output.shape == (1, 3, 512, 512)
70
+ assert isinstance(model.log_dict, dict)
71
+ # check returned keys
72
+ expected_keys = [
73
+ 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
74
+ 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
75
+ 'l_d_right_eye', 'l_d_mouth'
76
+ ]
77
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
78
+
79
+ # ----------------- test save -------------------- #
80
+ with tempfile.TemporaryDirectory() as tmpdir:
81
+ model.opt['path']['models'] = tmpdir
82
+ model.opt['path']['training_states'] = tmpdir
83
+ model.save(0, 1)
84
+
85
+ # ----------------- test the test function -------------------- #
86
+ model.test()
87
+ assert model.output.shape == (1, 3, 512, 512)
88
+ # delete net_g_ema
89
+ model.__delattr__('net_g_ema')
90
+ model.test()
91
+ assert model.output.shape == (1, 3, 512, 512)
92
+ assert model.net_g.training is True # should back to training mode after testing
93
+
94
+ # ----------------- test nondist_validation -------------------- #
95
+ # construct dataloader
96
+ dataset_opt = dict(
97
+ name='Demo',
98
+ dataroot_gt='tests/data/gt',
99
+ dataroot_lq='tests/data/gt',
100
+ io_backend=dict(type='disk'),
101
+ scale=4,
102
+ phase='val')
103
+ dataset = PairedImageDataset(dataset_opt)
104
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
105
+ assert model.is_train is True
106
+ with tempfile.TemporaryDirectory() as tmpdir:
107
+ model.opt['path']['visualization'] = tmpdir
108
+ model.nondist_validation(dataloader, 1, None, save_img=True)
109
+ assert model.is_train is True
110
+ # check metric_results
111
+ assert 'psnr' in model.metric_results
112
+ assert isinstance(model.metric_results['psnr'], float)
113
+
114
+ # validation
115
+ with tempfile.TemporaryDirectory() as tmpdir:
116
+ model.opt['is_train'] = False
117
+ model.opt['val']['suffix'] = 'test'
118
+ model.opt['path']['visualization'] = tmpdir
119
+ model.opt['val']['pbar'] = True
120
+ model.nondist_validation(dataloader, 1, None, save_img=True)
121
+ # check metric_results
122
+ assert 'psnr' in model.metric_results
123
+ assert isinstance(model.metric_results['psnr'], float)
124
+
125
+ # if opt['val']['suffix'] is None
126
+ model.opt['val']['suffix'] = None
127
+ model.opt['name'] = 'demo'
128
+ model.opt['path']['visualization'] = tmpdir
129
+ model.nondist_validation(dataloader, 1, None, save_img=True)
130
+ # check metric_results
131
+ assert 'psnr' in model.metric_results
132
+ assert isinstance(model.metric_results['psnr'], float)