Spaces:
Runtime error
Runtime error
File size: 2,073 Bytes
f6a41bd 64174d5 f6a41bd 64174d5 f6a41bd 64174d5 f6a41bd 9f1fbd5 f6a41bd 64174d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
import gradio as gr
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
from huggingface_hub import hf_hub_download
from digitdreamer import Autoencoder, DiT
from digitdreamer.modules import RF
from PIL.Image import Image
hf_hub_download(
"karanravindra/digitdreamer", "ft-decoder.pth", subfolder="models", local_dir="."
)
hf_hub_download(
"karanravindra/digitdreamer", "diffusion.pth", subfolder="models", local_dir="."
)
torch.set_grad_enabled(False)
decoder = Autoencoder().decoder
dit = DiT()
decoder.load_state_dict(torch.load("models/ft-decoder.pth", weights_only=True, map_location="cpu"))
dit.load_state_dict(torch.load("models/diffusion.pth", weights_only=True, map_location="cpu"))
rf = RF(dit)
def generate(choice: str, images: int, steps: int, cfg: float):
if choice != "Random":
class_choice = int(choice) + 1
cond = torch.full((images,), class_choice, dtype=torch.long)
else:
class_choice = torch.randint(1, 11, (images,))
cond = class_choice
noise = torch.randn(images, 8, 2, 2)
uncond = torch.full((images,), 0, dtype=torch.long)
samples = rf.sample(noise, cond, uncond, sample_steps=steps, cfg=cfg)
samples = torch.cat(samples, dim=0)
imgs = decoder(samples).cpu()
imgs = imgs.view(-1, images, 1, 32, 32)
pil_imgs: list[Image] = [to_pil_image(make_grid(img, nrow=10)) for img in imgs]
return pil_imgs[-1]
demo = gr.Interface(
fn=generate,
submit_btn="Generate",
inputs=gr.Radio(label="Number", choices=list("0123456789")+["Random"], value="Random"),
additional_inputs=[
gr.Slider(label="Number of Images", minimum=10, maximum=100, step=10, value=100),
gr.Slider(label="Number of Steps", minimum=1, maximum=100, step=1, value=6),
gr.Slider(label="Classifier Free Guidence", minimum=0, maximum=10, step=0.1, value=2)
],
outputs=gr.Image(),
title="DigitDreamer",
description="Generate images of a number using the DiT model",
)
demo.launch()
|