File size: 2,073 Bytes
f6a41bd
64174d5
f6a41bd
 
 
 
 
 
64174d5
f6a41bd
 
 
 
 
 
64174d5
f6a41bd
 
 
 
9f1fbd5
 
f6a41bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64174d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import gradio as gr
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
from huggingface_hub import hf_hub_download
from digitdreamer import Autoencoder, DiT
from digitdreamer.modules import RF
from PIL.Image import Image

hf_hub_download(
    "karanravindra/digitdreamer", "ft-decoder.pth", subfolder="models", local_dir="."
)
hf_hub_download(
    "karanravindra/digitdreamer", "diffusion.pth", subfolder="models", local_dir="."
)

torch.set_grad_enabled(False)
decoder = Autoencoder().decoder
dit = DiT()

decoder.load_state_dict(torch.load("models/ft-decoder.pth", weights_only=True, map_location="cpu"))
dit.load_state_dict(torch.load("models/diffusion.pth", weights_only=True, map_location="cpu"))

rf = RF(dit)


def generate(choice: str, images: int, steps: int, cfg: float):
    if choice != "Random":
        class_choice = int(choice) + 1
        cond = torch.full((images,), class_choice, dtype=torch.long)
    else:
        class_choice = torch.randint(1, 11, (images,))
        cond = class_choice

    noise = torch.randn(images, 8, 2, 2)
    uncond = torch.full((images,), 0, dtype=torch.long)

    samples = rf.sample(noise, cond, uncond, sample_steps=steps, cfg=cfg)

    samples = torch.cat(samples, dim=0)

    imgs = decoder(samples).cpu()
    imgs = imgs.view(-1, images, 1, 32, 32)

    pil_imgs: list[Image] = [to_pil_image(make_grid(img, nrow=10)) for img in imgs]
    
    return pil_imgs[-1]


demo = gr.Interface(
    fn=generate,
    submit_btn="Generate",
    inputs=gr.Radio(label="Number", choices=list("0123456789")+["Random"], value="Random"),
    additional_inputs=[
        gr.Slider(label="Number of Images", minimum=10, maximum=100, step=10, value=100),
        gr.Slider(label="Number of Steps", minimum=1, maximum=100, step=1, value=6),
        gr.Slider(label="Classifier Free Guidence", minimum=0, maximum=10, step=0.1, value=2)
    ],
    outputs=gr.Image(),
    title="DigitDreamer",
    description="Generate images of a number using the DiT model",
)
demo.launch()