dinhdat1110 commited on
Commit
88aefd4
·
verified ·
1 Parent(s): 4aaca4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -17
app.py CHANGED
@@ -6,33 +6,50 @@ from torchvision import transforms
6
 
7
 
8
  parser = argparse.ArgumentParser()
9
- parser.add_argument("--ckpt_path", type=str, default="./checkpoints/model/mnist.ckpt")
10
  parser.add_argument("--map_location", type=str, default="cpu")
11
  parser.add_argument("--share", action='store_true')
12
  args = parser.parse_args()
13
 
14
  if __name__ == "__main__":
15
- model = diffusion.DiffusionModel.load_from_checkpoint(
16
- args.ckpt_path, in_channels=1, map_location=args.map_location, num_classes=10
 
 
 
17
  )
18
  to_pil = transforms.ToPILImage()
19
 
20
- def reset(image):
21
- image = to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8))
22
- return image
23
-
24
- # def noising(image):
25
- # for i in range(100):
26
 
27
  def denoise(label, timesteps):
28
- labels = torch.tensor([label]).to(model.device)
29
- for img in model.sampling(labels=labels, demo=True, mode="ddim", timesteps=timesteps):
30
  image = to_pil(img[0])
31
  yield image
32
 
33
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
34
  gr.Markdown("# Simple Diffusion Model")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  gr.Markdown("## MNIST")
37
  with gr.Row():
38
  with gr.Column(scale=2):
@@ -42,17 +59,17 @@ if __name__ == "__main__":
42
  choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
43
  value=0
44
  )
45
- timesteps = gr.Radio(label='Timestep', choices=[10, 20, 50, 100, 200, 1000])
 
 
46
  with gr.Row():
47
- sample_btn = gr.Button("Sampling")
48
- reset_btn = gr.Button("Reset")
49
  output = gr.Image(
50
  value=to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8)),
51
- scale=2,
52
  image_mode="L",
53
  type='pil',
54
  )
55
- sample_btn.click(denoise, [label, timesteps], outputs=output)
56
- reset_btn.click(reset, [output], outputs=output)
57
 
58
  demo.launch(share=args.share)
 
6
 
7
 
8
  parser = argparse.ArgumentParser()
 
9
  parser.add_argument("--map_location", type=str, default="cpu")
10
  parser.add_argument("--share", action='store_true')
11
  args = parser.parse_args()
12
 
13
  if __name__ == "__main__":
14
+ model_mnist = diffusion.DiffusionModel.load_from_checkpoint(
15
+ "./checkpoints/model/mnist.ckpt"
16
+ )
17
+ model_celeba = diffusion.DiffusionModel.load_from_checkpoint(
18
+ "./checkpoints/model/celebahq.ckpt"
19
  )
20
  to_pil = transforms.ToPILImage()
21
 
22
+ def denoise_celeb(timesteps):
23
+ for img in model_celeba.sampling(demo=True, mode="ddim", timesteps=timesteps, n_samples=1):
24
+ image = to_pil(img[0])
25
+ yield image
 
 
26
 
27
  def denoise(label, timesteps):
28
+ labels = torch.tensor([label]).to(model_mnist.device)
29
+ for img in model_mnist.sampling(labels=labels, demo=True, mode="ddim", timesteps=timesteps):
30
  image = to_pil(img[0])
31
  yield image
32
 
33
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
34
  gr.Markdown("# Simple Diffusion Model")
35
 
36
+ gr.Markdown("## CelebA")
37
+ with gr.Row():
38
+ with gr.Column(scale=2):
39
+ timesteps_celeb = gr.Radio(
40
+ label='Timestep', choices=[10, 20, 50, 100, 200, 1000]
41
+ )
42
+ sample_celeb_btn = gr.Button("Sample")
43
+
44
+ output = gr.Image(
45
+ value=to_pil((torch.randn(3, 64, 64)*255).type(torch.uint8)),
46
+ scale=1,
47
+ image_mode="RGB",
48
+ type='pil',
49
+ )
50
+
51
+ sample_celeb_btn.click(denoise_celeb, [timesteps_celeb], outputs=output)
52
+
53
  gr.Markdown("## MNIST")
54
  with gr.Row():
55
  with gr.Column(scale=2):
 
59
  choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
60
  value=0
61
  )
62
+ timesteps = gr.Radio(
63
+ label='Timestep', choices=[10, 20, 50, 100, 200, 1000]
64
+ )
65
  with gr.Row():
66
+ sample_mnist_btn = gr.Button("Sample")
 
67
  output = gr.Image(
68
  value=to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8)),
69
+ scale=1,
70
  image_mode="L",
71
  type='pil',
72
  )
73
+ sample_mnist_btn.click(denoise, [label, timesteps], outputs=output)
 
74
 
75
  demo.launch(share=args.share)