1aurent commited on
Commit
6860d8c
·
verified ·
1 Parent(s): 195980a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import torch
2
  import PIL.Image
3
  import gradio as gr
 
4
  import numpy as np
5
 
6
  from pipeline import DDIMPipelineCustom
7
 
8
  pipeline = DDIMPipelineCustom.from_pretrained("1aurent/ddpm-mnist-conditional")
 
 
9
 
 
10
  def predict(steps, seed, value, guidance):
11
  generator = torch.manual_seed(seed)
12
  for i in range(1,steps):
@@ -20,13 +24,13 @@ def predict(steps, seed, value, guidance):
20
  gr.Interface(
21
  predict,
22
  inputs=[
23
- gr.components.Slider(1, 100, label='Inference Steps', value=20, step=1),
24
- gr.components.Slider(0, 2147483647, label='Seed', value=69420, step=1),
25
- gr.components.Slider(0, 9, label='Value', value=5, step=1),
26
- gr.components.Slider(-2.5, 2.5, label='Guidance Factor', value=1),
27
  ],
28
- outputs=gr.Image(shape=[28,28], type="pil", elem_id="output_image"),
29
- css="#output_image img {width: 256px}",
30
  title="Conditional MNIST",
31
  description="A DDIM scheduler and UNet model trained on the MNIST dataset for conditional image generation.",
32
  ).queue().launch()
 
1
  import torch
2
  import PIL.Image
3
  import gradio as gr
4
+ import gradio.components as grc
5
  import numpy as np
6
 
7
  from pipeline import DDIMPipelineCustom
8
 
9
  pipeline = DDIMPipelineCustom.from_pretrained("1aurent/ddpm-mnist-conditional")
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ pipeline = pipeline.to(device=device)
12
 
13
+ @spaces.GPU
14
  def predict(steps, seed, value, guidance):
15
  generator = torch.manual_seed(seed)
16
  for i in range(1,steps):
 
24
  gr.Interface(
25
  predict,
26
  inputs=[
27
+ grc.Slider(1, 100, label='Inference Steps', value=20, step=1),
28
+ grc.Slider(0, 2147483647, label='Seed', value=69420, step=1),
29
+ grc.Slider(0, 9, label='Value', value=5, step=1),
30
+ grc.Slider(-2.5, 2.5, label='Guidance Factor', value=1),
31
  ],
32
+ outputs=gr.Image(height=28, width=28, type="pil", elem_id="output_image"),
33
+ css="#output_image{width: 256px !important; height: 256px !important;}",
34
  title="Conditional MNIST",
35
  description="A DDIM scheduler and UNet model trained on the MNIST dataset for conditional image generation.",
36
  ).queue().launch()