Spaces:
Runtime error
Runtime error
karanravindra
commited on
Commit
•
f6a41bd
1
Parent(s):
64174d5
make demo
Browse files- app.py +59 -3
- requirements.txt +114 -0
app.py
CHANGED
@@ -1,7 +1,63 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
demo.launch()
|
|
|
1 |
+
import torch
|
2 |
import gradio as gr
|
3 |
+
from torchvision.utils import make_grid
|
4 |
+
from torchvision.transforms.v2.functional import to_pil_image
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from digitdreamer import Autoencoder, DiT
|
7 |
+
from digitdreamer.modules import RF
|
8 |
+
from PIL.Image import Image
|
9 |
|
10 |
+
hf_hub_download(
|
11 |
+
"karanravindra/digitdreamer", "ft-decoder.pth", subfolder="models", local_dir="."
|
12 |
+
)
|
13 |
+
hf_hub_download(
|
14 |
+
"karanravindra/digitdreamer", "diffusion.pth", subfolder="models", local_dir="."
|
15 |
+
)
|
16 |
|
17 |
+
torch.set_grad_enabled(False)
|
18 |
+
decoder = Autoencoder().decoder
|
19 |
+
dit = DiT()
|
20 |
+
|
21 |
+
decoder.load_state_dict(torch.load("models/ft-decoder.pth", weights_only=True))
|
22 |
+
dit.load_state_dict(torch.load("models/diffusion.pth", weights_only=True))
|
23 |
+
|
24 |
+
rf = RF(dit)
|
25 |
+
|
26 |
+
|
27 |
+
def generate(choice: str, images: int, steps: int, cfg: float):
|
28 |
+
if choice != "Random":
|
29 |
+
class_choice = int(choice) + 1
|
30 |
+
cond = torch.full((images,), class_choice, dtype=torch.long)
|
31 |
+
else:
|
32 |
+
class_choice = torch.randint(1, 11, (images,))
|
33 |
+
cond = class_choice
|
34 |
+
|
35 |
+
noise = torch.randn(images, 8, 2, 2)
|
36 |
+
uncond = torch.full((images,), 0, dtype=torch.long)
|
37 |
+
|
38 |
+
samples = rf.sample(noise, cond, uncond, sample_steps=steps, cfg=cfg)
|
39 |
+
|
40 |
+
samples = torch.cat(samples, dim=0)
|
41 |
+
|
42 |
+
imgs = decoder(samples).cpu()
|
43 |
+
imgs = imgs.view(-1, images, 1, 32, 32)
|
44 |
+
|
45 |
+
pil_imgs: list[Image] = [to_pil_image(make_grid(img, nrow=10)) for img in imgs]
|
46 |
+
|
47 |
+
return pil_imgs[-1]
|
48 |
+
|
49 |
+
|
50 |
+
demo = gr.Interface(
|
51 |
+
fn=generate,
|
52 |
+
submit_btn="Generate",
|
53 |
+
inputs=gr.Radio(label="Number", choices=list("0123456789")+["Random"], value="Random"),
|
54 |
+
additional_inputs=[
|
55 |
+
gr.Slider(label="Number of Images", minimum=10, maximum=100, step=10, value=100),
|
56 |
+
gr.Slider(label="Number of Steps", minimum=1, maximum=100, step=1, value=6),
|
57 |
+
gr.Slider(label="Classifier Free Guidence", minimum=0, maximum=10, step=0.1, value=2)
|
58 |
+
],
|
59 |
+
outputs=gr.Image(),
|
60 |
+
title="DigitDreamer",
|
61 |
+
description="Generate images of a number using the DiT model",
|
62 |
+
)
|
63 |
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
digitdreamer @ git+https://github.com/karanravindra/digitdreamer@main
|
2 |
+
huggingface_hub
|
3 |
+
|
4 |
+
# This file was autogenerated via `uv export`.
|
5 |
+
aiofiles==23.2.1
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyio==4.6.2.post1
|
8 |
+
appnope==0.1.4 ; platform_system == 'Darwin'
|
9 |
+
asttokens==2.4.1
|
10 |
+
certifi==2024.8.30
|
11 |
+
cffi==1.17.1 ; implementation_name == 'pypy'
|
12 |
+
charset-normalizer==3.4.0
|
13 |
+
click==8.1.7 ; sys_platform != 'emscripten'
|
14 |
+
colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows'
|
15 |
+
comm==0.2.2
|
16 |
+
contourpy==1.3.0
|
17 |
+
cycler==0.12.1
|
18 |
+
debugpy==1.8.7
|
19 |
+
decorator==5.1.1
|
20 |
+
einops==0.8.0
|
21 |
+
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
22 |
+
executing==2.1.0
|
23 |
+
fastapi==0.115.4
|
24 |
+
ffmpy==0.4.0
|
25 |
+
filelock==3.16.1
|
26 |
+
fonttools==4.54.1
|
27 |
+
fsspec==2024.10.0
|
28 |
+
gradio==5.3.0
|
29 |
+
gradio-client==1.4.2
|
30 |
+
h11==0.14.0
|
31 |
+
httpcore==1.0.6
|
32 |
+
httpx==0.27.2
|
33 |
+
huggingface-hub==0.26.2
|
34 |
+
idna==3.10
|
35 |
+
ipykernel==6.29.5
|
36 |
+
ipython==8.29.0
|
37 |
+
ipywidgets==8.1.5
|
38 |
+
jedi==0.19.1
|
39 |
+
jinja2==3.1.4
|
40 |
+
jupyter-client==8.6.3
|
41 |
+
jupyter-core==5.7.2
|
42 |
+
jupyterlab-widgets==3.0.13
|
43 |
+
kiwisolver==1.4.7
|
44 |
+
markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
|
45 |
+
markupsafe==2.1.5
|
46 |
+
matplotlib==3.9.2
|
47 |
+
matplotlib-inline==0.1.7
|
48 |
+
mdurl==0.1.2 ; sys_platform != 'emscripten'
|
49 |
+
mpmath==1.3.0
|
50 |
+
nest-asyncio==1.6.0
|
51 |
+
networkx==3.4.2
|
52 |
+
numpy==2.1.2
|
53 |
+
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
54 |
+
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
55 |
+
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
56 |
+
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
57 |
+
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
58 |
+
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
59 |
+
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
60 |
+
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
61 |
+
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
62 |
+
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
63 |
+
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
64 |
+
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
65 |
+
orjson==3.10.10
|
66 |
+
packaging==24.1
|
67 |
+
pandas==2.2.3
|
68 |
+
parso==0.8.4
|
69 |
+
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
70 |
+
pillow==10.4.0
|
71 |
+
platformdirs==4.3.6
|
72 |
+
prompt-toolkit==3.0.48
|
73 |
+
psutil==6.1.0
|
74 |
+
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
75 |
+
pure-eval==0.2.3
|
76 |
+
pycparser==2.22 ; implementation_name == 'pypy'
|
77 |
+
pydantic==2.9.2
|
78 |
+
pydantic-core==2.23.4
|
79 |
+
pydub==0.25.1
|
80 |
+
pygments==2.18.0
|
81 |
+
pyparsing==3.2.0
|
82 |
+
python-dateutil==2.9.0.post0
|
83 |
+
python-multipart==0.0.16
|
84 |
+
pytz==2024.2
|
85 |
+
pywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
|
86 |
+
pyyaml==6.0.2
|
87 |
+
pyzmq==26.2.0
|
88 |
+
requests==2.32.3
|
89 |
+
rich==13.9.3 ; sys_platform != 'emscripten'
|
90 |
+
ruff==0.7.1 ; sys_platform != 'emscripten'
|
91 |
+
semantic-version==2.10.0
|
92 |
+
setuptools==75.2.0 ; python_full_version >= '3.12'
|
93 |
+
shellingham==1.5.4 ; sys_platform != 'emscripten'
|
94 |
+
six==1.16.0
|
95 |
+
sniffio==1.3.1
|
96 |
+
stack-data==0.6.3
|
97 |
+
starlette==0.41.2
|
98 |
+
sympy==1.13.1
|
99 |
+
tomlkit==0.12.0
|
100 |
+
torch==2.5.0
|
101 |
+
torchinfo==1.8.0
|
102 |
+
torchvision==0.20.0
|
103 |
+
tornado==6.4.1
|
104 |
+
tqdm==4.66.5
|
105 |
+
traitlets==5.14.3
|
106 |
+
triton==3.1.0 ; python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'
|
107 |
+
typer==0.12.5 ; sys_platform != 'emscripten'
|
108 |
+
typing-extensions==4.12.2
|
109 |
+
tzdata==2024.2
|
110 |
+
urllib3==2.2.3
|
111 |
+
uvicorn==0.32.0 ; sys_platform != 'emscripten'
|
112 |
+
wcwidth==0.2.13
|
113 |
+
websockets==12.0
|
114 |
+
widgetsnbextension==4.0.13
|