ddpm-mnist / app.py
stevhliu's picture
stevhliu HF staff
Update app.py
b84ee6b
raw
history blame
1.04 kB
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import gradio as gr
import random
import numpy as np
pipeline = DiffusionPipeline.from_pretrained("johnowhitaker/ddpm-butterflies-32px")
def predict(steps, seed):
generator = torch.manual_seed(seed)
for i in range(1,steps):
yield pipeline(generator=generator, num_inference_steps=i)["sample"][0]
random_seed = random.randint(0, 2147483647)
gr.Interface(
predict,
inputs=[
gr.inputs.Slider(1, 100, label='Inference Steps', default=5, step=1),
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
],
outputs=gr.Image(shape=[128,128], type="pil", elem_id="output_image"),
css="#output_image{width: 256px}",
title="Unconditional butterflies",
description="A DDPM scheduler and UNet model trained on a subset of the <a href=\"https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset\">Smithsonian Butterflies</a> dataset for unconditional image generation.",
).queue().launch()