DigitDreamer / app.py
karanravindra's picture
make demo
f6a41bd unverified
raw
history blame
2.03 kB
import torch
import gradio as gr
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
from huggingface_hub import hf_hub_download
from digitdreamer import Autoencoder, DiT
from digitdreamer.modules import RF
from PIL.Image import Image
hf_hub_download(
"karanravindra/digitdreamer", "ft-decoder.pth", subfolder="models", local_dir="."
)
hf_hub_download(
"karanravindra/digitdreamer", "diffusion.pth", subfolder="models", local_dir="."
)
torch.set_grad_enabled(False)
decoder = Autoencoder().decoder
dit = DiT()
decoder.load_state_dict(torch.load("models/ft-decoder.pth", weights_only=True))
dit.load_state_dict(torch.load("models/diffusion.pth", weights_only=True))
rf = RF(dit)
def generate(choice: str, images: int, steps: int, cfg: float):
if choice != "Random":
class_choice = int(choice) + 1
cond = torch.full((images,), class_choice, dtype=torch.long)
else:
class_choice = torch.randint(1, 11, (images,))
cond = class_choice
noise = torch.randn(images, 8, 2, 2)
uncond = torch.full((images,), 0, dtype=torch.long)
samples = rf.sample(noise, cond, uncond, sample_steps=steps, cfg=cfg)
samples = torch.cat(samples, dim=0)
imgs = decoder(samples).cpu()
imgs = imgs.view(-1, images, 1, 32, 32)
pil_imgs: list[Image] = [to_pil_image(make_grid(img, nrow=10)) for img in imgs]
return pil_imgs[-1]
demo = gr.Interface(
fn=generate,
submit_btn="Generate",
inputs=gr.Radio(label="Number", choices=list("0123456789")+["Random"], value="Random"),
additional_inputs=[
gr.Slider(label="Number of Images", minimum=10, maximum=100, step=10, value=100),
gr.Slider(label="Number of Steps", minimum=1, maximum=100, step=1, value=6),
gr.Slider(label="Classifier Free Guidence", minimum=0, maximum=10, step=0.1, value=2)
],
outputs=gr.Image(),
title="DigitDreamer",
description="Generate images of a number using the DiT model",
)
demo.launch()