ddpm-mnist / app.py
stevhliu's picture
stevhliu HF Staff
Create app.py
b260fbd
raw
history blame
1.02 kB
from diffusers import DDPMPipeline
import torch
import PIL.Image
import gradio as gr
import random
import numpy as np
pipeline = DDPMPipeline.from_pretrained("stevhliu/ddpm-butterflies-128")
def predict(steps, seed):
generator = torch.manual_seed(seed)
for i in range(1,steps):
yield pipeline(generator=generator, num_inference_steps=i)["sample"][0]
random_seed = random.randint(0, 2147483647)
gr.Interface(
predict,
inputs=[
gr.inputs.Slider(1, 100, label='Inference Steps', default=5, step=1),
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
],
outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
css="#output_image{width: 256px}",
title="Unconditional butterflies",
description="A DDPM scheduler and UNet model trained on a subset of the <a href=\"https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset\">Smithsonian Butterflies</a> dataset for unconditional image generation.
).queue().launch()