karanravindra commited on
Commit
f6a41bd
1 Parent(s): 64174d5
Files changed (2) hide show
  1. app.py +59 -3
  2. requirements.txt +114 -0
app.py CHANGED
@@ -1,7 +1,63 @@
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
+ import torch
2
  import gradio as gr
3
+ from torchvision.utils import make_grid
4
+ from torchvision.transforms.v2.functional import to_pil_image
5
+ from huggingface_hub import hf_hub_download
6
+ from digitdreamer import Autoencoder, DiT
7
+ from digitdreamer.modules import RF
8
+ from PIL.Image import Image
9
 
10
+ hf_hub_download(
11
+ "karanravindra/digitdreamer", "ft-decoder.pth", subfolder="models", local_dir="."
12
+ )
13
+ hf_hub_download(
14
+ "karanravindra/digitdreamer", "diffusion.pth", subfolder="models", local_dir="."
15
+ )
16
 
17
+ torch.set_grad_enabled(False)
18
+ decoder = Autoencoder().decoder
19
+ dit = DiT()
20
+
21
+ decoder.load_state_dict(torch.load("models/ft-decoder.pth", weights_only=True))
22
+ dit.load_state_dict(torch.load("models/diffusion.pth", weights_only=True))
23
+
24
+ rf = RF(dit)
25
+
26
+
27
+ def generate(choice: str, images: int, steps: int, cfg: float):
28
+ if choice != "Random":
29
+ class_choice = int(choice) + 1
30
+ cond = torch.full((images,), class_choice, dtype=torch.long)
31
+ else:
32
+ class_choice = torch.randint(1, 11, (images,))
33
+ cond = class_choice
34
+
35
+ noise = torch.randn(images, 8, 2, 2)
36
+ uncond = torch.full((images,), 0, dtype=torch.long)
37
+
38
+ samples = rf.sample(noise, cond, uncond, sample_steps=steps, cfg=cfg)
39
+
40
+ samples = torch.cat(samples, dim=0)
41
+
42
+ imgs = decoder(samples).cpu()
43
+ imgs = imgs.view(-1, images, 1, 32, 32)
44
+
45
+ pil_imgs: list[Image] = [to_pil_image(make_grid(img, nrow=10)) for img in imgs]
46
+
47
+ return pil_imgs[-1]
48
+
49
+
50
+ demo = gr.Interface(
51
+ fn=generate,
52
+ submit_btn="Generate",
53
+ inputs=gr.Radio(label="Number", choices=list("0123456789")+["Random"], value="Random"),
54
+ additional_inputs=[
55
+ gr.Slider(label="Number of Images", minimum=10, maximum=100, step=10, value=100),
56
+ gr.Slider(label="Number of Steps", minimum=1, maximum=100, step=1, value=6),
57
+ gr.Slider(label="Classifier Free Guidence", minimum=0, maximum=10, step=0.1, value=2)
58
+ ],
59
+ outputs=gr.Image(),
60
+ title="DigitDreamer",
61
+ description="Generate images of a number using the DiT model",
62
+ )
63
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ digitdreamer @ git+https://github.com/karanravindra/digitdreamer@main
2
+ huggingface_hub
3
+
4
+ # This file was autogenerated via `uv export`.
5
+ aiofiles==23.2.1
6
+ annotated-types==0.7.0
7
+ anyio==4.6.2.post1
8
+ appnope==0.1.4 ; platform_system == 'Darwin'
9
+ asttokens==2.4.1
10
+ certifi==2024.8.30
11
+ cffi==1.17.1 ; implementation_name == 'pypy'
12
+ charset-normalizer==3.4.0
13
+ click==8.1.7 ; sys_platform != 'emscripten'
14
+ colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows'
15
+ comm==0.2.2
16
+ contourpy==1.3.0
17
+ cycler==0.12.1
18
+ debugpy==1.8.7
19
+ decorator==5.1.1
20
+ einops==0.8.0
21
+ exceptiongroup==1.2.2 ; python_full_version < '3.11'
22
+ executing==2.1.0
23
+ fastapi==0.115.4
24
+ ffmpy==0.4.0
25
+ filelock==3.16.1
26
+ fonttools==4.54.1
27
+ fsspec==2024.10.0
28
+ gradio==5.3.0
29
+ gradio-client==1.4.2
30
+ h11==0.14.0
31
+ httpcore==1.0.6
32
+ httpx==0.27.2
33
+ huggingface-hub==0.26.2
34
+ idna==3.10
35
+ ipykernel==6.29.5
36
+ ipython==8.29.0
37
+ ipywidgets==8.1.5
38
+ jedi==0.19.1
39
+ jinja2==3.1.4
40
+ jupyter-client==8.6.3
41
+ jupyter-core==5.7.2
42
+ jupyterlab-widgets==3.0.13
43
+ kiwisolver==1.4.7
44
+ markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
45
+ markupsafe==2.1.5
46
+ matplotlib==3.9.2
47
+ matplotlib-inline==0.1.7
48
+ mdurl==0.1.2 ; sys_platform != 'emscripten'
49
+ mpmath==1.3.0
50
+ nest-asyncio==1.6.0
51
+ networkx==3.4.2
52
+ numpy==2.1.2
53
+ nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and platform_system == 'Linux'
54
+ nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
55
+ nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
56
+ nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
57
+ nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and platform_system == 'Linux'
58
+ nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
59
+ nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and platform_system == 'Linux'
60
+ nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and platform_system == 'Linux'
61
+ nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and platform_system == 'Linux'
62
+ nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
63
+ nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
64
+ nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
65
+ orjson==3.10.10
66
+ packaging==24.1
67
+ pandas==2.2.3
68
+ parso==0.8.4
69
+ pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
70
+ pillow==10.4.0
71
+ platformdirs==4.3.6
72
+ prompt-toolkit==3.0.48
73
+ psutil==6.1.0
74
+ ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
75
+ pure-eval==0.2.3
76
+ pycparser==2.22 ; implementation_name == 'pypy'
77
+ pydantic==2.9.2
78
+ pydantic-core==2.23.4
79
+ pydub==0.25.1
80
+ pygments==2.18.0
81
+ pyparsing==3.2.0
82
+ python-dateutil==2.9.0.post0
83
+ python-multipart==0.0.16
84
+ pytz==2024.2
85
+ pywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
86
+ pyyaml==6.0.2
87
+ pyzmq==26.2.0
88
+ requests==2.32.3
89
+ rich==13.9.3 ; sys_platform != 'emscripten'
90
+ ruff==0.7.1 ; sys_platform != 'emscripten'
91
+ semantic-version==2.10.0
92
+ setuptools==75.2.0 ; python_full_version >= '3.12'
93
+ shellingham==1.5.4 ; sys_platform != 'emscripten'
94
+ six==1.16.0
95
+ sniffio==1.3.1
96
+ stack-data==0.6.3
97
+ starlette==0.41.2
98
+ sympy==1.13.1
99
+ tomlkit==0.12.0
100
+ torch==2.5.0
101
+ torchinfo==1.8.0
102
+ torchvision==0.20.0
103
+ tornado==6.4.1
104
+ tqdm==4.66.5
105
+ traitlets==5.14.3
106
+ triton==3.1.0 ; python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'
107
+ typer==0.12.5 ; sys_platform != 'emscripten'
108
+ typing-extensions==4.12.2
109
+ tzdata==2024.2
110
+ urllib3==2.2.3
111
+ uvicorn==0.32.0 ; sys_platform != 'emscripten'
112
+ wcwidth==0.2.13
113
+ websockets==12.0
114
+ widgetsnbextension==4.0.13