hr16 commited on
Commit
cc7e213
·
1 Parent(s): 44da1ce

Final result

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -43,7 +43,7 @@ def generate_images(G, args):
43
  if len(args.gpu) <= 1:
44
  for i, ref in enumerate(noise_reference):
45
  noise_tensors[i].append(
46
- torch.from_numpy(rnd.standard_normal(tuple([*ref.size()[1:]]))) #?
47
  )
48
  if label_size:
49
  labels.append(torch.tensor([rnd.integers(0, label_size)]))
@@ -79,13 +79,13 @@ def generate_images(G, args):
79
 
80
  #----------------------------------------------------------------------------
81
 
82
- def inference(seed):
83
  G = stylegan2.models.load(hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", "Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN']))
84
  G.eval()
85
  return generate_images(
86
  G,
87
  SimpleNamespace(**{
88
- 'truncation_psi': 0.7, #It seems like 0.7 will give the best result for this model.
89
  'seeds': [seed],
90
  'batch_size': 1,
91
  'pixel_min': -1,
@@ -95,6 +95,6 @@ def inference(seed):
95
  )[0]
96
 
97
  title = "TWDNEv3 CPU Generator"
98
- description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port). To use it, simply put a any-bit-size unsigned integer seed."
99
  article = ""
100
- gr.Interface(inference, [gr.Number(precision=0)], gr.outputs.Image(type="pil"),title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False).launch()
 
43
  if len(args.gpu) <= 1:
44
  for i, ref in enumerate(noise_reference):
45
  noise_tensors[i].append(
46
+ torch.from_numpy(rnd.standard_normal(tuple([*ref.size()[1:]])))
47
  )
48
  if label_size:
49
  labels.append(torch.tensor([rnd.integers(0, label_size)]))
 
79
 
80
  #----------------------------------------------------------------------------
81
 
82
+ def inference(seed, truncation_psi):
83
  G = stylegan2.models.load(hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", "Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN']))
84
  G.eval()
85
  return generate_images(
86
  G,
87
  SimpleNamespace(**{
88
+ 'truncation_psi': truncation_psi,
89
  'seeds': [seed],
90
  'batch_size': 1,
91
  'pixel_min': -1,
 
95
  )[0]
96
 
97
  title = "TWDNEv3 CPU Generator"
98
+ description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port)"
99
  article = ""
100
+ gr.Interface(inference, [gr.Number(precision=0, label="Seed (any-bit-size unsigned int)"), gr.Slider(0, 2, step=0.1, default=0.7, label='Truncation psi (aka creative level, between 0 and 2)')], gr.outputs.Image(type="pil"),title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False).launch()