Cropinky commited on
Commit
7885ecf
·
1 Parent(s): 37f46ae
Files changed (2) hide show
  1. app.py +66 -0
  2. requirements.txt +1 -0
app.py CHANGED
@@ -4,6 +4,72 @@ import torch
4
  import matplotlib.pyplot as plt
5
  import torchvision
6
  from networks_fastgan import MyGenerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def image_generation(model, number_of_images=1):
9
  G = MyGenerator.from_pretrained("Cropinky/projected_gan_impressionism")
 
4
  import matplotlib.pyplot as plt
5
  import torchvision
6
  from networks_fastgan import MyGenerator
7
+ import click
8
+ import PIL
9
+
10
+ @click.command()
11
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', default = 10-15, required=True)
12
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
13
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
14
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
15
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
16
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
17
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
18
+ def generate_images(
19
+ seeds: List[int],
20
+ truncation_psi: float,
21
+ noise_mode: str,
22
+ outdir: str,
23
+ translate: Tuple[float,float],
24
+ rotate: float,
25
+ class_idx: Optional[int]
26
+ ):
27
+ """Generate images using pretrained network pickle.
28
+
29
+ Examples:
30
+
31
+ \b
32
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
33
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
34
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
35
+
36
+ \b
37
+ # Generate uncurated images with truncation using the MetFaces-U dataset
38
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
39
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
40
+ """
41
+
42
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
43
+ G = MyGenerator.from_pretrained("Cropinky/projected_gan_impressionism")
44
+ print("network loaded")
45
+ # Labels.
46
+ label = torch.zeros([1, G.c_dim], device=device)
47
+ if G.c_dim != 0:
48
+ if class_idx is None:
49
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
50
+ label[:, class_idx] = 1
51
+ else:
52
+ if class_idx is not None:
53
+ print ('warn: --class=lbl ignored when running on an unconditional network')
54
+
55
+ # Generate images.
56
+ for seed_idx, seed in enumerate(seeds):
57
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
58
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
59
+
60
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
61
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
62
+ # operations in the network.
63
+ if hasattr(G.synthesis, 'input'):
64
+ m = make_transform(translate, rotate)
65
+ m = np.linalg.inv(m)
66
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
67
+
68
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
69
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
70
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
71
+
72
+
73
 
74
  def image_generation(model, number_of_images=1):
75
  G = MyGenerator.from_pretrained("Cropinky/projected_gan_impressionism")
requirements.txt CHANGED
@@ -2,3 +2,4 @@ gradio
2
  torchvision
3
  matplotlib
4
  torch
 
 
2
  torchvision
3
  matplotlib
4
  torch
5
+ dnnlib