import torch import gradio as gr from torchvision.utils import make_grid from torchvision.transforms.v2.functional import to_pil_image from huggingface_hub import hf_hub_download from digitdreamer import Autoencoder, DiT from digitdreamer.modules import RF from PIL.Image import Image hf_hub_download( "karanravindra/digitdreamer", "ft-decoder.pth", subfolder="models", local_dir="." ) hf_hub_download( "karanravindra/digitdreamer", "diffusion.pth", subfolder="models", local_dir="." ) torch.set_grad_enabled(False) decoder = Autoencoder().decoder dit = DiT() decoder.load_state_dict(torch.load("models/ft-decoder.pth", weights_only=True)) dit.load_state_dict(torch.load("models/diffusion.pth", weights_only=True)) rf = RF(dit) def generate(choice: str, images: int, steps: int, cfg: float): if choice != "Random": class_choice = int(choice) + 1 cond = torch.full((images,), class_choice, dtype=torch.long) else: class_choice = torch.randint(1, 11, (images,)) cond = class_choice noise = torch.randn(images, 8, 2, 2) uncond = torch.full((images,), 0, dtype=torch.long) samples = rf.sample(noise, cond, uncond, sample_steps=steps, cfg=cfg) samples = torch.cat(samples, dim=0) imgs = decoder(samples).cpu() imgs = imgs.view(-1, images, 1, 32, 32) pil_imgs: list[Image] = [to_pil_image(make_grid(img, nrow=10)) for img in imgs] return pil_imgs[-1] demo = gr.Interface( fn=generate, submit_btn="Generate", inputs=gr.Radio(label="Number", choices=list("0123456789")+["Random"], value="Random"), additional_inputs=[ gr.Slider(label="Number of Images", minimum=10, maximum=100, step=10, value=100), gr.Slider(label="Number of Steps", minimum=1, maximum=100, step=1, value=6), gr.Slider(label="Classifier Free Guidence", minimum=0, maximum=10, step=0.1, value=2) ], outputs=gr.Image(), title="DigitDreamer", description="Generate images of a number using the DiT model", ) demo.launch()