ddpm-mnist / app.py
1aurent's picture
Update app.py
0c19bde verified
raw
history blame
974 Bytes
from diffusers import DiffusionPipeline
import spaces
import torch
import PIL.Image
import gradio as gr
import gradio.components as grc
import numpy as np
pipeline = DiffusionPipeline.from_pretrained("1aurent/ddpm-mnist")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = pipeline.to(device=device)
@spaces.GPU
def predict(steps, seed):
generator = torch.manual_seed(seed)
for i in range(1,steps):
yield pipeline(generator=generator, num_inference_steps=i).images[0]
gr.Interface(
predict,
inputs=[
grc.Slider(1, 100, label='Inference Steps', default=12, step=1),
grc.Slider(0, 2147483647, label='Seed', default=69420, step=1),
],
outputs=gr.Image(shape=[28,28], type="pil", elem_id="output_image"),
css="#output_image{width: 256px}",
title="Unconditional MNIST",
description="A DDIM scheduler and UNet model trained on the MNIST dataset for unconditional image generation.",
).queue().launch()