dinhdat1110 commited on
Commit
f301a24
·
verified ·
1 Parent(s): 9c682fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import argparse
3
  import gradio as gr
4
  import diffusion
@@ -6,7 +7,7 @@ from torchvision import transforms
6
 
7
 
8
  parser = argparse.ArgumentParser()
9
- parser.add_argument("--ckpt_path", type=str, default="./checkpoints/mnist.ckpt")
10
  parser.add_argument("--map_location", type=str, default="cpu")
11
  parser.add_argument("--share", action='store_true')
12
  args = parser.parse_args()
@@ -21,9 +22,12 @@ if __name__ == "__main__":
21
  image = to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8))
22
  return image
23
 
24
- def denoise(label):
 
 
 
25
  labels = torch.tensor([label]).to(model.device)
26
- for img in model.sampling_demo(labels=labels):
27
  image = to_pil(img[0])
28
  yield image
29
 
@@ -33,11 +37,13 @@ if __name__ == "__main__":
33
  gr.Markdown("## MNIST")
34
  with gr.Row():
35
  with gr.Column(scale=2):
36
- label = gr.Dropdown(
37
- label='Label',
38
- choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
39
- value=0
40
- )
 
 
41
  with gr.Row():
42
  sample_btn = gr.Button("Sampling")
43
  reset_btn = gr.Button("Reset")
@@ -47,7 +53,7 @@ if __name__ == "__main__":
47
  image_mode="L",
48
  type='pil',
49
  )
50
- sample_btn.click(denoise, [label], outputs=output)
51
  reset_btn.click(reset, [output], outputs=output)
52
 
53
  demo.launch(share=args.share)
 
1
  import torch
2
+ import numpy as np
3
  import argparse
4
  import gradio as gr
5
  import diffusion
 
7
 
8
 
9
  parser = argparse.ArgumentParser()
10
+ parser.add_argument("--ckpt_path", type=str, default="./checkpoints/model/mnist.ckpt")
11
  parser.add_argument("--map_location", type=str, default="cpu")
12
  parser.add_argument("--share", action='store_true')
13
  args = parser.parse_args()
 
22
  image = to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8))
23
  return image
24
 
25
+ # def noising(image):
26
+ # for i in range(100):
27
+
28
+ def denoise(label, timesteps):
29
  labels = torch.tensor([label]).to(model.device)
30
+ for img in model.sampling(labels=labels, demo=True, mode="ddim", timesteps=timesteps):
31
  image = to_pil(img[0])
32
  yield image
33
 
 
37
  gr.Markdown("## MNIST")
38
  with gr.Row():
39
  with gr.Column(scale=2):
40
+ with gr.Row():
41
+ label = gr.Dropdown(
42
+ label='Label',
43
+ choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
44
+ value=0
45
+ )
46
+ timesteps = gr.Radio(label='Timestep', choices=[10, 20, 50, 100, 200, 1000])
47
  with gr.Row():
48
  sample_btn = gr.Button("Sampling")
49
  reset_btn = gr.Button("Reset")
 
53
  image_mode="L",
54
  type='pil',
55
  )
56
+ sample_btn.click(denoise, [label, timesteps], outputs=output)
57
  reset_btn.click(reset, [output], outputs=output)
58
 
59
  demo.launch(share=args.share)