Spaces:
Runtime error
Runtime error
Commit
·
f977a8b
1
Parent(s):
b38b153
Fuc complicated stuff
Browse files- .gitattributes +0 -35
- .gitignore +0 -4
- Dockerfile.fastapi +0 -14
- Dockerfile.gradio +0 -13
- LICENSE +0 -21
- README.md +2 -1
- docker-compose.yml +0 -17
- img_examples/a beautiful woman with blonde hair in her 50s_cfg_7_seed_11.png +0 -0
- img_examples/a cute grey great owl_cfg_8_seed_11.png +0 -0
- img_examples/a lake in mountains in the fall at sunset_cfg_7_seed_11.png +0 -0
- img_examples/a woman cyborg with red curly hair, 8k_cfg_9.5_seed_11.png +0 -0
- img_examples/an aerial view of manhattan, isometric view, as pantinted by mondrian_cfg_7_seed_11.png +0 -0
- img_examples/isometric view of small japanese village with blooming trees_cfg_7_seed_11.png +0 -0
- img_examples/painting of a cute fox in a suit in a field of poppies_cfg_8_seed_11.png +0 -0
- img_examples/painting of a cyberpunk market_cfg_7_seed_11.png +0 -0
- img_examples/watercolor of a cute cat riding a motorcycle_cfg_7_seed_11.png +0 -0
- main.py +17 -0
- og readme.md +0 -211
- old/main.py +0 -37
- pyproject.toml +0 -23
- requirements.txt +1 -10
- setup.py +0 -36
- start.sh +1 -1
- tests/__init__.py +0 -0
- tests/client.js +0 -15
- tests/test_api.py +0 -30
- tests/test_diffuser.py +0 -99
- tld/__init__.py +0 -0
- tld/app.py +0 -66
- tld/data.py +0 -243
- tld/denoiser.py +0 -123
- tld/diffusion.py +0 -198
- tld/gen_img.py +0 -41
- tld/gradio_app.py +0 -40
- tld/train.py +0 -208
- tld/transformer_blocks.py +0 -139
.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
__pycache__
|
2 |
-
node_modules
|
3 |
-
poetry.lock
|
4 |
-
cache
|
|
|
|
|
|
|
|
|
|
Dockerfile.fastapi
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
|
2 |
-
|
3 |
-
WORKDIR /app
|
4 |
-
|
5 |
-
RUN apt-get update && apt-get install -y git
|
6 |
-
|
7 |
-
COPY . /app
|
8 |
-
|
9 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
10 |
-
RUN pip install --no-cache-dir uvicorn gunicorn fastapi pytest ruff pytest-asyncio httpx
|
11 |
-
|
12 |
-
EXPOSE 80
|
13 |
-
|
14 |
-
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile.gradio
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
FROM python:3.10-slim
|
2 |
-
|
3 |
-
WORKDIR /app
|
4 |
-
|
5 |
-
RUN apt-get update && apt-get install -y git
|
6 |
-
|
7 |
-
COPY . /app
|
8 |
-
|
9 |
-
RUN pip install --no-cache-dir gradio Pillow
|
10 |
-
|
11 |
-
EXPOSE 80
|
12 |
-
|
13 |
-
CMD ["python", "tld/gradio_app.py"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2023 Alexandru Papiu
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -8,4 +8,5 @@ pinned: false
|
|
8 |
app_file: main.py
|
9 |
---
|
10 |
|
11 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
8 |
app_file: main.py
|
9 |
---
|
10 |
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
12 |
+
[Other GH](https://github.com/apapiu/transformer_latent_diffusion)
|
docker-compose.yml
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
version: '3.8'
|
2 |
-
services:
|
3 |
-
fastapi:
|
4 |
-
image: apapiu89/tld-app:latest
|
5 |
-
ports:
|
6 |
-
- "80:80"
|
7 |
-
environment:
|
8 |
-
- API_TOKEN=${API_TOKEN}
|
9 |
-
|
10 |
-
gradio:
|
11 |
-
image: apapiu89/gradio-app:latest
|
12 |
-
ports:
|
13 |
-
- "7860:7860"
|
14 |
-
environment:
|
15 |
-
- API_URL=http://fastapi:80
|
16 |
-
depends_on:
|
17 |
-
- fastapi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_examples/a beautiful woman with blonde hair in her 50s_cfg_7_seed_11.png
DELETED
Binary file (404 kB)
|
|
img_examples/a cute grey great owl_cfg_8_seed_11.png
DELETED
Binary file (477 kB)
|
|
img_examples/a lake in mountains in the fall at sunset_cfg_7_seed_11.png
DELETED
Binary file (419 kB)
|
|
img_examples/a woman cyborg with red curly hair, 8k_cfg_9.5_seed_11.png
DELETED
Binary file (429 kB)
|
|
img_examples/an aerial view of manhattan, isometric view, as pantinted by mondrian_cfg_7_seed_11.png
DELETED
Binary file (502 kB)
|
|
img_examples/isometric view of small japanese village with blooming trees_cfg_7_seed_11.png
DELETED
Binary file (481 kB)
|
|
img_examples/painting of a cute fox in a suit in a field of poppies_cfg_8_seed_11.png
DELETED
Binary file (462 kB)
|
|
img_examples/painting of a cyberpunk market_cfg_7_seed_11.png
DELETED
Binary file (479 kB)
|
|
img_examples/watercolor of a cute cat riding a motorcycle_cfg_7_seed_11.png
DELETED
Binary file (361 kB)
|
|
main.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import StableDiffusionPipeline
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# Find models in https://huggingface.co/models?pipeline_tag=text-to-image&library=diffusers&sort=trending
|
6 |
+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
7 |
+
|
8 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
9 |
+
pipe = pipe.to("cuda")
|
10 |
+
|
11 |
+
prompt = "beautiful horse"
|
12 |
+
|
13 |
+
image = pipe(prompt).images[0]
|
14 |
+
|
15 |
+
print("[PROMPT]: ", prompt)
|
16 |
+
plt.imshow(image)
|
17 |
+
plt.axis('off')
|
og readme.md
DELETED
@@ -1,211 +0,0 @@
|
|
1 |
-
# Transformer Latent Diffusion
|
2 |
-
Text to Image Latent Diffusion using a Transformer core in PyTorch.
|
3 |
-
|
4 |
-
[Original Github](https://github.com/apapiu/transformer_latent_diffusion)
|
5 |
-
|
6 |
-
**Try with own inputs**: [](https://colab.research.google.com/drive/1VaCe01YG9rnPwAfwVLBKdXEX7D_tk1U5?usp=sharing)
|
7 |
-
|
8 |
-
Below are some random examples (at 256 resolution) from a 100MM model trained from scratch for 260k iterations (about 32 hours on 1 A100):
|
9 |
-
|
10 |
-
<img width="760" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/e01e3094-2487-4c04-bc0f-d9b03eeaed00">
|
11 |
-
|
12 |
-
#### Clip interpolation Examples:
|
13 |
-
|
14 |
-
a photo of a cat → an anime drawing of a super saiyan cat, artstation:
|
15 |
-
|
16 |
-
<img width="1361" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/a079458b-9bd5-4557-aa7a-5a3e78f31b53">
|
17 |
-
|
18 |
-
a cute great gray owl → starry night by van gogh:
|
19 |
-
|
20 |
-
<img width="1399" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/8731d87a-89fa-43a2-847d-c7ff772de286">
|
21 |
-
|
22 |
-
Note that the model has not converged yet and could use more training.
|
23 |
-
|
24 |
-
#### High(er) Resolution:
|
25 |
-
By upsampling the positional encoding the model can also generate 512 or 1024 px images with minimal fine-tuning. See below for some examples of model fine-tuned on 100k extra 512 px images and 30k 1024 px images for about 2 hours on an A100. The images do sometimes lack global coherence at 1024 px - more to come here:
|
26 |
-
|
27 |
-
<img width="600" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/adba64f0-b43c-423e-9a7d-033a4afea207">
|
28 |
-
<img width="600" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/5a94515b-313e-420d-89d4-6bdc376d9a00">
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
### Intro:
|
33 |
-
|
34 |
-
The main goal of this project is to build an accessible diffusion model in PyTorch that is:
|
35 |
-
- fast (close to real time generation)
|
36 |
-
- small (~100MM params)
|
37 |
-
- reasonably good (of course not SOTA)
|
38 |
-
- can be trained in a reasonable amount of time on a single GPU (under 50 hours on an A100 or equivalent).
|
39 |
-
- simple self-contained codebase (model + train loop is about ~400 lines of PyTorch with little dependencies)
|
40 |
-
- uses ~ 1 million images with a focus on data quality over quantity
|
41 |
-
|
42 |
-
This is part II of a previous [project](https://github.com/apapiu/guided-diffusion-keras) I did where I trained a pixel level diffusion model in Keras. Even though this model outputs 4x higher resolution images (256px vs 64px), it's actually faster to both train and sample from, which shows the power of training in the latent space and speed of transformer architectures.
|
43 |
-
|
44 |
-
## Table of Contents:
|
45 |
-
- [Codebase](#codebase)
|
46 |
-
- [Usage](#usage)
|
47 |
-
- [Examples](#examples)
|
48 |
-
- [Data Processing](#data-processing)
|
49 |
-
- [Architecture](#architecture)
|
50 |
-
- [TO-DOs](#todos)
|
51 |
-
|
52 |
-
|
53 |
-
## Codebase:
|
54 |
-
The code is written in pure PyTorch with as few dependencies as possible.
|
55 |
-
|
56 |
-
- [transformer_blocks.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/transformer_blocks.py) - basic transformer building blocks relevant to the transformer denoiser
|
57 |
-
- [denoiser.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/denoiser.py) - the architecture of the denoiser transformer
|
58 |
-
- [train.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/train.py). The train loop uses `accelerate` so its training can scale to multiple GPUs if needed.
|
59 |
-
- [diffusion.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/diffusion.py). Class to generate image from noise using reverse diffusion. Short (~60 lines) and self-contained.
|
60 |
-
- [data.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/data.py). Data utils to download images/text and process necessary features for the diffusion model.
|
61 |
-
|
62 |
-
### Usage:
|
63 |
-
If you have your own dataset of URLs + captions, the process to train a model on the data consists of two steps:
|
64 |
-
|
65 |
-
1. Use `train.download_and_process_data` to obtain the latent and text encodings as numpy files. See [](https://colab.research.google.com/drive/1BPDFDBdsP9SSKBNEFJysmlBjfoxKK13r?usp=sharing) for a notebook example downloading and processing 2000 images from this HuggingFace [dataset](https://huggingface.co/datasets/zzliang/GRIT).
|
66 |
-
|
67 |
-
2. use the `train.main` function in an accelerate `notebook_launcher` - see [](https://colab.research.google.com/drive/1sKk0usxEF4bmdCDcNQJQNMt4l9qBOeAM?usp=sharing) for a colab notebook that trains a model on 100k images from scratch. Note that this downloads already pre-preprocessed latents and embeddings from [here](https://huggingface.co/apapiu/small_ldt/tree/main) but you could just use whatever `.npy` files you had saved from step 1.
|
68 |
-
|
69 |
-
#### Fine-Tuning - TODO but it is the same as step 2 above except you train on a pre-trained model.
|
70 |
-
|
71 |
-
```python
|
72 |
-
!wandb login
|
73 |
-
import os
|
74 |
-
from train import main, DataConfig, ModelConfig
|
75 |
-
from accelerate import notebook_launcher
|
76 |
-
|
77 |
-
data_config = DataConfig(latent_path='path/to/image_latents.npy',
|
78 |
-
text_emb_path='path/to/text_encodings.npy',
|
79 |
-
val_path='path/to/val_encodings.npy')
|
80 |
-
|
81 |
-
model_config = ModelConfig(embed_dim=512, n_layers=6) #see ModelConfig for more params
|
82 |
-
|
83 |
-
#run the training process on 2 GPUs:
|
84 |
-
notebook_launcher(main, (model_config, data_config), num_processes=2)
|
85 |
-
```
|
86 |
-
|
87 |
-
### Dependencies:
|
88 |
-
- `PyTorch` `numpy` `einops` for model building
|
89 |
-
- `wandb` `tqdm` for logging + progress bars
|
90 |
-
- `accelerate` for train loop and multi-GPU support
|
91 |
-
- `img2dataset` `webdataset` `torchvision` for data downloading and image processing
|
92 |
-
- `diffusers` `clip` for pretrained VAE and CLIP text model
|
93 |
-
|
94 |
-
### Codebases used for inspiration:
|
95 |
-
- [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha)
|
96 |
-
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
|
97 |
-
- [nanoGPT](https://github.com/karpathy/nanoGPT/tree/master)
|
98 |
-
- [LocalViT](https://github.com/ofsoundof/LocalViT)
|
99 |
-
|
100 |
-
#### Speed:
|
101 |
-
|
102 |
-
I try to speed up training and inference as much as possible by:
|
103 |
-
- using mixed precision for training + [sdpa]
|
104 |
-
- precompute all latent and text embeddings
|
105 |
-
- using float16 precision for inference
|
106 |
-
- using [sdpa] for the attention natively + torch.compile() (compile doesn't always work).
|
107 |
-
- use a highly performant sampler (DPM-Solver++(2M)) that gets good results in ~ 15 steps.
|
108 |
-
- TODO: would distillation or something like LCM work here?
|
109 |
-
- TODO: use flash-attention2?
|
110 |
-
- TODO: use smaller vae?
|
111 |
-
|
112 |
-
The time to generate a batch of 36 images (15 iterations) on a:
|
113 |
-
- T4: ~ 3.5 seconds
|
114 |
-
- A100: ~ 0.6 seconds
|
115 |
-
In fact on an A100 the vae becomes the bottleneck even though it is only used once.
|
116 |
-
|
117 |
-
|
118 |
-
## Examples:
|
119 |
-
|
120 |
-
More examples generated with the 100MM model - click the photo to see the prompt and other params like cfg and seed:
|
121 |
-

|
122 |
-

|
123 |
-

|
124 |
-

|
125 |
-

|
126 |
-

|
127 |
-

|
128 |
-
|
129 |
-
## Outpainting model:
|
130 |
-
|
131 |
-
I also fine-tuned an outpaing model on top of the original 101MM model. I had to modify the original input conv2d patch to 8 channel and initialize the mask channels parameters to zero. The rest of the architecture remained the same.
|
132 |
-
|
133 |
-
Below I apply the outpainting model repatedly to generate a somewhat consistent scenery based on the prompt "a cyberpunk marketplace":
|
134 |
-
|
135 |
-
<img width="1440" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/4451719f-d45a-4a86-a7bb-06c021b34996">
|
136 |
-
|
137 |
-
## Data Processing:
|
138 |
-
|
139 |
-
In [data.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/data.py), I have some helper functions to process images and captions. The flow is as follows:
|
140 |
-
- Use `img2dataset` to download images from a dataframe containing URLs and captions.
|
141 |
-
- Use `CLIP` to encode the prompts and the `VAE` to encode images to latents on a web2dataset data generator.
|
142 |
-
- Save the latents and text embedding for future training.
|
143 |
-
|
144 |
-
There are two advantages to this approach. One is that the VAE encoding is somewhat expensive, so doing it every epoch would affect training times. The other is that we can discard the images after processing. For `3*256*256` images, the latent dimension is `4*32*32`, so every latent is around 4KB (when quantized in uint8; see [here](https://pub.towardsai.net/stable-diffusion-based-image-compresssion-6f1f0a399202?gi=1f45c6522d3b)). This means that 1 million latents will be "only" 4GB in size, which is easy to handle even in RAM. Storing the raw images would have been 48x larger in size.
|
145 |
-
|
146 |
-
## Architecture:
|
147 |
-
|
148 |
-
See [here](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/denoiser.py) for the denoiser class.
|
149 |
-
|
150 |
-
The denoiser model is a Transformer-based model based on the archirtecture in [DiT](https://arxiv.org/abs/2203.02378) and [Pixart-Alpha](https://pixart-alpha.github.io/), albeit with quite a few modifications and simplifications. Using a Transformer as the denoiser is different from most diffusion models in that most other models used a CNN-based U-NET as the denoising backbone. I decided to use a Transformer for a few reasons. One was I just wanted to experiment and learn how to build and train Transformers from the ground up. Secondly, Transformers are fast both to train and to do inference on, and they will benefit most from future advances (both in hardware and in software) in performance.
|
151 |
-
|
152 |
-
Transformers are not natively built for spatial data and at first I found a lot of the outputs to be very "patchy". To remediy that I added a depth-wise convolution in the FFN layer of the transformer (this was introduced in the [Local ViT](https://arxiv.org/abs/2104.05707) paper. This allows the model to mix pixels that are close to each other with very little added compute cost.
|
153 |
-
|
154 |
-
|
155 |
-
### Img+Text+Noise Encoding:
|
156 |
-
|
157 |
-
The image latent inputs are `4*32*32` and we use a patch size of 2 to build 256 flattened `4*2*2=16` dimensional input "pixels". These are then projected into the embed dimensions are are fed through the transformer blocks.
|
158 |
-
|
159 |
-
The text and noise conditioning is very simple - we concatenate a pooled CLIP text embedding (`ViT/L14` - 768-dimensional) and the sinusoidal noise embedding and feed it as input in the cross-attention layer in each transformer block. No unpooled CLIP embeddings are used.
|
160 |
-
|
161 |
-
### Training:
|
162 |
-
The base model is 101MM parameters and has 12 layers and embedding dimension = 768. I train it with a batch size of 256 on a A100 and learning rate of `3e-4`. I used 1000 steps for warmup. Due to computational contraints I did not do any ablations for this configuration.
|
163 |
-
|
164 |
-
|
165 |
-
## Train and Diffusion Setup:
|
166 |
-
|
167 |
-
We train a denoising transformer that takes the following three inputs:
|
168 |
-
- `noise_level` (sampled from 0 to 1 with more values concentrated close to 0 - I use a beta distribution)
|
169 |
-
- Image latent (x) corrupted with a level of random noise
|
170 |
-
- For a given `noise_level` between 0 and 1, the corruption is as follows:
|
171 |
-
- `x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal(0, 1)`
|
172 |
-
- CLIP embeddings of a text prompt
|
173 |
-
- You can think of this as a numerical representation of a text prompt.
|
174 |
-
- We use the pooled text embedding here (768 dimensional for `ViT/L14`)
|
175 |
-
|
176 |
-
The output is a prediction of the denoised image latent - call it `f(x_noisy)`.
|
177 |
-
|
178 |
-
The model is trained to minimize the mean squared error `|f(x_noisy) - x|` between the prediction and actual image
|
179 |
-
(you can also use absolute error here). Note that I don't reparameterize the loss in terms of the noise here to keep things simple.
|
180 |
-
|
181 |
-
Using this model, we then iteratively generate an image from random noise as follows:
|
182 |
-
|
183 |
-
for i in range(len(self.noise_levels) - 1):
|
184 |
-
|
185 |
-
curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
|
186 |
-
|
187 |
-
# Predict original denoised image:
|
188 |
-
x0_pred = predict_x_zero(new_img, label, curr_noise)
|
189 |
-
|
190 |
-
# New image at next_noise level is a weighted average of old image and predicted x0:
|
191 |
-
new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
|
192 |
-
|
193 |
-
The `predict_x_zero` method uses classifier free guidance by combining the conditional and unconditional
|
194 |
-
prediction: `x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional`
|
195 |
-
|
196 |
-
A bit of math: The approach above falls within the VDM parametrization see 3.1 in [Kingma et al.](https://arxiv.org/pdf/2107.00630.pdf):
|
197 |
-
|
198 |
-
$$z_t = \alpha_t x + \sigma_t \epsilon, \epsilon \sim \mathcal{N}(0,1)$$
|
199 |
-
|
200 |
-
Where $z_t$ is the noisy version of $x$ at time $t$.
|
201 |
-
|
202 |
-
Generally, $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here, I chose $\alpha_t=1-\sigma_t$ so that we linearly interpolate between the image and random noise. Why? For one, it simplifies the updating equation quite a bit, and it's easier to understand what the noise to signal ratio will look like. I also found that the model produces sharper images faster - more validation here is needed. The updating equation above is the DDIM model for this parametrization, which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, and it generally takes fewer steps to achieve decent image quality.
|
203 |
-
|
204 |
-
## TODOS:
|
205 |
-
- better config in the train file
|
206 |
-
- how to speed up generation even more - LCMs or other sampling strategies?
|
207 |
-
- add script to compute FID
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
old/main.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from transformers import CLIPProcessor, CLIPModel
|
3 |
-
from PIL import Image
|
4 |
-
|
5 |
-
# Get the directory of the script
|
6 |
-
script_directory = os.path.dirname(os.path.realpath(__file__))
|
7 |
-
# Specify the directory where the cache will be stored (same folder as the script)
|
8 |
-
cache_directory = os.path.join(script_directory, "cache")
|
9 |
-
# Create the cache directory if it doesn't exist
|
10 |
-
os.makedirs(cache_directory, exist_ok=True)
|
11 |
-
|
12 |
-
# Load the CLIP processor and model
|
13 |
-
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
|
14 |
-
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
|
15 |
-
|
16 |
-
# Text description to generate image
|
17 |
-
text = "a cat sitting on a table"
|
18 |
-
|
19 |
-
# Tokenize text and get features
|
20 |
-
inputs = clip_processor(text, return_tensors="pt", padding=True)
|
21 |
-
|
22 |
-
# Generate image from text
|
23 |
-
generated_image = clip_model.generate(
|
24 |
-
inputs=inputs.input_ids,
|
25 |
-
attention_mask=inputs.attention_mask,
|
26 |
-
visual_input=None, # We don't provide image inputvi
|
27 |
-
return_tensors="pt" # Return PyTorch tensor
|
28 |
-
)
|
29 |
-
|
30 |
-
# Convert the generated image tensor to a NumPy array
|
31 |
-
generated_image_np = generated_image[0].cpu().numpy()
|
32 |
-
|
33 |
-
# Save the generated image
|
34 |
-
output_image_path = "generated_image.png"
|
35 |
-
Image.fromarray(generated_image_np).save(output_image_path)
|
36 |
-
|
37 |
-
print("Image generated and saved as:", output_image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
[tool.poetry]
|
2 |
-
name = "img-gen"
|
3 |
-
version = "0.0.1"
|
4 |
-
description = ""
|
5 |
-
authors = ["CubeBeveled <[email protected]>"]
|
6 |
-
readme = "README.md"
|
7 |
-
|
8 |
-
[tool.poetry.dependencies]
|
9 |
-
python = "^3.11"
|
10 |
-
torch = "^2.2.1"
|
11 |
-
numpy = "^1.26.4"
|
12 |
-
einops = "^0.7.0"
|
13 |
-
torchvision = "^0.17.1"
|
14 |
-
tqdm = "^4.66.2"
|
15 |
-
diffusers = "^0.27.2"
|
16 |
-
accelerate = "^0.28.0"
|
17 |
-
transformers = "^4.39.1"
|
18 |
-
pillow = "^10.2.0"
|
19 |
-
|
20 |
-
|
21 |
-
[build-system]
|
22 |
-
requires = ["poetry-core"]
|
23 |
-
build-backend = "poetry.core.masonry.api"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,11 +1,2 @@
|
|
1 |
-
torch
|
2 |
-
numpy
|
3 |
-
einops
|
4 |
-
torchvision
|
5 |
-
tqdm
|
6 |
diffusers
|
7 |
-
|
8 |
-
transformers
|
9 |
-
Pillow
|
10 |
-
poetry
|
11 |
-
git+https://github.com/openai/CLIP.git
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
diffusers
|
2 |
+
transformers
|
|
|
|
|
|
|
|
setup.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
from setuptools import setup, find_packages
|
2 |
-
|
3 |
-
|
4 |
-
def load_requirements(filename="requirements.txt"):
|
5 |
-
with open(filename, "r") as file:
|
6 |
-
lines = [line.strip() for line in file.readlines() if line.strip() and not line.startswith("#")]
|
7 |
-
return lines
|
8 |
-
|
9 |
-
|
10 |
-
setup(
|
11 |
-
name="tld",
|
12 |
-
version="0.1.0",
|
13 |
-
author="Alexandru Papiu",
|
14 |
-
author_email="[email protected]",
|
15 |
-
description="Transformer Latent Diffusion",
|
16 |
-
url="https://github.com/apapiu/transformer_latent_diffusion",
|
17 |
-
packages=find_packages(exclude=["tests*"]),
|
18 |
-
classifiers=[
|
19 |
-
"Programming Language :: Python :: 3",
|
20 |
-
"License :: OSI Approved :: MIT License",
|
21 |
-
"Operating System :: OS Independent",
|
22 |
-
],
|
23 |
-
python_requires=">=3.6",
|
24 |
-
install_requires=[
|
25 |
-
"torch",
|
26 |
-
"numpy",
|
27 |
-
"einops",
|
28 |
-
"torchvision",
|
29 |
-
"tqdm",
|
30 |
-
"diffusers",
|
31 |
-
"accelerate",
|
32 |
-
"transformers",
|
33 |
-
"Pillow",
|
34 |
-
"clip @ git+https://github.com/openai/CLIP.git",
|
35 |
-
],
|
36 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start.sh
CHANGED
@@ -1 +1 @@
|
|
1 |
-
python
|
|
|
1 |
+
python main.py
|
tests/__init__.py
DELETED
File without changes
|
tests/client.js
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
const axios = require("axios");
|
2 |
-
|
3 |
-
const apiUrl = `http://de-fsn-4.halex.gg:25287/api`;
|
4 |
-
|
5 |
-
const postData = {
|
6 |
-
prompt: "Wassup my homie"
|
7 |
-
};
|
8 |
-
|
9 |
-
axios.post(apiUrl, postData)
|
10 |
-
.then(response => {
|
11 |
-
console.log("Response from API:", response.data);
|
12 |
-
})
|
13 |
-
.catch(error => {
|
14 |
-
console.error("Error:", error.message);
|
15 |
-
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_api.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
from fastapi.testclient import TestClient
|
2 |
-
from app import app
|
3 |
-
import PIL
|
4 |
-
from PIL import Image
|
5 |
-
from io import BytesIO
|
6 |
-
|
7 |
-
client = TestClient(app)
|
8 |
-
|
9 |
-
def test_read_main():
|
10 |
-
response = client.get("/")
|
11 |
-
assert response.status_code == 200
|
12 |
-
assert response.json() == {"message": "Welcome to Image Generator"}
|
13 |
-
|
14 |
-
|
15 |
-
def test_generate_image_unauthorized():
|
16 |
-
response = client.post("/generate-image/", json={})
|
17 |
-
assert response.status_code == 401
|
18 |
-
assert response.json() == {"detail": "Not authenticated"}
|
19 |
-
|
20 |
-
|
21 |
-
def test_generate_image_authorized():
|
22 |
-
response = client.post(
|
23 |
-
"/generate-image/", json={"prompt": "a cute cat"}
|
24 |
-
)
|
25 |
-
assert response.status_code == 200
|
26 |
-
|
27 |
-
image = Image.open(BytesIO(response.content))
|
28 |
-
assert type(image) == PIL.JpegImagePlugin.JpegImageFile
|
29 |
-
|
30 |
-
test_generate_image_authorized()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_diffuser.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
|
4 |
-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
5 |
-
import time
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torchvision.transforms as transforms
|
10 |
-
import torchvision.utils as vutils
|
11 |
-
from diffusers import AutoencoderKL
|
12 |
-
|
13 |
-
from denoiser import Denoiser
|
14 |
-
from diffusion import DiffusionGenerator, DiffusionTransformer, LTDConfig
|
15 |
-
from PIL.Image import Image
|
16 |
-
|
17 |
-
to_pil = transforms.ToPILImage()
|
18 |
-
|
19 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
20 |
-
|
21 |
-
|
22 |
-
def test_outputs(num_imgs=4):
|
23 |
-
model = Denoiser(
|
24 |
-
image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=768, dropout=0.1, n_layers=12
|
25 |
-
)
|
26 |
-
x = torch.rand(num_imgs, 4, 32, 32)
|
27 |
-
noise_level = torch.rand(num_imgs, 1)
|
28 |
-
label = torch.rand(num_imgs, 768)
|
29 |
-
|
30 |
-
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
|
31 |
-
|
32 |
-
with torch.no_grad():
|
33 |
-
start_time = time.time()
|
34 |
-
output = model(x, noise_level, label)
|
35 |
-
end_time = time.time()
|
36 |
-
|
37 |
-
execution_time = end_time - start_time
|
38 |
-
print(f"Model execution took {execution_time:.4f} seconds.")
|
39 |
-
|
40 |
-
assert output.shape == torch.Size([num_imgs, 4, 32, 32])
|
41 |
-
print("Basic tests passed.")
|
42 |
-
|
43 |
-
# model = Denoiser(image_size=16, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=6)
|
44 |
-
# x = torch.rand(8, 4, 32, 32)
|
45 |
-
# noise_level = torch.rand(8, 1)
|
46 |
-
# label = torch.rand(8, 768)
|
47 |
-
|
48 |
-
# with torch.no_grad():
|
49 |
-
# output = model(x, noise_level, label)
|
50 |
-
|
51 |
-
# assert output.shape == torch.Size([8, 4, 32, 32])
|
52 |
-
# print("Uspscale tests passed.")
|
53 |
-
|
54 |
-
|
55 |
-
def test_diffusion_generator():
|
56 |
-
model_dtype = torch.float32 ##float 16 will not work on cpu
|
57 |
-
num_imgs = 1
|
58 |
-
nrow = int(np.sqrt(num_imgs))
|
59 |
-
|
60 |
-
denoiser = Denoiser(
|
61 |
-
image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=3
|
62 |
-
)
|
63 |
-
print(f"Model has {sum(p.numel() for p in denoiser.parameters())} parameters")
|
64 |
-
|
65 |
-
denoiser.to(model_dtype)
|
66 |
-
|
67 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=model_dtype).to(device)
|
68 |
-
|
69 |
-
labels = torch.rand(num_imgs, 768)
|
70 |
-
|
71 |
-
diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype)
|
72 |
-
|
73 |
-
out, _ = diffuser.generate(
|
74 |
-
labels=labels,
|
75 |
-
num_imgs=num_imgs,
|
76 |
-
class_guidance=3,
|
77 |
-
seed=1,
|
78 |
-
n_iter=5,
|
79 |
-
exponent=1,
|
80 |
-
scale_factor=8,
|
81 |
-
sharp_f=0,
|
82 |
-
bright_f=0,
|
83 |
-
)
|
84 |
-
|
85 |
-
out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
|
86 |
-
out.save("test.png")
|
87 |
-
print("Images generated at test.png")
|
88 |
-
|
89 |
-
|
90 |
-
def test_full_generation_pipeline():
|
91 |
-
ltdconfig = LTDConfig()
|
92 |
-
diffusion_transformer = DiffusionTransformer(ltdconfig)
|
93 |
-
|
94 |
-
out = diffusion_transformer.generate_image_from_text(prompt="a cute cat")
|
95 |
-
print(out)
|
96 |
-
assert type(out) == Image
|
97 |
-
|
98 |
-
|
99 |
-
# TODO: should add tests for train loop and data processing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/__init__.py
DELETED
File without changes
|
tld/app.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import os
|
3 |
-
from typing import Optional
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torchvision.transforms as transforms
|
7 |
-
from fastapi import FastAPI, HTTPException, status
|
8 |
-
from fastapi.responses import StreamingResponse
|
9 |
-
from fastapi.security import OAuth2PasswordBearer
|
10 |
-
from pydantic import BaseModel
|
11 |
-
|
12 |
-
from diffusion import DiffusionTransformer, LTDConfig
|
13 |
-
|
14 |
-
# Get the directory of the script
|
15 |
-
script_directory = os.path.dirname(os.path.realpath(__file__))
|
16 |
-
# Specify the directory where the cache will be stored (same folder as the script)
|
17 |
-
cache_directory = os.path.join(script_directory, "cache")
|
18 |
-
home_directory = os.path.join(script_directory, "home")
|
19 |
-
# Create the cache directory if it doesn't exist
|
20 |
-
os.makedirs(cache_directory, exist_ok=True)
|
21 |
-
os.makedirs(home_directory, exist_ok=True)
|
22 |
-
|
23 |
-
os.environ["TRANSFORMERS_CACHE"] = cache_directory
|
24 |
-
os.environ["HF_HOME"] = home_directory
|
25 |
-
|
26 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
27 |
-
to_pil = transforms.ToPILImage()
|
28 |
-
ltdconfig = LTDConfig()
|
29 |
-
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
|
30 |
-
app = FastAPI()
|
31 |
-
|
32 |
-
class ImageRequest(BaseModel):
|
33 |
-
prompt: str
|
34 |
-
class_guidance: Optional[int] = 6
|
35 |
-
seed: Optional[int] = 11
|
36 |
-
num_imgs: Optional[int] = 1
|
37 |
-
img_size: Optional[int] = 32
|
38 |
-
|
39 |
-
|
40 |
-
@app.get("/")
|
41 |
-
def read_root():
|
42 |
-
return {"message": "Welcome to Image Generator"}
|
43 |
-
|
44 |
-
|
45 |
-
@app.post("/generate-image/")
|
46 |
-
async def generate_image(request: ImageRequest):
|
47 |
-
try:
|
48 |
-
img = diffusion_transformer.generate_image_from_text(
|
49 |
-
prompt=request.prompt,
|
50 |
-
class_guidance=request.class_guidance,
|
51 |
-
seed=request.seed,
|
52 |
-
num_imgs=request.num_imgs,
|
53 |
-
img_size=request.img_size,
|
54 |
-
)
|
55 |
-
|
56 |
-
# Convert PIL image to byte stream suitable for HTTP response
|
57 |
-
img_byte_arr = io.BytesIO()
|
58 |
-
img.save(img_byte_arr, format="JPEG")
|
59 |
-
img_byte_arr.seek(0)
|
60 |
-
|
61 |
-
return StreamingResponse(img_byte_arr, media_type="image/jpeg")
|
62 |
-
except Exception as e:
|
63 |
-
raise HTTPException(status_code=500, detail=str(e))
|
64 |
-
|
65 |
-
|
66 |
-
# build job to test and deploy the API on a docker image (maybe in Azure?)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/data.py
DELETED
@@ -1,243 +0,0 @@
|
|
1 |
-
####data util to get and preprocess data from a text and image pair to latents and text embeddings.
|
2 |
-
### all that is required is a csv file with an image url and text caption:
|
3 |
-
#!pip install datasets img2dataset accelerate diffusers
|
4 |
-
#!pip install git+https://github.com/openai/CLIP.git
|
5 |
-
|
6 |
-
import json
|
7 |
-
import os
|
8 |
-
from dataclasses import dataclass
|
9 |
-
from typing import List, Union
|
10 |
-
|
11 |
-
import clip
|
12 |
-
import h5py
|
13 |
-
import numpy as np
|
14 |
-
import pandas as pd
|
15 |
-
import torch
|
16 |
-
import torchvision.transforms as transforms
|
17 |
-
import webdataset as wds
|
18 |
-
from diffusers import AutoencoderKL
|
19 |
-
from img2dataset import download
|
20 |
-
from torch import Tensor, nn
|
21 |
-
from torch.utils.data import DataLoader
|
22 |
-
from tqdm import tqdm
|
23 |
-
|
24 |
-
|
25 |
-
@torch.no_grad()
|
26 |
-
def encode_text(label: Union[str, List[str]], model: nn.Module, device: str) -> Tensor:
|
27 |
-
text_tokens = clip.tokenize(label, truncate=True).to(device)
|
28 |
-
text_encoding = model.encode_text(text_tokens)
|
29 |
-
return text_encoding.cpu()
|
30 |
-
|
31 |
-
|
32 |
-
@torch.no_grad()
|
33 |
-
def encode_image(img: Tensor, vae: AutoencoderKL) -> Tensor:
|
34 |
-
x = img.to("cuda").to(torch.float16)
|
35 |
-
|
36 |
-
x = x * 2 - 1 # to make it between -1 and 1.
|
37 |
-
encoded = vae.encode(x, return_dict=False)[0].sample()
|
38 |
-
return encoded.cpu()
|
39 |
-
|
40 |
-
|
41 |
-
@torch.no_grad()
|
42 |
-
def decode_latents(out_latents: torch.FloatTensor, vae: AutoencoderKL) -> Tensor:
|
43 |
-
# expected to be in the unscaled latent space
|
44 |
-
out = vae.decode(out_latents.cuda())[0].cpu()
|
45 |
-
|
46 |
-
return ((out + 1) / 2).clip(0, 1)
|
47 |
-
|
48 |
-
|
49 |
-
def quantize_latents(lat: Tensor, clip_val: float = 20) -> Tensor:
|
50 |
-
"""scale and quantize latents to unit8"""
|
51 |
-
lat_norm = lat.clip(-clip_val, clip_val) / clip_val
|
52 |
-
return (((lat_norm + 1) / 2) * 255).to(torch.uint8)
|
53 |
-
|
54 |
-
|
55 |
-
def dequantize_latents(lat: Tensor, clip_val: float = 20) -> Tensor:
|
56 |
-
lat_norm = (lat.to(torch.float16) / 255) * 2 - 1
|
57 |
-
return lat_norm * clip_val
|
58 |
-
|
59 |
-
|
60 |
-
def append_to_dataset(dataset: h5py.File, new_data: Tensor) -> None:
|
61 |
-
"""Appends new data to an HDF5 dataset."""
|
62 |
-
new_size = dataset.shape[0] + new_data.shape[0]
|
63 |
-
dataset.resize(new_size, axis=0)
|
64 |
-
dataset[-new_data.shape[0] :] = new_data
|
65 |
-
|
66 |
-
|
67 |
-
def get_text_and_latent_embeddings_hdf5(
|
68 |
-
dataloader: DataLoader, vae: AutoencoderKL, model: nn.Module, drive_save_path: str
|
69 |
-
) -> None:
|
70 |
-
"""Process img/text inptus that outputs an latent and text embeddings and text_prompts, saving encodings as float16."""
|
71 |
-
|
72 |
-
img_latent_path = os.path.join(drive_save_path, "image_latents.hdf5")
|
73 |
-
text_embed_path = os.path.join(drive_save_path, "text_encodings.hdf5")
|
74 |
-
metadata_csv_path = os.path.join(drive_save_path, "metadata.csv")
|
75 |
-
|
76 |
-
with h5py.File(img_latent_path, "a") as img_file, h5py.File(text_embed_path, "a") as text_file:
|
77 |
-
if "image_latents" not in img_file:
|
78 |
-
img_ds = img_file.create_dataset(
|
79 |
-
"image_latents",
|
80 |
-
shape=(0, 4, 32, 32),
|
81 |
-
maxshape=(None, 4, 32, 32),
|
82 |
-
dtype="float16",
|
83 |
-
chunks=True,
|
84 |
-
)
|
85 |
-
else:
|
86 |
-
img_ds = img_file["image_latents"]
|
87 |
-
|
88 |
-
if "text_encodings" not in text_file:
|
89 |
-
text_ds = text_file.create_dataset(
|
90 |
-
"text_encodings", shape=(0, 768), maxshape=(None, 768), dtype="float16", chunks=True
|
91 |
-
)
|
92 |
-
else:
|
93 |
-
text_ds = text_file["text_encodings"]
|
94 |
-
|
95 |
-
for img, (label, url) in tqdm(dataloader):
|
96 |
-
text_encoding = encode_text(label, model).cpu().numpy().astype(np.float16)
|
97 |
-
img_encoding = encode_image(img, vae).cpu().numpy().astype(np.float16)
|
98 |
-
|
99 |
-
append_to_dataset(img_ds, img_encoding)
|
100 |
-
append_to_dataset(text_ds, text_encoding)
|
101 |
-
|
102 |
-
metadata_df = pd.DataFrame({"text": label, "url": url})
|
103 |
-
if os.path.exists(metadata_csv_path):
|
104 |
-
metadata_df.to_csv(metadata_csv_path, mode="a", header=False, index=False)
|
105 |
-
else:
|
106 |
-
metadata_df.to_csv(metadata_csv_path, mode="w", header=True, index=False)
|
107 |
-
|
108 |
-
|
109 |
-
def download_and_process_data(
|
110 |
-
latent_save_path="latents",
|
111 |
-
raw_imgs_save_path="raw_imgs",
|
112 |
-
csv_path="imgs.csv",
|
113 |
-
image_size=256,
|
114 |
-
bs=64,
|
115 |
-
caption_col="captions",
|
116 |
-
url_col="url",
|
117 |
-
download_data=True,
|
118 |
-
number_sample_per_shard=10000,
|
119 |
-
):
|
120 |
-
if not os.path.exists(raw_imgs_save_path):
|
121 |
-
os.mkdir(raw_imgs_save_path)
|
122 |
-
|
123 |
-
if not os.path.exists(latent_save_path):
|
124 |
-
os.mkdir(latent_save_path)
|
125 |
-
|
126 |
-
if download_data:
|
127 |
-
download(
|
128 |
-
processes_count=8,
|
129 |
-
thread_count=64,
|
130 |
-
url_list=csv_path,
|
131 |
-
image_size=image_size,
|
132 |
-
output_folder=raw_imgs_save_path,
|
133 |
-
output_format="webdataset",
|
134 |
-
input_format="csv",
|
135 |
-
url_col=url_col,
|
136 |
-
caption_col=caption_col,
|
137 |
-
enable_wandb=False,
|
138 |
-
number_sample_per_shard=number_sample_per_shard,
|
139 |
-
distributor="multiprocessing",
|
140 |
-
resize_mode="center_crop",
|
141 |
-
)
|
142 |
-
|
143 |
-
files = os.listdir(raw_imgs_save_path)
|
144 |
-
tar_files = [os.path.join(raw_imgs_save_path, file) for file in files if file.endswith(".tar")]
|
145 |
-
print(tar_files)
|
146 |
-
dataset = wds.WebDataset(tar_files)
|
147 |
-
|
148 |
-
transform = transforms.Compose(
|
149 |
-
[
|
150 |
-
transforms.ToTensor(),
|
151 |
-
]
|
152 |
-
)
|
153 |
-
|
154 |
-
# output is (img_tensor, (caption , url_col)) per batch:
|
155 |
-
dataset = (
|
156 |
-
dataset.decode("pil")
|
157 |
-
.to_tuple("jpg;png", "json")
|
158 |
-
.map_tuple(transform, lambda x: (x["caption"], x[url_col]))
|
159 |
-
)
|
160 |
-
|
161 |
-
dataloader = DataLoader(dataset, batch_size=bs, shuffle=False)
|
162 |
-
|
163 |
-
model, _ = clip.load("ViT-L/14")
|
164 |
-
|
165 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
166 |
-
vae = vae.to("cuda")
|
167 |
-
model.to("cuda")
|
168 |
-
|
169 |
-
print("Starting to encode latents and text:")
|
170 |
-
get_text_and_latent_embeddings_hdf5(dataloader, vae, model, latent_save_path)
|
171 |
-
print("Finished encode latents and text:")
|
172 |
-
|
173 |
-
|
174 |
-
@dataclass
|
175 |
-
class DataConfiguration:
|
176 |
-
data_link: str
|
177 |
-
caption_col: str = "caption"
|
178 |
-
url_col: str = "url"
|
179 |
-
latent_save_path: str = "latents_folder"
|
180 |
-
raw_imgs_save_path: str = "raw_imgs_folder"
|
181 |
-
use_drive: bool = False
|
182 |
-
initial_csv_path: str = "imgs.csv"
|
183 |
-
number_sample_per_shard: int = 10000
|
184 |
-
image_size: int = 256
|
185 |
-
batch_size: int = 64
|
186 |
-
download_data: bool = True
|
187 |
-
|
188 |
-
|
189 |
-
if __name__ == "__main__":
|
190 |
-
use_wandb = False
|
191 |
-
|
192 |
-
if use_wandb:
|
193 |
-
import wandb
|
194 |
-
|
195 |
-
os.environ["WANDB_API_KEY"] = "key"
|
196 |
-
#!wandb login
|
197 |
-
|
198 |
-
data_link = "https://huggingface.co/datasets/zzliang/GRIT/resolve/main/grit-20m/coyo_0_snappy.parquet?download=true"
|
199 |
-
|
200 |
-
data_config = DataConfiguration(
|
201 |
-
data_link=data_link,
|
202 |
-
latent_save_path="latent_folder",
|
203 |
-
raw_imgs_save_path="raw_imgs_folder",
|
204 |
-
download_data=False,
|
205 |
-
number_sample_per_shard=1000,
|
206 |
-
)
|
207 |
-
|
208 |
-
if use_wandb:
|
209 |
-
wandb.init(project="image_vae_processing", entity="apapiu", config=data_config)
|
210 |
-
|
211 |
-
if not os.path.exists(data_config.latent_save_path):
|
212 |
-
os.mkdir(data_config.latent_save_path)
|
213 |
-
|
214 |
-
config_file_path = os.path.join(data_config.latent_save_path, "config.json")
|
215 |
-
with open(config_file_path, "w") as f:
|
216 |
-
json.dump(data_config.__dict__, f)
|
217 |
-
|
218 |
-
print("Config saved to:", config_file_path)
|
219 |
-
|
220 |
-
df = pd.read_parquet(data_link)
|
221 |
-
###add additional data cleaning here...should I
|
222 |
-
df = df.iloc[:3000]
|
223 |
-
df[["key", "url", "caption"]].to_csv("imgs.csv", index=None)
|
224 |
-
|
225 |
-
if data_config.use_drive:
|
226 |
-
from google.colab import drive
|
227 |
-
|
228 |
-
drive.mount("/content/drive")
|
229 |
-
|
230 |
-
download_and_process_data(
|
231 |
-
latent_save_path=data_config.latent_save_path,
|
232 |
-
raw_imgs_save_path=data_config.raw_imgs_save_path,
|
233 |
-
csv_path=data_config.initial_csv_path,
|
234 |
-
image_size=data_config.image_size,
|
235 |
-
bs=data_config.batch_size,
|
236 |
-
caption_col=data_config.caption_col,
|
237 |
-
url_col=data_config.url_col,
|
238 |
-
download_data=data_config.download_data,
|
239 |
-
number_sample_per_shard=data_config.number_sample_per_shard,
|
240 |
-
)
|
241 |
-
|
242 |
-
if use_wandb:
|
243 |
-
wandb.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/denoiser.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
"""transformer based denoiser"""
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from einops.layers.torch import Rearrange
|
5 |
-
from torch import nn
|
6 |
-
|
7 |
-
from transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding
|
8 |
-
|
9 |
-
|
10 |
-
class DenoiserTransBlock(nn.Module):
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
patch_size: int,
|
14 |
-
img_size: int,
|
15 |
-
embed_dim: int,
|
16 |
-
dropout: float,
|
17 |
-
n_layers: int,
|
18 |
-
mlp_multiplier: int = 4,
|
19 |
-
n_channels: int = 4,
|
20 |
-
):
|
21 |
-
super().__init__()
|
22 |
-
|
23 |
-
self.patch_size = patch_size
|
24 |
-
self.img_size = img_size
|
25 |
-
self.n_channels = n_channels
|
26 |
-
self.embed_dim = embed_dim
|
27 |
-
self.dropout = dropout
|
28 |
-
self.n_layers = n_layers
|
29 |
-
self.mlp_multiplier = mlp_multiplier
|
30 |
-
|
31 |
-
seq_len = int((self.img_size / self.patch_size) * (self.img_size / self.patch_size))
|
32 |
-
patch_dim = self.n_channels * self.patch_size * self.patch_size
|
33 |
-
|
34 |
-
self.patchify_and_embed = nn.Sequential(
|
35 |
-
nn.Conv2d(
|
36 |
-
self.n_channels,
|
37 |
-
patch_dim,
|
38 |
-
kernel_size=self.patch_size,
|
39 |
-
stride=self.patch_size,
|
40 |
-
),
|
41 |
-
Rearrange("bs d h w -> bs (h w) d"),
|
42 |
-
nn.LayerNorm(patch_dim),
|
43 |
-
nn.Linear(patch_dim, self.embed_dim),
|
44 |
-
nn.LayerNorm(self.embed_dim),
|
45 |
-
)
|
46 |
-
|
47 |
-
self.rearrange2 = Rearrange(
|
48 |
-
"b (h w) (c p1 p2) -> b c (h p1) (w p2)",
|
49 |
-
h=int(self.img_size / self.patch_size),
|
50 |
-
p1=self.patch_size,
|
51 |
-
p2=self.patch_size,
|
52 |
-
)
|
53 |
-
|
54 |
-
self.pos_embed = nn.Embedding(seq_len, self.embed_dim)
|
55 |
-
self.register_buffer("precomputed_pos_enc", torch.arange(0, seq_len).long())
|
56 |
-
|
57 |
-
self.decoder_blocks = nn.ModuleList(
|
58 |
-
[
|
59 |
-
DecoderBlock(
|
60 |
-
embed_dim=self.embed_dim,
|
61 |
-
mlp_multiplier=self.mlp_multiplier,
|
62 |
-
# note that this is a non-causal block since we are
|
63 |
-
# denoising the entire image no need for masking
|
64 |
-
is_causal=False,
|
65 |
-
dropout_level=self.dropout,
|
66 |
-
mlp_class=MLPSepConv,
|
67 |
-
)
|
68 |
-
for _ in range(self.n_layers)
|
69 |
-
]
|
70 |
-
)
|
71 |
-
|
72 |
-
self.out_proj = nn.Sequential(nn.Linear(self.embed_dim, patch_dim), self.rearrange2)
|
73 |
-
|
74 |
-
def forward(self, x, cond):
|
75 |
-
x = self.patchify_and_embed(x)
|
76 |
-
pos_enc = self.precomputed_pos_enc[: x.size(1)].expand(x.size(0), -1)
|
77 |
-
x = x + self.pos_embed(pos_enc)
|
78 |
-
|
79 |
-
for block in self.decoder_blocks:
|
80 |
-
x = block(x, cond)
|
81 |
-
|
82 |
-
return self.out_proj(x)
|
83 |
-
|
84 |
-
|
85 |
-
class Denoiser(nn.Module):
|
86 |
-
def __init__(
|
87 |
-
self,
|
88 |
-
image_size: int,
|
89 |
-
noise_embed_dims: int,
|
90 |
-
patch_size: int,
|
91 |
-
embed_dim: int,
|
92 |
-
dropout: float,
|
93 |
-
n_layers: int,
|
94 |
-
text_emb_size: int = 768,
|
95 |
-
):
|
96 |
-
super().__init__()
|
97 |
-
|
98 |
-
self.image_size = image_size
|
99 |
-
self.noise_embed_dims = noise_embed_dims
|
100 |
-
self.embed_dim = embed_dim
|
101 |
-
|
102 |
-
self.fourier_feats = nn.Sequential(
|
103 |
-
SinusoidalEmbedding(embedding_dims=noise_embed_dims),
|
104 |
-
nn.Linear(noise_embed_dims, self.embed_dim),
|
105 |
-
nn.GELU(),
|
106 |
-
nn.Linear(self.embed_dim, self.embed_dim),
|
107 |
-
)
|
108 |
-
|
109 |
-
self.denoiser_trans_block = DenoiserTransBlock(patch_size, image_size, embed_dim, dropout, n_layers)
|
110 |
-
self.norm = nn.LayerNorm(self.embed_dim)
|
111 |
-
self.label_proj = nn.Linear(text_emb_size, self.embed_dim)
|
112 |
-
|
113 |
-
def forward(self, x, noise_level, label):
|
114 |
-
noise_level = self.fourier_feats(noise_level).unsqueeze(1)
|
115 |
-
|
116 |
-
label = self.label_proj(label).unsqueeze(1)
|
117 |
-
|
118 |
-
noise_label_emb = torch.cat([noise_level, label], dim=1) # bs, 2, d
|
119 |
-
noise_label_emb = self.norm(noise_label_emb)
|
120 |
-
|
121 |
-
x = self.denoiser_trans_block(x, noise_label_emb)
|
122 |
-
|
123 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/diffusion.py
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
|
3 |
-
import clip
|
4 |
-
import numpy as np
|
5 |
-
import requests
|
6 |
-
import torch
|
7 |
-
import torchvision.transforms as transforms
|
8 |
-
import torchvision.utils as vutils
|
9 |
-
from diffusers import AutoencoderKL
|
10 |
-
from torch import Tensor
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
from denoiser import Denoiser
|
14 |
-
|
15 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
-
to_pil = transforms.ToPILImage()
|
17 |
-
|
18 |
-
|
19 |
-
@dataclass
|
20 |
-
class DiffusionGenerator:
|
21 |
-
model: Denoiser
|
22 |
-
vae: AutoencoderKL
|
23 |
-
device: torch.device
|
24 |
-
model_dtype: torch.dtype = torch.float32
|
25 |
-
|
26 |
-
@torch.no_grad()
|
27 |
-
def generate(
|
28 |
-
self,
|
29 |
-
labels: Tensor, # embeddings to condition on
|
30 |
-
n_iter: int = 30,
|
31 |
-
num_imgs: int = 16,
|
32 |
-
class_guidance: float = 3,
|
33 |
-
seed: int = 10,
|
34 |
-
scale_factor: int = 8, # latent scaling before decoding - should be ~ std of latent space
|
35 |
-
img_size: int = 32, # height, width of latent
|
36 |
-
sharp_f: float = 0.1,
|
37 |
-
bright_f: float = 0.1,
|
38 |
-
exponent: float = 1,
|
39 |
-
seeds: Tensor | None = None,
|
40 |
-
noise_levels=None,
|
41 |
-
use_ddpm_plus: bool = True,
|
42 |
-
):
|
43 |
-
"""Generate images via reverse diffusion.
|
44 |
-
if use_ddpm_plus=True uses Algorithm 2 DPM-Solver++(2M) here: https://arxiv.org/pdf/2211.01095.pdf
|
45 |
-
else use ddim with alpha = 1-sigma
|
46 |
-
"""
|
47 |
-
if noise_levels is None:
|
48 |
-
noise_levels = (1 - torch.pow(torch.arange(0, 1, 1 / n_iter), exponent)).tolist()
|
49 |
-
noise_levels[0] = 0.99
|
50 |
-
|
51 |
-
if use_ddpm_plus:
|
52 |
-
lambdas = [np.log((1 - sigma) / sigma) for sigma in noise_levels] # log snr
|
53 |
-
hs = [lambdas[i] - lambdas[i - 1] for i in range(1, len(lambdas))]
|
54 |
-
rs = [hs[i - 1] / hs[i] for i in range(1, len(hs))]
|
55 |
-
|
56 |
-
x_t = self.initialize_image(seeds, num_imgs, img_size, seed)
|
57 |
-
|
58 |
-
labels = torch.cat([labels, torch.zeros_like(labels)])
|
59 |
-
self.model.eval()
|
60 |
-
|
61 |
-
x0_pred_prev = None
|
62 |
-
|
63 |
-
for i in tqdm(range(len(noise_levels) - 1)):
|
64 |
-
curr_noise, next_noise = noise_levels[i], noise_levels[i + 1]
|
65 |
-
|
66 |
-
x0_pred = self.pred_image(x_t, labels, curr_noise, class_guidance)
|
67 |
-
|
68 |
-
if x0_pred_prev is None:
|
69 |
-
x_t = ((curr_noise - next_noise) * x0_pred + next_noise * x_t) / curr_noise
|
70 |
-
else:
|
71 |
-
if use_ddpm_plus:
|
72 |
-
# x0_pred is a combination of the two previous x0_pred:
|
73 |
-
D = (1 + 1 / (2 * rs[i - 1])) * x0_pred - (1 / (2 * rs[i - 1])) * x0_pred_prev
|
74 |
-
else:
|
75 |
-
# ddim:
|
76 |
-
D = x0_pred
|
77 |
-
|
78 |
-
x_t = ((curr_noise - next_noise) * D + next_noise * x_t) / curr_noise
|
79 |
-
|
80 |
-
x0_pred_prev = x0_pred
|
81 |
-
|
82 |
-
x0_pred = self.pred_image(x_t, labels, next_noise, class_guidance)
|
83 |
-
|
84 |
-
# shifting latents works a bit like an image editor:
|
85 |
-
x0_pred[:, 3, :, :] += sharp_f
|
86 |
-
x0_pred[:, 0, :, :] += bright_f
|
87 |
-
|
88 |
-
x0_pred_img = self.vae.decode((x0_pred * scale_factor).to(self.model_dtype))[0].cpu()
|
89 |
-
return x0_pred_img, x0_pred
|
90 |
-
|
91 |
-
def pred_image(self, noisy_image, labels, noise_level, class_guidance):
|
92 |
-
num_imgs = noisy_image.size(0)
|
93 |
-
noises = torch.full((2 * num_imgs, 1), noise_level)
|
94 |
-
x0_pred = self.model(
|
95 |
-
torch.cat([noisy_image, noisy_image]),
|
96 |
-
noises.to(self.device, self.model_dtype),
|
97 |
-
labels.to(self.device, self.model_dtype),
|
98 |
-
)
|
99 |
-
x0_pred = self.apply_classifier_free_guidance(x0_pred, num_imgs, class_guidance)
|
100 |
-
return x0_pred
|
101 |
-
|
102 |
-
def initialize_image(self, seeds, num_imgs, img_size, seed):
|
103 |
-
"""Initialize the seed tensor."""
|
104 |
-
if seeds is None:
|
105 |
-
generator = torch.Generator(device=self.device)
|
106 |
-
generator.manual_seed(seed)
|
107 |
-
return torch.randn(
|
108 |
-
num_imgs,
|
109 |
-
4,
|
110 |
-
img_size,
|
111 |
-
img_size,
|
112 |
-
dtype=self.model_dtype,
|
113 |
-
device=self.device,
|
114 |
-
generator=generator,
|
115 |
-
)
|
116 |
-
else:
|
117 |
-
return seeds.to(self.device, self.model_dtype)
|
118 |
-
|
119 |
-
def apply_classifier_free_guidance(self, x0_pred, num_imgs, class_guidance):
|
120 |
-
"""Apply classifier-free guidance to the predictions."""
|
121 |
-
x0_pred_label, x0_pred_no_label = x0_pred[:num_imgs], x0_pred[num_imgs:]
|
122 |
-
return class_guidance * x0_pred_label + (1 - class_guidance) * x0_pred_no_label
|
123 |
-
|
124 |
-
|
125 |
-
@dataclass
|
126 |
-
class LTDConfig:
|
127 |
-
vae_scale_factor: float = 8
|
128 |
-
img_size: int = 32
|
129 |
-
model_dtype: torch.dtype = torch.float32
|
130 |
-
file_url: str = None # = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
|
131 |
-
local_filename: str = "state_dict_378000.pth"
|
132 |
-
vae_name: str = "ByteDance/SDXL-Lightning"
|
133 |
-
clip_model_name: str = "ViT-L/14"
|
134 |
-
denoiser: Denoiser = Denoiser(
|
135 |
-
image_size=32,
|
136 |
-
noise_embed_dims=256,
|
137 |
-
patch_size=2,
|
138 |
-
embed_dim=256,
|
139 |
-
dropout=0,
|
140 |
-
n_layers=4,
|
141 |
-
)
|
142 |
-
|
143 |
-
|
144 |
-
def download_file(url, filename):
|
145 |
-
with requests.get(url, stream=True) as r:
|
146 |
-
r.raise_for_status()
|
147 |
-
with open(filename, "wb") as f:
|
148 |
-
for chunk in r.iter_content(chunk_size=8192):
|
149 |
-
f.write(chunk)
|
150 |
-
|
151 |
-
|
152 |
-
@torch.no_grad()
|
153 |
-
def encode_text(label, model):
|
154 |
-
text_tokens = clip.tokenize(label, truncate=True).to(device)
|
155 |
-
text_encoding = model.encode_text(text_tokens)
|
156 |
-
return text_encoding.cpu()
|
157 |
-
|
158 |
-
|
159 |
-
class DiffusionTransformer:
|
160 |
-
def __init__(self, config: LTDConfig):
|
161 |
-
denoiser = config.denoiser.to(config.model_dtype)
|
162 |
-
|
163 |
-
if config.file_url is not None:
|
164 |
-
print(f"Downloading model from {config.file_url}")
|
165 |
-
download_file(config.file_url, config.local_filename)
|
166 |
-
state_dict = torch.load(config.local_filename, map_location=torch.device("cpu"))
|
167 |
-
denoiser.load_state_dict(state_dict)
|
168 |
-
|
169 |
-
denoiser = denoiser.to(device)
|
170 |
-
|
171 |
-
vae = AutoencoderKL.from_pretrained(config.vae_name, torch_dtype=config.model_dtype).to(device)
|
172 |
-
|
173 |
-
self.clip_model, preprocess = clip.load(config.clip_model_name)
|
174 |
-
self.clip_model = self.clip_model.to(device)
|
175 |
-
|
176 |
-
self.diffuser = DiffusionGenerator(denoiser, vae, device, config.model_dtype)
|
177 |
-
|
178 |
-
def generate_image_from_text(
|
179 |
-
self, prompt: str, class_guidance=6, seed=11, num_imgs=1, img_size=32, n_iter=15
|
180 |
-
):
|
181 |
-
nrow = int(np.sqrt(num_imgs))
|
182 |
-
|
183 |
-
cur_prompts = [prompt] * num_imgs
|
184 |
-
labels = encode_text(cur_prompts, self.clip_model)
|
185 |
-
out, out_latent = self.diffuser.generate(
|
186 |
-
labels=labels,
|
187 |
-
num_imgs=num_imgs,
|
188 |
-
class_guidance=class_guidance,
|
189 |
-
seed=seed,
|
190 |
-
n_iter=n_iter,
|
191 |
-
exponent=1,
|
192 |
-
scale_factor=8,
|
193 |
-
sharp_f=0,
|
194 |
-
bright_f=0,
|
195 |
-
)
|
196 |
-
|
197 |
-
out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
|
198 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/gen_img.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import asyncio
|
3 |
-
import os
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torchvision.transforms as transforms
|
7 |
-
|
8 |
-
from diffusion import DiffusionTransformer, LTDConfig
|
9 |
-
|
10 |
-
# Get the directory of the script
|
11 |
-
script_directory = os.path.dirname(os.path.realpath(__file__))
|
12 |
-
# Specify the directory where the cache will be stored (same folder as the script)
|
13 |
-
cache_directory = os.path.join(script_directory, "cache")
|
14 |
-
home_directory = os.path.join(script_directory, "home")
|
15 |
-
# Create the cache directory if it doesn't exist
|
16 |
-
os.makedirs(cache_directory, exist_ok=True)
|
17 |
-
os.makedirs(home_directory, exist_ok=True)
|
18 |
-
|
19 |
-
os.environ["TRANSFORMERS_CACHE"] = cache_directory
|
20 |
-
os.environ["HF_HOME"] = home_directory
|
21 |
-
|
22 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23 |
-
to_pil = transforms.ToPILImage()
|
24 |
-
ltdconfig = LTDConfig()
|
25 |
-
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
|
26 |
-
|
27 |
-
async def generate_image(prompt):
|
28 |
-
try:
|
29 |
-
img = diffusion_transformer.generate_image_from_text(
|
30 |
-
prompt=prompt,
|
31 |
-
class_guidance=6,
|
32 |
-
seed=11,
|
33 |
-
num_imgs=1,
|
34 |
-
img_size=32,
|
35 |
-
)
|
36 |
-
|
37 |
-
img.save("generated_img.png")
|
38 |
-
except Exception as e:
|
39 |
-
print(e)
|
40 |
-
|
41 |
-
asyncio.run(generate_image("a cute cat"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/gradio_app.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from io import BytesIO
|
3 |
-
|
4 |
-
import gradio as gr
|
5 |
-
import requests
|
6 |
-
from PIL import Image
|
7 |
-
|
8 |
-
# runpod_id = os.environ['RUNPOD_ID']
|
9 |
-
# token_id = os.environ['AUTH_TOKEN']
|
10 |
-
# url = f'https://{runpod_id}-8000.proxy.runpod.net/generate-image/'
|
11 |
-
|
12 |
-
url = os.getenv("API_URL")
|
13 |
-
token_id = os.getenv("API_TOKEN")
|
14 |
-
|
15 |
-
|
16 |
-
def generate_image_from_text(prompt, class_guidance):
|
17 |
-
headers = {"Authorization": f"Bearer {token_id}"}
|
18 |
-
|
19 |
-
data = {"prompt": prompt, "class_guidance": class_guidance, "seed": 11, "num_imgs": 1, "img_size": 32}
|
20 |
-
|
21 |
-
response = requests.post(url, json=data, headers=headers)
|
22 |
-
|
23 |
-
if response.status_code == 200:
|
24 |
-
image = Image.open(BytesIO(response.content))
|
25 |
-
else:
|
26 |
-
print("Failed to fetch image:", response.status_code, response.text)
|
27 |
-
|
28 |
-
return image
|
29 |
-
|
30 |
-
|
31 |
-
iface = gr.Interface(
|
32 |
-
fn=generate_image_from_text,
|
33 |
-
inputs=["text", "slider"],
|
34 |
-
outputs="image",
|
35 |
-
title="Text-to-Image Generator",
|
36 |
-
description="Enter a text prompt to generate an image.",
|
37 |
-
)
|
38 |
-
|
39 |
-
# Launch the app
|
40 |
-
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/train.py
DELETED
@@ -1,208 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
|
3 |
-
import copy
|
4 |
-
from dataclasses import asdict, dataclass
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torchvision
|
9 |
-
import torchvision.utils as vutils
|
10 |
-
import wandb
|
11 |
-
from accelerate import Accelerator
|
12 |
-
from diffusers import AutoencoderKL
|
13 |
-
from PIL.Image import Image
|
14 |
-
from torch import Tensor, nn
|
15 |
-
from torch.utils.data import DataLoader, TensorDataset
|
16 |
-
from tqdm import tqdm
|
17 |
-
|
18 |
-
from denoiser import Denoiser
|
19 |
-
from diffusion import DiffusionGenerator
|
20 |
-
|
21 |
-
|
22 |
-
def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image:
|
23 |
-
class_guidance = 4.5
|
24 |
-
seed = 10
|
25 |
-
out, _ = diffuser.generate(
|
26 |
-
labels=torch.repeat_interleave(labels, 8, dim=0),
|
27 |
-
num_imgs=64,
|
28 |
-
class_guidance=class_guidance,
|
29 |
-
seed=seed,
|
30 |
-
n_iter=40,
|
31 |
-
exponent=1,
|
32 |
-
sharp_f=0.1,
|
33 |
-
)
|
34 |
-
|
35 |
-
out = to_pil((vutils.make_grid((out + 1) / 2, nrow=8, padding=4)).float().clip(0, 1))
|
36 |
-
out.save(f"emb_val_cfg:{class_guidance}_seed:{seed}.png")
|
37 |
-
|
38 |
-
return out
|
39 |
-
|
40 |
-
|
41 |
-
def count_parameters(model: nn.Module):
|
42 |
-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
43 |
-
|
44 |
-
|
45 |
-
def count_parameters_per_layer(model: nn.Module):
|
46 |
-
for name, param in model.named_parameters():
|
47 |
-
print(f"{name}: {param.numel()} parameters")
|
48 |
-
|
49 |
-
|
50 |
-
to_pil = torchvision.transforms.ToPILImage()
|
51 |
-
|
52 |
-
|
53 |
-
def update_ema(ema_model: nn.Module, model: nn.Module, alpha: float = 0.999):
|
54 |
-
with torch.no_grad():
|
55 |
-
for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
|
56 |
-
ema_param.data.mul_(alpha).add_(model_param.data, alpha=1 - alpha)
|
57 |
-
|
58 |
-
|
59 |
-
@dataclass
|
60 |
-
class ModelConfig:
|
61 |
-
embed_dim: int = 512
|
62 |
-
n_layers: int = 6
|
63 |
-
clip_embed_size: int = 768
|
64 |
-
scaling_factor: int = 8
|
65 |
-
patch_size: int = 2
|
66 |
-
image_size: int = 32
|
67 |
-
n_channels: int = 4
|
68 |
-
dropout: float = 0
|
69 |
-
mlp_multiplier: int = 4
|
70 |
-
batch_size: int = 128
|
71 |
-
class_guidance: int = 3
|
72 |
-
lr: float = 3e-4
|
73 |
-
n_epoch: int = 100
|
74 |
-
alpha: float = 0.999
|
75 |
-
noise_embed_dims: int = 128
|
76 |
-
diffusion_n_iter: int = 35
|
77 |
-
from_scratch: bool = True
|
78 |
-
run_id: str = ""
|
79 |
-
model_name: str = ""
|
80 |
-
beta_a: float = 0.75
|
81 |
-
beta_b: float = 0.75
|
82 |
-
save_and_eval_every_iters: int = 1000
|
83 |
-
|
84 |
-
|
85 |
-
@dataclass
|
86 |
-
class DataConfig:
|
87 |
-
latent_path: str # path to a numpy file containing latents
|
88 |
-
text_emb_path: str
|
89 |
-
val_path: str
|
90 |
-
|
91 |
-
|
92 |
-
def main(config: ModelConfig, dataconfig: DataConfig) -> None:
|
93 |
-
"""main train loop to be used with accelerate"""
|
94 |
-
|
95 |
-
accelerator = Accelerator(mixed_precision="fp16", log_with="wandb")
|
96 |
-
|
97 |
-
accelerator.print("Loading Data:")
|
98 |
-
latent_train_data = torch.tensor(np.load(dataconfig.latent_path), dtype=torch.float32)
|
99 |
-
train_label_embeddings = torch.tensor(np.load(dataconfig.text_emb_path), dtype=torch.float32)
|
100 |
-
emb_val = torch.tensor(np.load(dataconfig.val_path), dtype=torch.float32)
|
101 |
-
dataset = TensorDataset(latent_train_data, train_label_embeddings)
|
102 |
-
train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
|
103 |
-
|
104 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
105 |
-
|
106 |
-
if accelerator.is_main_process:
|
107 |
-
vae = vae.to(accelerator.device)
|
108 |
-
|
109 |
-
model = Denoiser(
|
110 |
-
image_size=config.image_size,
|
111 |
-
noise_embed_dims=config.noise_embed_dims,
|
112 |
-
patch_size=config.patch_size,
|
113 |
-
embed_dim=config.embed_dim,
|
114 |
-
dropout=config.dropout,
|
115 |
-
n_layers=config.n_layers,
|
116 |
-
)
|
117 |
-
|
118 |
-
loss_fn = nn.MSELoss()
|
119 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
|
120 |
-
|
121 |
-
accelerator.print("Compiling model:")
|
122 |
-
model = torch.compile(model)
|
123 |
-
|
124 |
-
if not config.from_scratch:
|
125 |
-
accelerator.print("Loading Model:")
|
126 |
-
wandb.restore(
|
127 |
-
config.model_name, run_path=f"apapiu/cifar_diffusion/runs/{config.run_id}", replace=True
|
128 |
-
)
|
129 |
-
full_state_dict = torch.load(config.model_name)
|
130 |
-
model.load_state_dict(full_state_dict["model_ema"])
|
131 |
-
optimizer.load_state_dict(full_state_dict["opt_state"])
|
132 |
-
global_step = full_state_dict["global_step"]
|
133 |
-
else:
|
134 |
-
global_step = 0
|
135 |
-
|
136 |
-
if accelerator.is_local_main_process:
|
137 |
-
ema_model = copy.deepcopy(model).to(accelerator.device)
|
138 |
-
diffuser = DiffusionGenerator(ema_model, vae, accelerator.device, torch.float32)
|
139 |
-
|
140 |
-
accelerator.print("model prep")
|
141 |
-
model, train_loader, optimizer = accelerator.prepare(model, train_loader, optimizer)
|
142 |
-
|
143 |
-
accelerator.init_trackers(project_name="cifar_diffusion", config=asdict(config))
|
144 |
-
|
145 |
-
accelerator.print(count_parameters(model))
|
146 |
-
accelerator.print(count_parameters_per_layer(model))
|
147 |
-
|
148 |
-
### Train:
|
149 |
-
for i in range(1, config.n_epoch + 1):
|
150 |
-
accelerator.print(f"epoch: {i}")
|
151 |
-
|
152 |
-
for x, y in tqdm(train_loader):
|
153 |
-
x = x / config.scaling_factor
|
154 |
-
|
155 |
-
noise_level = torch.tensor(
|
156 |
-
np.random.beta(config.beta_a, config.beta_b, len(x)), device=accelerator.device
|
157 |
-
)
|
158 |
-
signal_level = 1 - noise_level
|
159 |
-
noise = torch.randn_like(x)
|
160 |
-
|
161 |
-
x_noisy = noise_level.view(-1, 1, 1, 1) * noise + signal_level.view(-1, 1, 1, 1) * x
|
162 |
-
|
163 |
-
x_noisy = x_noisy.float()
|
164 |
-
noise_level = noise_level.float()
|
165 |
-
label = y
|
166 |
-
|
167 |
-
prob = 0.15
|
168 |
-
mask = torch.rand(y.size(0), device=accelerator.device) < prob
|
169 |
-
label[mask] = 0 # OR replacement_vector
|
170 |
-
|
171 |
-
if global_step % config.save_and_eval_every_iters == 0:
|
172 |
-
accelerator.wait_for_everyone()
|
173 |
-
if accelerator.is_main_process:
|
174 |
-
##eval and saving:
|
175 |
-
out = eval_gen(diffuser=diffuser, labels=emb_val)
|
176 |
-
out.save("img.jpg")
|
177 |
-
accelerator.log({f"step: {global_step}": wandb.Image("img.jpg")})
|
178 |
-
|
179 |
-
opt_unwrapped = accelerator.unwrap_model(optimizer)
|
180 |
-
full_state_dict = {
|
181 |
-
"model_ema": ema_model.state_dict(),
|
182 |
-
"opt_state": opt_unwrapped.state_dict(),
|
183 |
-
"global_step": global_step,
|
184 |
-
}
|
185 |
-
accelerator.save(full_state_dict, config.model_name)
|
186 |
-
wandb.save(config.model_name)
|
187 |
-
|
188 |
-
model.train()
|
189 |
-
|
190 |
-
with accelerator.accumulate():
|
191 |
-
###train loop:
|
192 |
-
optimizer.zero_grad()
|
193 |
-
|
194 |
-
pred = model(x_noisy, noise_level.view(-1, 1), label)
|
195 |
-
loss = loss_fn(pred, x)
|
196 |
-
accelerator.log({"train_loss": loss.item()}, step=global_step)
|
197 |
-
accelerator.backward(loss)
|
198 |
-
optimizer.step()
|
199 |
-
|
200 |
-
if accelerator.is_main_process:
|
201 |
-
update_ema(ema_model, model, alpha=config.alpha)
|
202 |
-
|
203 |
-
global_step += 1
|
204 |
-
accelerator.end_training()
|
205 |
-
|
206 |
-
|
207 |
-
# args = (config, data_path, val_path)
|
208 |
-
# notebook_launcher(training_loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tld/transformer_blocks.py
DELETED
@@ -1,139 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
from einops import rearrange
|
5 |
-
|
6 |
-
|
7 |
-
class SinusoidalEmbedding(nn.Module):
|
8 |
-
def __init__(self, emb_min_freq=1.0, emb_max_freq=1000.0, embedding_dims=32):
|
9 |
-
super(SinusoidalEmbedding, self).__init__()
|
10 |
-
|
11 |
-
frequencies = torch.exp(
|
12 |
-
torch.linspace(np.log(emb_min_freq), np.log(emb_max_freq), embedding_dims // 2)
|
13 |
-
)
|
14 |
-
|
15 |
-
self.register_buffer("angular_speeds", 2.0 * torch.pi * frequencies)
|
16 |
-
|
17 |
-
def forward(self, x):
|
18 |
-
embeddings = torch.cat(
|
19 |
-
[torch.sin(self.angular_speeds * x), torch.cos(self.angular_speeds * x)], dim=-1
|
20 |
-
)
|
21 |
-
return embeddings
|
22 |
-
|
23 |
-
|
24 |
-
class MHAttention(nn.Module):
|
25 |
-
def __init__(self, is_causal=False, dropout_level=0.0, n_heads=4):
|
26 |
-
super().__init__()
|
27 |
-
self.is_causal = is_causal
|
28 |
-
self.dropout_level = dropout_level
|
29 |
-
self.n_heads = n_heads
|
30 |
-
|
31 |
-
def forward(self, q, k, v, attn_mask=None):
|
32 |
-
assert q.size(-1) == k.size(-1)
|
33 |
-
assert k.size(-2) == v.size(-2)
|
34 |
-
|
35 |
-
q, k, v = [rearrange(x, "bs n (h d) -> bs h n d", h=self.n_heads) for x in [q, k, v]]
|
36 |
-
|
37 |
-
out = nn.functional.scaled_dot_product_attention(
|
38 |
-
q,
|
39 |
-
k,
|
40 |
-
v,
|
41 |
-
attn_mask=attn_mask,
|
42 |
-
is_causal=self.is_causal,
|
43 |
-
dropout_p=self.dropout_level if self.training else 0,
|
44 |
-
)
|
45 |
-
|
46 |
-
out = rearrange(out, "bs h n d -> bs n (h d)", h=self.n_heads)
|
47 |
-
|
48 |
-
return out
|
49 |
-
|
50 |
-
|
51 |
-
class SelfAttention(nn.Module):
|
52 |
-
def __init__(self, embed_dim, is_causal=False, dropout_level=0.0, n_heads=4):
|
53 |
-
super().__init__()
|
54 |
-
self.qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
|
55 |
-
self.mha = MHAttention(is_causal, dropout_level, n_heads)
|
56 |
-
|
57 |
-
def forward(self, x):
|
58 |
-
q, k, v = self.qkv_linear(x).chunk(3, dim=2)
|
59 |
-
return self.mha(q, k, v)
|
60 |
-
|
61 |
-
|
62 |
-
class CrossAttention(nn.Module):
|
63 |
-
def __init__(self, embed_dim, is_causal=False, dropout_level=0, n_heads=4):
|
64 |
-
super().__init__()
|
65 |
-
self.kv_linear = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
|
66 |
-
self.q_linear = nn.Linear(embed_dim, embed_dim, bias=False)
|
67 |
-
self.mha = MHAttention(is_causal, dropout_level, n_heads)
|
68 |
-
|
69 |
-
def forward(self, x, y):
|
70 |
-
q = self.q_linear(x)
|
71 |
-
k, v = self.kv_linear(y).chunk(2, dim=2)
|
72 |
-
return self.mha(q, k, v)
|
73 |
-
|
74 |
-
|
75 |
-
class MLP(nn.Module):
|
76 |
-
def __init__(self, embed_dim, mlp_multiplier, dropout_level):
|
77 |
-
super().__init__()
|
78 |
-
self.mlp = nn.Sequential(
|
79 |
-
nn.Linear(embed_dim, mlp_multiplier * embed_dim),
|
80 |
-
nn.GELU(),
|
81 |
-
nn.Linear(mlp_multiplier * embed_dim, embed_dim),
|
82 |
-
nn.Dropout(dropout_level),
|
83 |
-
)
|
84 |
-
|
85 |
-
def forward(self, x):
|
86 |
-
return self.mlp(x)
|
87 |
-
|
88 |
-
|
89 |
-
class MLPSepConv(nn.Module):
|
90 |
-
def __init__(self, embed_dim, mlp_multiplier, dropout_level):
|
91 |
-
"""see: https://github.com/ofsoundof/LocalViT"""
|
92 |
-
super().__init__()
|
93 |
-
self.mlp = nn.Sequential(
|
94 |
-
# this Conv with kernel size 1 is equivalent to the Linear layer in a "regular" transformer MLP
|
95 |
-
nn.Conv2d(embed_dim, mlp_multiplier * embed_dim, kernel_size=1, padding="same"),
|
96 |
-
nn.Conv2d(
|
97 |
-
mlp_multiplier * embed_dim,
|
98 |
-
mlp_multiplier * embed_dim,
|
99 |
-
kernel_size=3,
|
100 |
-
padding="same",
|
101 |
-
groups=mlp_multiplier * embed_dim,
|
102 |
-
), # <- depthwise conv
|
103 |
-
nn.GELU(),
|
104 |
-
nn.Conv2d(mlp_multiplier * embed_dim, embed_dim, kernel_size=1, padding="same"),
|
105 |
-
nn.Dropout(dropout_level),
|
106 |
-
)
|
107 |
-
|
108 |
-
def forward(self, x):
|
109 |
-
w = h = int(np.sqrt(x.size(1))) # only square images for now
|
110 |
-
x = rearrange(x, "bs (h w) d -> bs d h w", h=h, w=w)
|
111 |
-
x = self.mlp(x)
|
112 |
-
x = rearrange(x, "bs d h w -> bs (h w) d")
|
113 |
-
return x
|
114 |
-
|
115 |
-
|
116 |
-
class DecoderBlock(nn.Module):
|
117 |
-
def __init__(
|
118 |
-
self,
|
119 |
-
embed_dim: int,
|
120 |
-
is_causal: bool,
|
121 |
-
mlp_multiplier: int,
|
122 |
-
dropout_level: float,
|
123 |
-
mlp_class: type[MLP] | type[MLPSepConv],
|
124 |
-
):
|
125 |
-
super().__init__()
|
126 |
-
self.self_attention = SelfAttention(embed_dim, is_causal, dropout_level, n_heads=embed_dim // 64)
|
127 |
-
self.cross_attention = CrossAttention(
|
128 |
-
embed_dim, is_causal=False, dropout_level=0, n_heads=embed_dim // 64
|
129 |
-
)
|
130 |
-
self.mlp = mlp_class(embed_dim, mlp_multiplier, dropout_level)
|
131 |
-
self.norm1 = nn.LayerNorm(embed_dim)
|
132 |
-
self.norm2 = nn.LayerNorm(embed_dim)
|
133 |
-
self.norm3 = nn.LayerNorm(embed_dim)
|
134 |
-
|
135 |
-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
136 |
-
x = self.self_attention(self.norm1(x)) + x
|
137 |
-
x = self.cross_attention(self.norm2(x), y) + x
|
138 |
-
x = self.mlp(self.norm3(x)) + x
|
139 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|