culture commited on
Commit
f6e08a3
·
1 Parent(s): 21b1ebc

Delete tests/test_stylegan2_clean_arch.py

Browse files
Files changed (1) hide show
  1. tests/test_stylegan2_clean_arch.py +0 -52
tests/test_stylegan2_clean_arch.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
-
3
- from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
4
-
5
-
6
- def test_stylegan2generatorclean():
7
- """Test arch: StyleGAN2GeneratorClean."""
8
-
9
- # model init and forward (gpu)
10
- if torch.cuda.is_available():
11
- net = StyleGAN2GeneratorClean(
12
- out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
13
- style = torch.rand((1, 512), dtype=torch.float32).cuda()
14
- output = net([style], input_is_latent=False)
15
- assert output[0].shape == (1, 3, 32, 32)
16
- assert output[1] is None
17
-
18
- # -------------------- with return_latents ----------------------- #
19
- output = net([style], input_is_latent=True, return_latents=True)
20
- assert output[0].shape == (1, 3, 32, 32)
21
- assert len(output[1]) == 1
22
- # check latent
23
- assert output[1][0].shape == (8, 512)
24
-
25
- # -------------------- with randomize_noise = False ----------------------- #
26
- output = net([style], randomize_noise=False)
27
- assert output[0].shape == (1, 3, 32, 32)
28
- assert output[1] is None
29
-
30
- # -------------------- with truncation = 0.5 and mixing----------------------- #
31
- output = net([style, style], truncation=0.5, truncation_latent=style)
32
- assert output[0].shape == (1, 3, 32, 32)
33
- assert output[1] is None
34
-
35
- # ------------------ test make_noise ----------------------- #
36
- out = net.make_noise()
37
- assert len(out) == 7
38
- assert out[0].shape == (1, 1, 4, 4)
39
- assert out[1].shape == (1, 1, 8, 8)
40
- assert out[2].shape == (1, 1, 8, 8)
41
- assert out[3].shape == (1, 1, 16, 16)
42
- assert out[4].shape == (1, 1, 16, 16)
43
- assert out[5].shape == (1, 1, 32, 32)
44
- assert out[6].shape == (1, 1, 32, 32)
45
-
46
- # ------------------ test get_latent ----------------------- #
47
- out = net.get_latent(style)
48
- assert out.shape == (1, 512)
49
-
50
- # ------------------ test mean_latent ----------------------- #
51
- out = net.mean_latent(2)
52
- assert out.shape == (1, 512)