1aurent's picture
Update app.py
6860d8c verified
raw
history blame
1.26 kB
import torch
import PIL.Image
import gradio as gr
import gradio.components as grc
import numpy as np
from pipeline import DDIMPipelineCustom
pipeline = DDIMPipelineCustom.from_pretrained("1aurent/ddpm-mnist-conditional")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = pipeline.to(device=device)
@spaces.GPU
def predict(steps, seed, value, guidance):
generator = torch.manual_seed(seed)
for i in range(1,steps):
yield pipeline(
generator=generator,
condition=torch.tensor([value]),
guidance=guidance,
num_inference_steps=steps
).images[0]
gr.Interface(
predict,
inputs=[
grc.Slider(1, 100, label='Inference Steps', value=20, step=1),
grc.Slider(0, 2147483647, label='Seed', value=69420, step=1),
grc.Slider(0, 9, label='Value', value=5, step=1),
grc.Slider(-2.5, 2.5, label='Guidance Factor', value=1),
],
outputs=gr.Image(height=28, width=28, type="pil", elem_id="output_image"),
css="#output_image{width: 256px !important; height: 256px !important;}",
title="Conditional MNIST",
description="A DDIM scheduler and UNet model trained on the MNIST dataset for conditional image generation.",
).queue().launch()