Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
import torch
|
2 |
import PIL.Image
|
3 |
import gradio as gr
|
|
|
4 |
import numpy as np
|
5 |
|
6 |
from pipeline import DDIMPipelineCustom
|
7 |
|
8 |
pipeline = DDIMPipelineCustom.from_pretrained("1aurent/ddpm-mnist-conditional")
|
|
|
|
|
9 |
|
|
|
10 |
def predict(steps, seed, value, guidance):
|
11 |
generator = torch.manual_seed(seed)
|
12 |
for i in range(1,steps):
|
@@ -20,13 +24,13 @@ def predict(steps, seed, value, guidance):
|
|
20 |
gr.Interface(
|
21 |
predict,
|
22 |
inputs=[
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
],
|
28 |
-
outputs=gr.Image(
|
29 |
-
css="#output_image
|
30 |
title="Conditional MNIST",
|
31 |
description="A DDIM scheduler and UNet model trained on the MNIST dataset for conditional image generation.",
|
32 |
).queue().launch()
|
|
|
1 |
import torch
|
2 |
import PIL.Image
|
3 |
import gradio as gr
|
4 |
+
import gradio.components as grc
|
5 |
import numpy as np
|
6 |
|
7 |
from pipeline import DDIMPipelineCustom
|
8 |
|
9 |
pipeline = DDIMPipelineCustom.from_pretrained("1aurent/ddpm-mnist-conditional")
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
pipeline = pipeline.to(device=device)
|
12 |
|
13 |
+
@spaces.GPU
|
14 |
def predict(steps, seed, value, guidance):
|
15 |
generator = torch.manual_seed(seed)
|
16 |
for i in range(1,steps):
|
|
|
24 |
gr.Interface(
|
25 |
predict,
|
26 |
inputs=[
|
27 |
+
grc.Slider(1, 100, label='Inference Steps', value=20, step=1),
|
28 |
+
grc.Slider(0, 2147483647, label='Seed', value=69420, step=1),
|
29 |
+
grc.Slider(0, 9, label='Value', value=5, step=1),
|
30 |
+
grc.Slider(-2.5, 2.5, label='Guidance Factor', value=1),
|
31 |
],
|
32 |
+
outputs=gr.Image(height=28, width=28, type="pil", elem_id="output_image"),
|
33 |
+
css="#output_image{width: 256px !important; height: 256px !important;}",
|
34 |
title="Conditional MNIST",
|
35 |
description="A DDIM scheduler and UNet model trained on the MNIST dataset for conditional image generation.",
|
36 |
).queue().launch()
|