Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from torchvision.utils import make_grid | |
from torchvision.transforms.v2.functional import to_pil_image | |
from huggingface_hub import hf_hub_download | |
from digitdreamer import Autoencoder, DiT | |
from digitdreamer.modules import RF | |
from PIL.Image import Image | |
hf_hub_download( | |
"karanravindra/digitdreamer", "ft-decoder.pth", subfolder="models", local_dir="." | |
) | |
hf_hub_download( | |
"karanravindra/digitdreamer", "diffusion.pth", subfolder="models", local_dir="." | |
) | |
torch.set_grad_enabled(False) | |
decoder = Autoencoder().decoder | |
dit = DiT() | |
decoder.load_state_dict(torch.load("models/ft-decoder.pth", weights_only=True)) | |
dit.load_state_dict(torch.load("models/diffusion.pth", weights_only=True)) | |
rf = RF(dit) | |
def generate(choice: str, images: int, steps: int, cfg: float): | |
if choice != "Random": | |
class_choice = int(choice) + 1 | |
cond = torch.full((images,), class_choice, dtype=torch.long) | |
else: | |
class_choice = torch.randint(1, 11, (images,)) | |
cond = class_choice | |
noise = torch.randn(images, 8, 2, 2) | |
uncond = torch.full((images,), 0, dtype=torch.long) | |
samples = rf.sample(noise, cond, uncond, sample_steps=steps, cfg=cfg) | |
samples = torch.cat(samples, dim=0) | |
imgs = decoder(samples).cpu() | |
imgs = imgs.view(-1, images, 1, 32, 32) | |
pil_imgs: list[Image] = [to_pil_image(make_grid(img, nrow=10)) for img in imgs] | |
return pil_imgs[-1] | |
demo = gr.Interface( | |
fn=generate, | |
submit_btn="Generate", | |
inputs=gr.Radio(label="Number", choices=list("0123456789")+["Random"], value="Random"), | |
additional_inputs=[ | |
gr.Slider(label="Number of Images", minimum=10, maximum=100, step=10, value=100), | |
gr.Slider(label="Number of Steps", minimum=1, maximum=100, step=1, value=6), | |
gr.Slider(label="Classifier Free Guidence", minimum=0, maximum=10, step=0.1, value=2) | |
], | |
outputs=gr.Image(), | |
title="DigitDreamer", | |
description="Generate images of a number using the DiT model", | |
) | |
demo.launch() | |