BeveledCube commited on
Commit
f977a8b
·
1 Parent(s): b38b153

Fuc complicated stuff

Browse files
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1,4 +0,0 @@
1
- __pycache__
2
- node_modules
3
- poetry.lock
4
- cache
 
 
 
 
 
Dockerfile.fastapi DELETED
@@ -1,14 +0,0 @@
1
- FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y git
6
-
7
- COPY . /app
8
-
9
- RUN pip install --no-cache-dir -r requirements.txt
10
- RUN pip install --no-cache-dir uvicorn gunicorn fastapi pytest ruff pytest-asyncio httpx
11
-
12
- EXPOSE 80
13
-
14
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile.gradio DELETED
@@ -1,13 +0,0 @@
1
- FROM python:3.10-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y git
6
-
7
- COPY . /app
8
-
9
- RUN pip install --no-cache-dir gradio Pillow
10
-
11
- EXPOSE 80
12
-
13
- CMD ["python", "tld/gradio_app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 Alexandru Papiu
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -8,4 +8,5 @@ pinned: false
8
  app_file: main.py
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
8
  app_file: main.py
9
  ---
10
 
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+ [Other GH](https://github.com/apapiu/transformer_latent_diffusion)
docker-compose.yml DELETED
@@ -1,17 +0,0 @@
1
- version: '3.8'
2
- services:
3
- fastapi:
4
- image: apapiu89/tld-app:latest
5
- ports:
6
- - "80:80"
7
- environment:
8
- - API_TOKEN=${API_TOKEN}
9
-
10
- gradio:
11
- image: apapiu89/gradio-app:latest
12
- ports:
13
- - "7860:7860"
14
- environment:
15
- - API_URL=http://fastapi:80
16
- depends_on:
17
- - fastapi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
img_examples/a beautiful woman with blonde hair in her 50s_cfg_7_seed_11.png DELETED
Binary file (404 kB)
 
img_examples/a cute grey great owl_cfg_8_seed_11.png DELETED
Binary file (477 kB)
 
img_examples/a lake in mountains in the fall at sunset_cfg_7_seed_11.png DELETED
Binary file (419 kB)
 
img_examples/a woman cyborg with red curly hair, 8k_cfg_9.5_seed_11.png DELETED
Binary file (429 kB)
 
img_examples/an aerial view of manhattan, isometric view, as pantinted by mondrian_cfg_7_seed_11.png DELETED
Binary file (502 kB)
 
img_examples/isometric view of small japanese village with blooming trees_cfg_7_seed_11.png DELETED
Binary file (481 kB)
 
img_examples/painting of a cute fox in a suit in a field of poppies_cfg_8_seed_11.png DELETED
Binary file (462 kB)
 
img_examples/painting of a cyberpunk market_cfg_7_seed_11.png DELETED
Binary file (479 kB)
 
img_examples/watercolor of a cute cat riding a motorcycle_cfg_7_seed_11.png DELETED
Binary file (361 kB)
 
main.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+
5
+ # Find models in https://huggingface.co/models?pipeline_tag=text-to-image&library=diffusers&sort=trending
6
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
7
+
8
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
9
+ pipe = pipe.to("cuda")
10
+
11
+ prompt = "beautiful horse"
12
+
13
+ image = pipe(prompt).images[0]
14
+
15
+ print("[PROMPT]: ", prompt)
16
+ plt.imshow(image)
17
+ plt.axis('off')
og readme.md DELETED
@@ -1,211 +0,0 @@
1
- # Transformer Latent Diffusion
2
- Text to Image Latent Diffusion using a Transformer core in PyTorch.
3
-
4
- [Original Github](https://github.com/apapiu/transformer_latent_diffusion)
5
-
6
- **Try with own inputs**: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VaCe01YG9rnPwAfwVLBKdXEX7D_tk1U5?usp=sharing)
7
-
8
- Below are some random examples (at 256 resolution) from a 100MM model trained from scratch for 260k iterations (about 32 hours on 1 A100):
9
-
10
- <img width="760" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/e01e3094-2487-4c04-bc0f-d9b03eeaed00">
11
-
12
- #### Clip interpolation Examples:
13
-
14
- a photo of a cat → an anime drawing of a super saiyan cat, artstation:
15
-
16
- <img width="1361" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/a079458b-9bd5-4557-aa7a-5a3e78f31b53">
17
-
18
- a cute great gray owl → starry night by van gogh:
19
-
20
- <img width="1399" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/8731d87a-89fa-43a2-847d-c7ff772de286">
21
-
22
- Note that the model has not converged yet and could use more training.
23
-
24
- #### High(er) Resolution:
25
- By upsampling the positional encoding the model can also generate 512 or 1024 px images with minimal fine-tuning. See below for some examples of model fine-tuned on 100k extra 512 px images and 30k 1024 px images for about 2 hours on an A100. The images do sometimes lack global coherence at 1024 px - more to come here:
26
-
27
- <img width="600" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/adba64f0-b43c-423e-9a7d-033a4afea207">
28
- <img width="600" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/5a94515b-313e-420d-89d4-6bdc376d9a00">
29
-
30
-
31
-
32
- ### Intro:
33
-
34
- The main goal of this project is to build an accessible diffusion model in PyTorch that is:
35
- - fast (close to real time generation)
36
- - small (~100MM params)
37
- - reasonably good (of course not SOTA)
38
- - can be trained in a reasonable amount of time on a single GPU (under 50 hours on an A100 or equivalent).
39
- - simple self-contained codebase (model + train loop is about ~400 lines of PyTorch with little dependencies)
40
- - uses ~ 1 million images with a focus on data quality over quantity
41
-
42
- This is part II of a previous [project](https://github.com/apapiu/guided-diffusion-keras) I did where I trained a pixel level diffusion model in Keras. Even though this model outputs 4x higher resolution images (256px vs 64px), it's actually faster to both train and sample from, which shows the power of training in the latent space and speed of transformer architectures.
43
-
44
- ## Table of Contents:
45
- - [Codebase](#codebase)
46
- - [Usage](#usage)
47
- - [Examples](#examples)
48
- - [Data Processing](#data-processing)
49
- - [Architecture](#architecture)
50
- - [TO-DOs](#todos)
51
-
52
-
53
- ## Codebase:
54
- The code is written in pure PyTorch with as few dependencies as possible.
55
-
56
- - [transformer_blocks.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/transformer_blocks.py) - basic transformer building blocks relevant to the transformer denoiser
57
- - [denoiser.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/denoiser.py) - the architecture of the denoiser transformer
58
- - [train.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/train.py). The train loop uses `accelerate` so its training can scale to multiple GPUs if needed.
59
- - [diffusion.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/diffusion.py). Class to generate image from noise using reverse diffusion. Short (~60 lines) and self-contained.
60
- - [data.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/data.py). Data utils to download images/text and process necessary features for the diffusion model.
61
-
62
- ### Usage:
63
- If you have your own dataset of URLs + captions, the process to train a model on the data consists of two steps:
64
-
65
- 1. Use `train.download_and_process_data` to obtain the latent and text encodings as numpy files. See [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BPDFDBdsP9SSKBNEFJysmlBjfoxKK13r?usp=sharing) for a notebook example downloading and processing 2000 images from this HuggingFace [dataset](https://huggingface.co/datasets/zzliang/GRIT).
66
-
67
- 2. use the `train.main` function in an accelerate `notebook_launcher` - see [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1sKk0usxEF4bmdCDcNQJQNMt4l9qBOeAM?usp=sharing) for a colab notebook that trains a model on 100k images from scratch. Note that this downloads already pre-preprocessed latents and embeddings from [here](https://huggingface.co/apapiu/small_ldt/tree/main) but you could just use whatever `.npy` files you had saved from step 1.
68
-
69
- #### Fine-Tuning - TODO but it is the same as step 2 above except you train on a pre-trained model.
70
-
71
- ```python
72
- !wandb login
73
- import os
74
- from train import main, DataConfig, ModelConfig
75
- from accelerate import notebook_launcher
76
-
77
- data_config = DataConfig(latent_path='path/to/image_latents.npy',
78
- text_emb_path='path/to/text_encodings.npy',
79
- val_path='path/to/val_encodings.npy')
80
-
81
- model_config = ModelConfig(embed_dim=512, n_layers=6) #see ModelConfig for more params
82
-
83
- #run the training process on 2 GPUs:
84
- notebook_launcher(main, (model_config, data_config), num_processes=2)
85
- ```
86
-
87
- ### Dependencies:
88
- - `PyTorch` `numpy` `einops` for model building
89
- - `wandb` `tqdm` for logging + progress bars
90
- - `accelerate` for train loop and multi-GPU support
91
- - `img2dataset` `webdataset` `torchvision` for data downloading and image processing
92
- - `diffusers` `clip` for pretrained VAE and CLIP text model
93
-
94
- ### Codebases used for inspiration:
95
- - [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha)
96
- - [k-diffusion](https://github.com/crowsonkb/k-diffusion)
97
- - [nanoGPT](https://github.com/karpathy/nanoGPT/tree/master)
98
- - [LocalViT](https://github.com/ofsoundof/LocalViT)
99
-
100
- #### Speed:
101
-
102
- I try to speed up training and inference as much as possible by:
103
- - using mixed precision for training + [sdpa]
104
- - precompute all latent and text embeddings
105
- - using float16 precision for inference
106
- - using [sdpa] for the attention natively + torch.compile() (compile doesn't always work).
107
- - use a highly performant sampler (DPM-Solver++(2M)) that gets good results in ~ 15 steps.
108
- - TODO: would distillation or something like LCM work here?
109
- - TODO: use flash-attention2?
110
- - TODO: use smaller vae?
111
-
112
- The time to generate a batch of 36 images (15 iterations) on a:
113
- - T4: ~ 3.5 seconds
114
- - A100: ~ 0.6 seconds
115
- In fact on an A100 the vae becomes the bottleneck even though it is only used once.
116
-
117
-
118
- ## Examples:
119
-
120
- More examples generated with the 100MM model - click the photo to see the prompt and other params like cfg and seed:
121
- ![image](tld/img_examples/a%20cute%20grey%20great%20owl_cfg_8_seed_11.png)
122
- ![image](tld/img_examples/watercolor%20of%20a%20cute%20cat%20riding%20a%20motorcycle_cfg_7_seed_11.png)
123
- ![image](tld/img_examples/painting%20of%20a%20cyberpunk%20market_cfg_7_seed_11.png)
124
- ![image](tld/img_examples/isometric%20view%20of%20small%20japanese%20village%20with%20blooming%20trees_cfg_7_seed_11.png)
125
- ![image](tld/img_examples/a%20beautiful%20woman%20with%20blonde%20hair%20in%20her%2050s_cfg_7_seed_11.png)
126
- ![image](tld/img_examples/painting%20of%20a%20cute%20fox%20in%20a%20suit%20in%20a%20field%20of%20poppies_cfg_8_seed_11.png)
127
- ![image](tld/img_examples/an%20aerial%20view%20of%20manhattan%2C%20isometric%20view%2C%20as%20pantinted%20by%20mondrian_cfg_7_seed_11.png)
128
-
129
- ## Outpainting model:
130
-
131
- I also fine-tuned an outpaing model on top of the original 101MM model. I had to modify the original input conv2d patch to 8 channel and initialize the mask channels parameters to zero. The rest of the architecture remained the same.
132
-
133
- Below I apply the outpainting model repatedly to generate a somewhat consistent scenery based on the prompt "a cyberpunk marketplace":
134
-
135
- <img width="1440" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/4451719f-d45a-4a86-a7bb-06c021b34996">
136
-
137
- ## Data Processing:
138
-
139
- In [data.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/data.py), I have some helper functions to process images and captions. The flow is as follows:
140
- - Use `img2dataset` to download images from a dataframe containing URLs and captions.
141
- - Use `CLIP` to encode the prompts and the `VAE` to encode images to latents on a web2dataset data generator.
142
- - Save the latents and text embedding for future training.
143
-
144
- There are two advantages to this approach. One is that the VAE encoding is somewhat expensive, so doing it every epoch would affect training times. The other is that we can discard the images after processing. For `3*256*256` images, the latent dimension is `4*32*32`, so every latent is around 4KB (when quantized in uint8; see [here](https://pub.towardsai.net/stable-diffusion-based-image-compresssion-6f1f0a399202?gi=1f45c6522d3b)). This means that 1 million latents will be "only" 4GB in size, which is easy to handle even in RAM. Storing the raw images would have been 48x larger in size.
145
-
146
- ## Architecture:
147
-
148
- See [here](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/denoiser.py) for the denoiser class.
149
-
150
- The denoiser model is a Transformer-based model based on the archirtecture in [DiT](https://arxiv.org/abs/2203.02378) and [Pixart-Alpha](https://pixart-alpha.github.io/), albeit with quite a few modifications and simplifications. Using a Transformer as the denoiser is different from most diffusion models in that most other models used a CNN-based U-NET as the denoising backbone. I decided to use a Transformer for a few reasons. One was I just wanted to experiment and learn how to build and train Transformers from the ground up. Secondly, Transformers are fast both to train and to do inference on, and they will benefit most from future advances (both in hardware and in software) in performance.
151
-
152
- Transformers are not natively built for spatial data and at first I found a lot of the outputs to be very "patchy". To remediy that I added a depth-wise convolution in the FFN layer of the transformer (this was introduced in the [Local ViT](https://arxiv.org/abs/2104.05707) paper. This allows the model to mix pixels that are close to each other with very little added compute cost.
153
-
154
-
155
- ### Img+Text+Noise Encoding:
156
-
157
- The image latent inputs are `4*32*32` and we use a patch size of 2 to build 256 flattened `4*2*2=16` dimensional input "pixels". These are then projected into the embed dimensions are are fed through the transformer blocks.
158
-
159
- The text and noise conditioning is very simple - we concatenate a pooled CLIP text embedding (`ViT/L14` - 768-dimensional) and the sinusoidal noise embedding and feed it as input in the cross-attention layer in each transformer block. No unpooled CLIP embeddings are used.
160
-
161
- ### Training:
162
- The base model is 101MM parameters and has 12 layers and embedding dimension = 768. I train it with a batch size of 256 on a A100 and learning rate of `3e-4`. I used 1000 steps for warmup. Due to computational contraints I did not do any ablations for this configuration.
163
-
164
-
165
- ## Train and Diffusion Setup:
166
-
167
- We train a denoising transformer that takes the following three inputs:
168
- - `noise_level` (sampled from 0 to 1 with more values concentrated close to 0 - I use a beta distribution)
169
- - Image latent (x) corrupted with a level of random noise
170
- - For a given `noise_level` between 0 and 1, the corruption is as follows:
171
- - `x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal(0, 1)`
172
- - CLIP embeddings of a text prompt
173
- - You can think of this as a numerical representation of a text prompt.
174
- - We use the pooled text embedding here (768 dimensional for `ViT/L14`)
175
-
176
- The output is a prediction of the denoised image latent - call it `f(x_noisy)`.
177
-
178
- The model is trained to minimize the mean squared error `|f(x_noisy) - x|` between the prediction and actual image
179
- (you can also use absolute error here). Note that I don't reparameterize the loss in terms of the noise here to keep things simple.
180
-
181
- Using this model, we then iteratively generate an image from random noise as follows:
182
-
183
- for i in range(len(self.noise_levels) - 1):
184
-
185
- curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
186
-
187
- # Predict original denoised image:
188
- x0_pred = predict_x_zero(new_img, label, curr_noise)
189
-
190
- # New image at next_noise level is a weighted average of old image and predicted x0:
191
- new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
192
-
193
- The `predict_x_zero` method uses classifier free guidance by combining the conditional and unconditional
194
- prediction: `x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional`
195
-
196
- A bit of math: The approach above falls within the VDM parametrization see 3.1 in [Kingma et al.](https://arxiv.org/pdf/2107.00630.pdf):
197
-
198
- $$z_t = \alpha_t x + \sigma_t \epsilon, \epsilon \sim \mathcal{N}(0,1)$$
199
-
200
- Where $z_t$ is the noisy version of $x$ at time $t$.
201
-
202
- Generally, $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here, I chose $\alpha_t=1-\sigma_t$ so that we linearly interpolate between the image and random noise. Why? For one, it simplifies the updating equation quite a bit, and it's easier to understand what the noise to signal ratio will look like. I also found that the model produces sharper images faster - more validation here is needed. The updating equation above is the DDIM model for this parametrization, which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, and it generally takes fewer steps to achieve decent image quality.
203
-
204
- ## TODOS:
205
- - better config in the train file
206
- - how to speed up generation even more - LCMs or other sampling strategies?
207
- - add script to compute FID
208
-
209
-
210
-
211
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
old/main.py DELETED
@@ -1,37 +0,0 @@
1
- import os
2
- from transformers import CLIPProcessor, CLIPModel
3
- from PIL import Image
4
-
5
- # Get the directory of the script
6
- script_directory = os.path.dirname(os.path.realpath(__file__))
7
- # Specify the directory where the cache will be stored (same folder as the script)
8
- cache_directory = os.path.join(script_directory, "cache")
9
- # Create the cache directory if it doesn't exist
10
- os.makedirs(cache_directory, exist_ok=True)
11
-
12
- # Load the CLIP processor and model
13
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
14
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
15
-
16
- # Text description to generate image
17
- text = "a cat sitting on a table"
18
-
19
- # Tokenize text and get features
20
- inputs = clip_processor(text, return_tensors="pt", padding=True)
21
-
22
- # Generate image from text
23
- generated_image = clip_model.generate(
24
- inputs=inputs.input_ids,
25
- attention_mask=inputs.attention_mask,
26
- visual_input=None, # We don't provide image inputvi
27
- return_tensors="pt" # Return PyTorch tensor
28
- )
29
-
30
- # Convert the generated image tensor to a NumPy array
31
- generated_image_np = generated_image[0].cpu().numpy()
32
-
33
- # Save the generated image
34
- output_image_path = "generated_image.png"
35
- Image.fromarray(generated_image_np).save(output_image_path)
36
-
37
- print("Image generated and saved as:", output_image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml DELETED
@@ -1,23 +0,0 @@
1
- [tool.poetry]
2
- name = "img-gen"
3
- version = "0.0.1"
4
- description = ""
5
- authors = ["CubeBeveled <[email protected]>"]
6
- readme = "README.md"
7
-
8
- [tool.poetry.dependencies]
9
- python = "^3.11"
10
- torch = "^2.2.1"
11
- numpy = "^1.26.4"
12
- einops = "^0.7.0"
13
- torchvision = "^0.17.1"
14
- tqdm = "^4.66.2"
15
- diffusers = "^0.27.2"
16
- accelerate = "^0.28.0"
17
- transformers = "^4.39.1"
18
- pillow = "^10.2.0"
19
-
20
-
21
- [build-system]
22
- requires = ["poetry-core"]
23
- build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,11 +1,2 @@
1
- torch
2
- numpy
3
- einops
4
- torchvision
5
- tqdm
6
  diffusers
7
- accelerate
8
- transformers
9
- Pillow
10
- poetry
11
- git+https://github.com/openai/CLIP.git
 
 
 
 
 
 
1
  diffusers
2
+ transformers
 
 
 
 
setup.py DELETED
@@ -1,36 +0,0 @@
1
- from setuptools import setup, find_packages
2
-
3
-
4
- def load_requirements(filename="requirements.txt"):
5
- with open(filename, "r") as file:
6
- lines = [line.strip() for line in file.readlines() if line.strip() and not line.startswith("#")]
7
- return lines
8
-
9
-
10
- setup(
11
- name="tld",
12
- version="0.1.0",
13
- author="Alexandru Papiu",
14
- author_email="[email protected]",
15
- description="Transformer Latent Diffusion",
16
- url="https://github.com/apapiu/transformer_latent_diffusion",
17
- packages=find_packages(exclude=["tests*"]),
18
- classifiers=[
19
- "Programming Language :: Python :: 3",
20
- "License :: OSI Approved :: MIT License",
21
- "Operating System :: OS Independent",
22
- ],
23
- python_requires=">=3.6",
24
- install_requires=[
25
- "torch",
26
- "numpy",
27
- "einops",
28
- "torchvision",
29
- "tqdm",
30
- "diffusers",
31
- "accelerate",
32
- "transformers",
33
- "Pillow",
34
- "clip @ git+https://github.com/openai/CLIP.git",
35
- ],
36
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
start.sh CHANGED
@@ -1 +1 @@
1
- python tld/gen_img.py
 
1
+ python main.py
tests/__init__.py DELETED
File without changes
tests/client.js DELETED
@@ -1,15 +0,0 @@
1
- const axios = require("axios");
2
-
3
- const apiUrl = `http://de-fsn-4.halex.gg:25287/api`;
4
-
5
- const postData = {
6
- prompt: "Wassup my homie"
7
- };
8
-
9
- axios.post(apiUrl, postData)
10
- .then(response => {
11
- console.log("Response from API:", response.data);
12
- })
13
- .catch(error => {
14
- console.error("Error:", error.message);
15
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_api.py DELETED
@@ -1,30 +0,0 @@
1
- from fastapi.testclient import TestClient
2
- from app import app
3
- import PIL
4
- from PIL import Image
5
- from io import BytesIO
6
-
7
- client = TestClient(app)
8
-
9
- def test_read_main():
10
- response = client.get("/")
11
- assert response.status_code == 200
12
- assert response.json() == {"message": "Welcome to Image Generator"}
13
-
14
-
15
- def test_generate_image_unauthorized():
16
- response = client.post("/generate-image/", json={})
17
- assert response.status_code == 401
18
- assert response.json() == {"detail": "Not authenticated"}
19
-
20
-
21
- def test_generate_image_authorized():
22
- response = client.post(
23
- "/generate-image/", json={"prompt": "a cute cat"}
24
- )
25
- assert response.status_code == 200
26
-
27
- image = Image.open(BytesIO(response.content))
28
- assert type(image) == PIL.JpegImagePlugin.JpegImageFile
29
-
30
- test_generate_image_authorized()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_diffuser.py DELETED
@@ -1,99 +0,0 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5
- import time
6
-
7
- import numpy as np
8
- import torch
9
- import torchvision.transforms as transforms
10
- import torchvision.utils as vutils
11
- from diffusers import AutoencoderKL
12
-
13
- from denoiser import Denoiser
14
- from diffusion import DiffusionGenerator, DiffusionTransformer, LTDConfig
15
- from PIL.Image import Image
16
-
17
- to_pil = transforms.ToPILImage()
18
-
19
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
-
21
-
22
- def test_outputs(num_imgs=4):
23
- model = Denoiser(
24
- image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=768, dropout=0.1, n_layers=12
25
- )
26
- x = torch.rand(num_imgs, 4, 32, 32)
27
- noise_level = torch.rand(num_imgs, 1)
28
- label = torch.rand(num_imgs, 768)
29
-
30
- print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
31
-
32
- with torch.no_grad():
33
- start_time = time.time()
34
- output = model(x, noise_level, label)
35
- end_time = time.time()
36
-
37
- execution_time = end_time - start_time
38
- print(f"Model execution took {execution_time:.4f} seconds.")
39
-
40
- assert output.shape == torch.Size([num_imgs, 4, 32, 32])
41
- print("Basic tests passed.")
42
-
43
- # model = Denoiser(image_size=16, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=6)
44
- # x = torch.rand(8, 4, 32, 32)
45
- # noise_level = torch.rand(8, 1)
46
- # label = torch.rand(8, 768)
47
-
48
- # with torch.no_grad():
49
- # output = model(x, noise_level, label)
50
-
51
- # assert output.shape == torch.Size([8, 4, 32, 32])
52
- # print("Uspscale tests passed.")
53
-
54
-
55
- def test_diffusion_generator():
56
- model_dtype = torch.float32 ##float 16 will not work on cpu
57
- num_imgs = 1
58
- nrow = int(np.sqrt(num_imgs))
59
-
60
- denoiser = Denoiser(
61
- image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=3
62
- )
63
- print(f"Model has {sum(p.numel() for p in denoiser.parameters())} parameters")
64
-
65
- denoiser.to(model_dtype)
66
-
67
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=model_dtype).to(device)
68
-
69
- labels = torch.rand(num_imgs, 768)
70
-
71
- diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype)
72
-
73
- out, _ = diffuser.generate(
74
- labels=labels,
75
- num_imgs=num_imgs,
76
- class_guidance=3,
77
- seed=1,
78
- n_iter=5,
79
- exponent=1,
80
- scale_factor=8,
81
- sharp_f=0,
82
- bright_f=0,
83
- )
84
-
85
- out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
86
- out.save("test.png")
87
- print("Images generated at test.png")
88
-
89
-
90
- def test_full_generation_pipeline():
91
- ltdconfig = LTDConfig()
92
- diffusion_transformer = DiffusionTransformer(ltdconfig)
93
-
94
- out = diffusion_transformer.generate_image_from_text(prompt="a cute cat")
95
- print(out)
96
- assert type(out) == Image
97
-
98
-
99
- # TODO: should add tests for train loop and data processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/__init__.py DELETED
File without changes
tld/app.py DELETED
@@ -1,66 +0,0 @@
1
- import io
2
- import os
3
- from typing import Optional
4
-
5
- import torch
6
- import torchvision.transforms as transforms
7
- from fastapi import FastAPI, HTTPException, status
8
- from fastapi.responses import StreamingResponse
9
- from fastapi.security import OAuth2PasswordBearer
10
- from pydantic import BaseModel
11
-
12
- from diffusion import DiffusionTransformer, LTDConfig
13
-
14
- # Get the directory of the script
15
- script_directory = os.path.dirname(os.path.realpath(__file__))
16
- # Specify the directory where the cache will be stored (same folder as the script)
17
- cache_directory = os.path.join(script_directory, "cache")
18
- home_directory = os.path.join(script_directory, "home")
19
- # Create the cache directory if it doesn't exist
20
- os.makedirs(cache_directory, exist_ok=True)
21
- os.makedirs(home_directory, exist_ok=True)
22
-
23
- os.environ["TRANSFORMERS_CACHE"] = cache_directory
24
- os.environ["HF_HOME"] = home_directory
25
-
26
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
- to_pil = transforms.ToPILImage()
28
- ltdconfig = LTDConfig()
29
- diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
30
- app = FastAPI()
31
-
32
- class ImageRequest(BaseModel):
33
- prompt: str
34
- class_guidance: Optional[int] = 6
35
- seed: Optional[int] = 11
36
- num_imgs: Optional[int] = 1
37
- img_size: Optional[int] = 32
38
-
39
-
40
- @app.get("/")
41
- def read_root():
42
- return {"message": "Welcome to Image Generator"}
43
-
44
-
45
- @app.post("/generate-image/")
46
- async def generate_image(request: ImageRequest):
47
- try:
48
- img = diffusion_transformer.generate_image_from_text(
49
- prompt=request.prompt,
50
- class_guidance=request.class_guidance,
51
- seed=request.seed,
52
- num_imgs=request.num_imgs,
53
- img_size=request.img_size,
54
- )
55
-
56
- # Convert PIL image to byte stream suitable for HTTP response
57
- img_byte_arr = io.BytesIO()
58
- img.save(img_byte_arr, format="JPEG")
59
- img_byte_arr.seek(0)
60
-
61
- return StreamingResponse(img_byte_arr, media_type="image/jpeg")
62
- except Exception as e:
63
- raise HTTPException(status_code=500, detail=str(e))
64
-
65
-
66
- # build job to test and deploy the API on a docker image (maybe in Azure?)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/data.py DELETED
@@ -1,243 +0,0 @@
1
- ####data util to get and preprocess data from a text and image pair to latents and text embeddings.
2
- ### all that is required is a csv file with an image url and text caption:
3
- #!pip install datasets img2dataset accelerate diffusers
4
- #!pip install git+https://github.com/openai/CLIP.git
5
-
6
- import json
7
- import os
8
- from dataclasses import dataclass
9
- from typing import List, Union
10
-
11
- import clip
12
- import h5py
13
- import numpy as np
14
- import pandas as pd
15
- import torch
16
- import torchvision.transforms as transforms
17
- import webdataset as wds
18
- from diffusers import AutoencoderKL
19
- from img2dataset import download
20
- from torch import Tensor, nn
21
- from torch.utils.data import DataLoader
22
- from tqdm import tqdm
23
-
24
-
25
- @torch.no_grad()
26
- def encode_text(label: Union[str, List[str]], model: nn.Module, device: str) -> Tensor:
27
- text_tokens = clip.tokenize(label, truncate=True).to(device)
28
- text_encoding = model.encode_text(text_tokens)
29
- return text_encoding.cpu()
30
-
31
-
32
- @torch.no_grad()
33
- def encode_image(img: Tensor, vae: AutoencoderKL) -> Tensor:
34
- x = img.to("cuda").to(torch.float16)
35
-
36
- x = x * 2 - 1 # to make it between -1 and 1.
37
- encoded = vae.encode(x, return_dict=False)[0].sample()
38
- return encoded.cpu()
39
-
40
-
41
- @torch.no_grad()
42
- def decode_latents(out_latents: torch.FloatTensor, vae: AutoencoderKL) -> Tensor:
43
- # expected to be in the unscaled latent space
44
- out = vae.decode(out_latents.cuda())[0].cpu()
45
-
46
- return ((out + 1) / 2).clip(0, 1)
47
-
48
-
49
- def quantize_latents(lat: Tensor, clip_val: float = 20) -> Tensor:
50
- """scale and quantize latents to unit8"""
51
- lat_norm = lat.clip(-clip_val, clip_val) / clip_val
52
- return (((lat_norm + 1) / 2) * 255).to(torch.uint8)
53
-
54
-
55
- def dequantize_latents(lat: Tensor, clip_val: float = 20) -> Tensor:
56
- lat_norm = (lat.to(torch.float16) / 255) * 2 - 1
57
- return lat_norm * clip_val
58
-
59
-
60
- def append_to_dataset(dataset: h5py.File, new_data: Tensor) -> None:
61
- """Appends new data to an HDF5 dataset."""
62
- new_size = dataset.shape[0] + new_data.shape[0]
63
- dataset.resize(new_size, axis=0)
64
- dataset[-new_data.shape[0] :] = new_data
65
-
66
-
67
- def get_text_and_latent_embeddings_hdf5(
68
- dataloader: DataLoader, vae: AutoencoderKL, model: nn.Module, drive_save_path: str
69
- ) -> None:
70
- """Process img/text inptus that outputs an latent and text embeddings and text_prompts, saving encodings as float16."""
71
-
72
- img_latent_path = os.path.join(drive_save_path, "image_latents.hdf5")
73
- text_embed_path = os.path.join(drive_save_path, "text_encodings.hdf5")
74
- metadata_csv_path = os.path.join(drive_save_path, "metadata.csv")
75
-
76
- with h5py.File(img_latent_path, "a") as img_file, h5py.File(text_embed_path, "a") as text_file:
77
- if "image_latents" not in img_file:
78
- img_ds = img_file.create_dataset(
79
- "image_latents",
80
- shape=(0, 4, 32, 32),
81
- maxshape=(None, 4, 32, 32),
82
- dtype="float16",
83
- chunks=True,
84
- )
85
- else:
86
- img_ds = img_file["image_latents"]
87
-
88
- if "text_encodings" not in text_file:
89
- text_ds = text_file.create_dataset(
90
- "text_encodings", shape=(0, 768), maxshape=(None, 768), dtype="float16", chunks=True
91
- )
92
- else:
93
- text_ds = text_file["text_encodings"]
94
-
95
- for img, (label, url) in tqdm(dataloader):
96
- text_encoding = encode_text(label, model).cpu().numpy().astype(np.float16)
97
- img_encoding = encode_image(img, vae).cpu().numpy().astype(np.float16)
98
-
99
- append_to_dataset(img_ds, img_encoding)
100
- append_to_dataset(text_ds, text_encoding)
101
-
102
- metadata_df = pd.DataFrame({"text": label, "url": url})
103
- if os.path.exists(metadata_csv_path):
104
- metadata_df.to_csv(metadata_csv_path, mode="a", header=False, index=False)
105
- else:
106
- metadata_df.to_csv(metadata_csv_path, mode="w", header=True, index=False)
107
-
108
-
109
- def download_and_process_data(
110
- latent_save_path="latents",
111
- raw_imgs_save_path="raw_imgs",
112
- csv_path="imgs.csv",
113
- image_size=256,
114
- bs=64,
115
- caption_col="captions",
116
- url_col="url",
117
- download_data=True,
118
- number_sample_per_shard=10000,
119
- ):
120
- if not os.path.exists(raw_imgs_save_path):
121
- os.mkdir(raw_imgs_save_path)
122
-
123
- if not os.path.exists(latent_save_path):
124
- os.mkdir(latent_save_path)
125
-
126
- if download_data:
127
- download(
128
- processes_count=8,
129
- thread_count=64,
130
- url_list=csv_path,
131
- image_size=image_size,
132
- output_folder=raw_imgs_save_path,
133
- output_format="webdataset",
134
- input_format="csv",
135
- url_col=url_col,
136
- caption_col=caption_col,
137
- enable_wandb=False,
138
- number_sample_per_shard=number_sample_per_shard,
139
- distributor="multiprocessing",
140
- resize_mode="center_crop",
141
- )
142
-
143
- files = os.listdir(raw_imgs_save_path)
144
- tar_files = [os.path.join(raw_imgs_save_path, file) for file in files if file.endswith(".tar")]
145
- print(tar_files)
146
- dataset = wds.WebDataset(tar_files)
147
-
148
- transform = transforms.Compose(
149
- [
150
- transforms.ToTensor(),
151
- ]
152
- )
153
-
154
- # output is (img_tensor, (caption , url_col)) per batch:
155
- dataset = (
156
- dataset.decode("pil")
157
- .to_tuple("jpg;png", "json")
158
- .map_tuple(transform, lambda x: (x["caption"], x[url_col]))
159
- )
160
-
161
- dataloader = DataLoader(dataset, batch_size=bs, shuffle=False)
162
-
163
- model, _ = clip.load("ViT-L/14")
164
-
165
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
166
- vae = vae.to("cuda")
167
- model.to("cuda")
168
-
169
- print("Starting to encode latents and text:")
170
- get_text_and_latent_embeddings_hdf5(dataloader, vae, model, latent_save_path)
171
- print("Finished encode latents and text:")
172
-
173
-
174
- @dataclass
175
- class DataConfiguration:
176
- data_link: str
177
- caption_col: str = "caption"
178
- url_col: str = "url"
179
- latent_save_path: str = "latents_folder"
180
- raw_imgs_save_path: str = "raw_imgs_folder"
181
- use_drive: bool = False
182
- initial_csv_path: str = "imgs.csv"
183
- number_sample_per_shard: int = 10000
184
- image_size: int = 256
185
- batch_size: int = 64
186
- download_data: bool = True
187
-
188
-
189
- if __name__ == "__main__":
190
- use_wandb = False
191
-
192
- if use_wandb:
193
- import wandb
194
-
195
- os.environ["WANDB_API_KEY"] = "key"
196
- #!wandb login
197
-
198
- data_link = "https://huggingface.co/datasets/zzliang/GRIT/resolve/main/grit-20m/coyo_0_snappy.parquet?download=true"
199
-
200
- data_config = DataConfiguration(
201
- data_link=data_link,
202
- latent_save_path="latent_folder",
203
- raw_imgs_save_path="raw_imgs_folder",
204
- download_data=False,
205
- number_sample_per_shard=1000,
206
- )
207
-
208
- if use_wandb:
209
- wandb.init(project="image_vae_processing", entity="apapiu", config=data_config)
210
-
211
- if not os.path.exists(data_config.latent_save_path):
212
- os.mkdir(data_config.latent_save_path)
213
-
214
- config_file_path = os.path.join(data_config.latent_save_path, "config.json")
215
- with open(config_file_path, "w") as f:
216
- json.dump(data_config.__dict__, f)
217
-
218
- print("Config saved to:", config_file_path)
219
-
220
- df = pd.read_parquet(data_link)
221
- ###add additional data cleaning here...should I
222
- df = df.iloc[:3000]
223
- df[["key", "url", "caption"]].to_csv("imgs.csv", index=None)
224
-
225
- if data_config.use_drive:
226
- from google.colab import drive
227
-
228
- drive.mount("/content/drive")
229
-
230
- download_and_process_data(
231
- latent_save_path=data_config.latent_save_path,
232
- raw_imgs_save_path=data_config.raw_imgs_save_path,
233
- csv_path=data_config.initial_csv_path,
234
- image_size=data_config.image_size,
235
- bs=data_config.batch_size,
236
- caption_col=data_config.caption_col,
237
- url_col=data_config.url_col,
238
- download_data=data_config.download_data,
239
- number_sample_per_shard=data_config.number_sample_per_shard,
240
- )
241
-
242
- if use_wandb:
243
- wandb.finish()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/denoiser.py DELETED
@@ -1,123 +0,0 @@
1
- """transformer based denoiser"""
2
-
3
- import torch
4
- from einops.layers.torch import Rearrange
5
- from torch import nn
6
-
7
- from transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding
8
-
9
-
10
- class DenoiserTransBlock(nn.Module):
11
- def __init__(
12
- self,
13
- patch_size: int,
14
- img_size: int,
15
- embed_dim: int,
16
- dropout: float,
17
- n_layers: int,
18
- mlp_multiplier: int = 4,
19
- n_channels: int = 4,
20
- ):
21
- super().__init__()
22
-
23
- self.patch_size = patch_size
24
- self.img_size = img_size
25
- self.n_channels = n_channels
26
- self.embed_dim = embed_dim
27
- self.dropout = dropout
28
- self.n_layers = n_layers
29
- self.mlp_multiplier = mlp_multiplier
30
-
31
- seq_len = int((self.img_size / self.patch_size) * (self.img_size / self.patch_size))
32
- patch_dim = self.n_channels * self.patch_size * self.patch_size
33
-
34
- self.patchify_and_embed = nn.Sequential(
35
- nn.Conv2d(
36
- self.n_channels,
37
- patch_dim,
38
- kernel_size=self.patch_size,
39
- stride=self.patch_size,
40
- ),
41
- Rearrange("bs d h w -> bs (h w) d"),
42
- nn.LayerNorm(patch_dim),
43
- nn.Linear(patch_dim, self.embed_dim),
44
- nn.LayerNorm(self.embed_dim),
45
- )
46
-
47
- self.rearrange2 = Rearrange(
48
- "b (h w) (c p1 p2) -> b c (h p1) (w p2)",
49
- h=int(self.img_size / self.patch_size),
50
- p1=self.patch_size,
51
- p2=self.patch_size,
52
- )
53
-
54
- self.pos_embed = nn.Embedding(seq_len, self.embed_dim)
55
- self.register_buffer("precomputed_pos_enc", torch.arange(0, seq_len).long())
56
-
57
- self.decoder_blocks = nn.ModuleList(
58
- [
59
- DecoderBlock(
60
- embed_dim=self.embed_dim,
61
- mlp_multiplier=self.mlp_multiplier,
62
- # note that this is a non-causal block since we are
63
- # denoising the entire image no need for masking
64
- is_causal=False,
65
- dropout_level=self.dropout,
66
- mlp_class=MLPSepConv,
67
- )
68
- for _ in range(self.n_layers)
69
- ]
70
- )
71
-
72
- self.out_proj = nn.Sequential(nn.Linear(self.embed_dim, patch_dim), self.rearrange2)
73
-
74
- def forward(self, x, cond):
75
- x = self.patchify_and_embed(x)
76
- pos_enc = self.precomputed_pos_enc[: x.size(1)].expand(x.size(0), -1)
77
- x = x + self.pos_embed(pos_enc)
78
-
79
- for block in self.decoder_blocks:
80
- x = block(x, cond)
81
-
82
- return self.out_proj(x)
83
-
84
-
85
- class Denoiser(nn.Module):
86
- def __init__(
87
- self,
88
- image_size: int,
89
- noise_embed_dims: int,
90
- patch_size: int,
91
- embed_dim: int,
92
- dropout: float,
93
- n_layers: int,
94
- text_emb_size: int = 768,
95
- ):
96
- super().__init__()
97
-
98
- self.image_size = image_size
99
- self.noise_embed_dims = noise_embed_dims
100
- self.embed_dim = embed_dim
101
-
102
- self.fourier_feats = nn.Sequential(
103
- SinusoidalEmbedding(embedding_dims=noise_embed_dims),
104
- nn.Linear(noise_embed_dims, self.embed_dim),
105
- nn.GELU(),
106
- nn.Linear(self.embed_dim, self.embed_dim),
107
- )
108
-
109
- self.denoiser_trans_block = DenoiserTransBlock(patch_size, image_size, embed_dim, dropout, n_layers)
110
- self.norm = nn.LayerNorm(self.embed_dim)
111
- self.label_proj = nn.Linear(text_emb_size, self.embed_dim)
112
-
113
- def forward(self, x, noise_level, label):
114
- noise_level = self.fourier_feats(noise_level).unsqueeze(1)
115
-
116
- label = self.label_proj(label).unsqueeze(1)
117
-
118
- noise_label_emb = torch.cat([noise_level, label], dim=1) # bs, 2, d
119
- noise_label_emb = self.norm(noise_label_emb)
120
-
121
- x = self.denoiser_trans_block(x, noise_label_emb)
122
-
123
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/diffusion.py DELETED
@@ -1,198 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import clip
4
- import numpy as np
5
- import requests
6
- import torch
7
- import torchvision.transforms as transforms
8
- import torchvision.utils as vutils
9
- from diffusers import AutoencoderKL
10
- from torch import Tensor
11
- from tqdm import tqdm
12
-
13
- from denoiser import Denoiser
14
-
15
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
- to_pil = transforms.ToPILImage()
17
-
18
-
19
- @dataclass
20
- class DiffusionGenerator:
21
- model: Denoiser
22
- vae: AutoencoderKL
23
- device: torch.device
24
- model_dtype: torch.dtype = torch.float32
25
-
26
- @torch.no_grad()
27
- def generate(
28
- self,
29
- labels: Tensor, # embeddings to condition on
30
- n_iter: int = 30,
31
- num_imgs: int = 16,
32
- class_guidance: float = 3,
33
- seed: int = 10,
34
- scale_factor: int = 8, # latent scaling before decoding - should be ~ std of latent space
35
- img_size: int = 32, # height, width of latent
36
- sharp_f: float = 0.1,
37
- bright_f: float = 0.1,
38
- exponent: float = 1,
39
- seeds: Tensor | None = None,
40
- noise_levels=None,
41
- use_ddpm_plus: bool = True,
42
- ):
43
- """Generate images via reverse diffusion.
44
- if use_ddpm_plus=True uses Algorithm 2 DPM-Solver++(2M) here: https://arxiv.org/pdf/2211.01095.pdf
45
- else use ddim with alpha = 1-sigma
46
- """
47
- if noise_levels is None:
48
- noise_levels = (1 - torch.pow(torch.arange(0, 1, 1 / n_iter), exponent)).tolist()
49
- noise_levels[0] = 0.99
50
-
51
- if use_ddpm_plus:
52
- lambdas = [np.log((1 - sigma) / sigma) for sigma in noise_levels] # log snr
53
- hs = [lambdas[i] - lambdas[i - 1] for i in range(1, len(lambdas))]
54
- rs = [hs[i - 1] / hs[i] for i in range(1, len(hs))]
55
-
56
- x_t = self.initialize_image(seeds, num_imgs, img_size, seed)
57
-
58
- labels = torch.cat([labels, torch.zeros_like(labels)])
59
- self.model.eval()
60
-
61
- x0_pred_prev = None
62
-
63
- for i in tqdm(range(len(noise_levels) - 1)):
64
- curr_noise, next_noise = noise_levels[i], noise_levels[i + 1]
65
-
66
- x0_pred = self.pred_image(x_t, labels, curr_noise, class_guidance)
67
-
68
- if x0_pred_prev is None:
69
- x_t = ((curr_noise - next_noise) * x0_pred + next_noise * x_t) / curr_noise
70
- else:
71
- if use_ddpm_plus:
72
- # x0_pred is a combination of the two previous x0_pred:
73
- D = (1 + 1 / (2 * rs[i - 1])) * x0_pred - (1 / (2 * rs[i - 1])) * x0_pred_prev
74
- else:
75
- # ddim:
76
- D = x0_pred
77
-
78
- x_t = ((curr_noise - next_noise) * D + next_noise * x_t) / curr_noise
79
-
80
- x0_pred_prev = x0_pred
81
-
82
- x0_pred = self.pred_image(x_t, labels, next_noise, class_guidance)
83
-
84
- # shifting latents works a bit like an image editor:
85
- x0_pred[:, 3, :, :] += sharp_f
86
- x0_pred[:, 0, :, :] += bright_f
87
-
88
- x0_pred_img = self.vae.decode((x0_pred * scale_factor).to(self.model_dtype))[0].cpu()
89
- return x0_pred_img, x0_pred
90
-
91
- def pred_image(self, noisy_image, labels, noise_level, class_guidance):
92
- num_imgs = noisy_image.size(0)
93
- noises = torch.full((2 * num_imgs, 1), noise_level)
94
- x0_pred = self.model(
95
- torch.cat([noisy_image, noisy_image]),
96
- noises.to(self.device, self.model_dtype),
97
- labels.to(self.device, self.model_dtype),
98
- )
99
- x0_pred = self.apply_classifier_free_guidance(x0_pred, num_imgs, class_guidance)
100
- return x0_pred
101
-
102
- def initialize_image(self, seeds, num_imgs, img_size, seed):
103
- """Initialize the seed tensor."""
104
- if seeds is None:
105
- generator = torch.Generator(device=self.device)
106
- generator.manual_seed(seed)
107
- return torch.randn(
108
- num_imgs,
109
- 4,
110
- img_size,
111
- img_size,
112
- dtype=self.model_dtype,
113
- device=self.device,
114
- generator=generator,
115
- )
116
- else:
117
- return seeds.to(self.device, self.model_dtype)
118
-
119
- def apply_classifier_free_guidance(self, x0_pred, num_imgs, class_guidance):
120
- """Apply classifier-free guidance to the predictions."""
121
- x0_pred_label, x0_pred_no_label = x0_pred[:num_imgs], x0_pred[num_imgs:]
122
- return class_guidance * x0_pred_label + (1 - class_guidance) * x0_pred_no_label
123
-
124
-
125
- @dataclass
126
- class LTDConfig:
127
- vae_scale_factor: float = 8
128
- img_size: int = 32
129
- model_dtype: torch.dtype = torch.float32
130
- file_url: str = None # = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
131
- local_filename: str = "state_dict_378000.pth"
132
- vae_name: str = "ByteDance/SDXL-Lightning"
133
- clip_model_name: str = "ViT-L/14"
134
- denoiser: Denoiser = Denoiser(
135
- image_size=32,
136
- noise_embed_dims=256,
137
- patch_size=2,
138
- embed_dim=256,
139
- dropout=0,
140
- n_layers=4,
141
- )
142
-
143
-
144
- def download_file(url, filename):
145
- with requests.get(url, stream=True) as r:
146
- r.raise_for_status()
147
- with open(filename, "wb") as f:
148
- for chunk in r.iter_content(chunk_size=8192):
149
- f.write(chunk)
150
-
151
-
152
- @torch.no_grad()
153
- def encode_text(label, model):
154
- text_tokens = clip.tokenize(label, truncate=True).to(device)
155
- text_encoding = model.encode_text(text_tokens)
156
- return text_encoding.cpu()
157
-
158
-
159
- class DiffusionTransformer:
160
- def __init__(self, config: LTDConfig):
161
- denoiser = config.denoiser.to(config.model_dtype)
162
-
163
- if config.file_url is not None:
164
- print(f"Downloading model from {config.file_url}")
165
- download_file(config.file_url, config.local_filename)
166
- state_dict = torch.load(config.local_filename, map_location=torch.device("cpu"))
167
- denoiser.load_state_dict(state_dict)
168
-
169
- denoiser = denoiser.to(device)
170
-
171
- vae = AutoencoderKL.from_pretrained(config.vae_name, torch_dtype=config.model_dtype).to(device)
172
-
173
- self.clip_model, preprocess = clip.load(config.clip_model_name)
174
- self.clip_model = self.clip_model.to(device)
175
-
176
- self.diffuser = DiffusionGenerator(denoiser, vae, device, config.model_dtype)
177
-
178
- def generate_image_from_text(
179
- self, prompt: str, class_guidance=6, seed=11, num_imgs=1, img_size=32, n_iter=15
180
- ):
181
- nrow = int(np.sqrt(num_imgs))
182
-
183
- cur_prompts = [prompt] * num_imgs
184
- labels = encode_text(cur_prompts, self.clip_model)
185
- out, out_latent = self.diffuser.generate(
186
- labels=labels,
187
- num_imgs=num_imgs,
188
- class_guidance=class_guidance,
189
- seed=seed,
190
- n_iter=n_iter,
191
- exponent=1,
192
- scale_factor=8,
193
- sharp_f=0,
194
- bright_f=0,
195
- )
196
-
197
- out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
198
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/gen_img.py DELETED
@@ -1,41 +0,0 @@
1
- import io
2
- import asyncio
3
- import os
4
-
5
- import torch
6
- import torchvision.transforms as transforms
7
-
8
- from diffusion import DiffusionTransformer, LTDConfig
9
-
10
- # Get the directory of the script
11
- script_directory = os.path.dirname(os.path.realpath(__file__))
12
- # Specify the directory where the cache will be stored (same folder as the script)
13
- cache_directory = os.path.join(script_directory, "cache")
14
- home_directory = os.path.join(script_directory, "home")
15
- # Create the cache directory if it doesn't exist
16
- os.makedirs(cache_directory, exist_ok=True)
17
- os.makedirs(home_directory, exist_ok=True)
18
-
19
- os.environ["TRANSFORMERS_CACHE"] = cache_directory
20
- os.environ["HF_HOME"] = home_directory
21
-
22
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
- to_pil = transforms.ToPILImage()
24
- ltdconfig = LTDConfig()
25
- diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
26
-
27
- async def generate_image(prompt):
28
- try:
29
- img = diffusion_transformer.generate_image_from_text(
30
- prompt=prompt,
31
- class_guidance=6,
32
- seed=11,
33
- num_imgs=1,
34
- img_size=32,
35
- )
36
-
37
- img.save("generated_img.png")
38
- except Exception as e:
39
- print(e)
40
-
41
- asyncio.run(generate_image("a cute cat"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/gradio_app.py DELETED
@@ -1,40 +0,0 @@
1
- import os
2
- from io import BytesIO
3
-
4
- import gradio as gr
5
- import requests
6
- from PIL import Image
7
-
8
- # runpod_id = os.environ['RUNPOD_ID']
9
- # token_id = os.environ['AUTH_TOKEN']
10
- # url = f'https://{runpod_id}-8000.proxy.runpod.net/generate-image/'
11
-
12
- url = os.getenv("API_URL")
13
- token_id = os.getenv("API_TOKEN")
14
-
15
-
16
- def generate_image_from_text(prompt, class_guidance):
17
- headers = {"Authorization": f"Bearer {token_id}"}
18
-
19
- data = {"prompt": prompt, "class_guidance": class_guidance, "seed": 11, "num_imgs": 1, "img_size": 32}
20
-
21
- response = requests.post(url, json=data, headers=headers)
22
-
23
- if response.status_code == 200:
24
- image = Image.open(BytesIO(response.content))
25
- else:
26
- print("Failed to fetch image:", response.status_code, response.text)
27
-
28
- return image
29
-
30
-
31
- iface = gr.Interface(
32
- fn=generate_image_from_text,
33
- inputs=["text", "slider"],
34
- outputs="image",
35
- title="Text-to-Image Generator",
36
- description="Enter a text prompt to generate an image.",
37
- )
38
-
39
- # Launch the app
40
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/train.py DELETED
@@ -1,208 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import copy
4
- from dataclasses import asdict, dataclass
5
-
6
- import numpy as np
7
- import torch
8
- import torchvision
9
- import torchvision.utils as vutils
10
- import wandb
11
- from accelerate import Accelerator
12
- from diffusers import AutoencoderKL
13
- from PIL.Image import Image
14
- from torch import Tensor, nn
15
- from torch.utils.data import DataLoader, TensorDataset
16
- from tqdm import tqdm
17
-
18
- from denoiser import Denoiser
19
- from diffusion import DiffusionGenerator
20
-
21
-
22
- def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image:
23
- class_guidance = 4.5
24
- seed = 10
25
- out, _ = diffuser.generate(
26
- labels=torch.repeat_interleave(labels, 8, dim=0),
27
- num_imgs=64,
28
- class_guidance=class_guidance,
29
- seed=seed,
30
- n_iter=40,
31
- exponent=1,
32
- sharp_f=0.1,
33
- )
34
-
35
- out = to_pil((vutils.make_grid((out + 1) / 2, nrow=8, padding=4)).float().clip(0, 1))
36
- out.save(f"emb_val_cfg:{class_guidance}_seed:{seed}.png")
37
-
38
- return out
39
-
40
-
41
- def count_parameters(model: nn.Module):
42
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
43
-
44
-
45
- def count_parameters_per_layer(model: nn.Module):
46
- for name, param in model.named_parameters():
47
- print(f"{name}: {param.numel()} parameters")
48
-
49
-
50
- to_pil = torchvision.transforms.ToPILImage()
51
-
52
-
53
- def update_ema(ema_model: nn.Module, model: nn.Module, alpha: float = 0.999):
54
- with torch.no_grad():
55
- for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
56
- ema_param.data.mul_(alpha).add_(model_param.data, alpha=1 - alpha)
57
-
58
-
59
- @dataclass
60
- class ModelConfig:
61
- embed_dim: int = 512
62
- n_layers: int = 6
63
- clip_embed_size: int = 768
64
- scaling_factor: int = 8
65
- patch_size: int = 2
66
- image_size: int = 32
67
- n_channels: int = 4
68
- dropout: float = 0
69
- mlp_multiplier: int = 4
70
- batch_size: int = 128
71
- class_guidance: int = 3
72
- lr: float = 3e-4
73
- n_epoch: int = 100
74
- alpha: float = 0.999
75
- noise_embed_dims: int = 128
76
- diffusion_n_iter: int = 35
77
- from_scratch: bool = True
78
- run_id: str = ""
79
- model_name: str = ""
80
- beta_a: float = 0.75
81
- beta_b: float = 0.75
82
- save_and_eval_every_iters: int = 1000
83
-
84
-
85
- @dataclass
86
- class DataConfig:
87
- latent_path: str # path to a numpy file containing latents
88
- text_emb_path: str
89
- val_path: str
90
-
91
-
92
- def main(config: ModelConfig, dataconfig: DataConfig) -> None:
93
- """main train loop to be used with accelerate"""
94
-
95
- accelerator = Accelerator(mixed_precision="fp16", log_with="wandb")
96
-
97
- accelerator.print("Loading Data:")
98
- latent_train_data = torch.tensor(np.load(dataconfig.latent_path), dtype=torch.float32)
99
- train_label_embeddings = torch.tensor(np.load(dataconfig.text_emb_path), dtype=torch.float32)
100
- emb_val = torch.tensor(np.load(dataconfig.val_path), dtype=torch.float32)
101
- dataset = TensorDataset(latent_train_data, train_label_embeddings)
102
- train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
103
-
104
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
105
-
106
- if accelerator.is_main_process:
107
- vae = vae.to(accelerator.device)
108
-
109
- model = Denoiser(
110
- image_size=config.image_size,
111
- noise_embed_dims=config.noise_embed_dims,
112
- patch_size=config.patch_size,
113
- embed_dim=config.embed_dim,
114
- dropout=config.dropout,
115
- n_layers=config.n_layers,
116
- )
117
-
118
- loss_fn = nn.MSELoss()
119
- optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
120
-
121
- accelerator.print("Compiling model:")
122
- model = torch.compile(model)
123
-
124
- if not config.from_scratch:
125
- accelerator.print("Loading Model:")
126
- wandb.restore(
127
- config.model_name, run_path=f"apapiu/cifar_diffusion/runs/{config.run_id}", replace=True
128
- )
129
- full_state_dict = torch.load(config.model_name)
130
- model.load_state_dict(full_state_dict["model_ema"])
131
- optimizer.load_state_dict(full_state_dict["opt_state"])
132
- global_step = full_state_dict["global_step"]
133
- else:
134
- global_step = 0
135
-
136
- if accelerator.is_local_main_process:
137
- ema_model = copy.deepcopy(model).to(accelerator.device)
138
- diffuser = DiffusionGenerator(ema_model, vae, accelerator.device, torch.float32)
139
-
140
- accelerator.print("model prep")
141
- model, train_loader, optimizer = accelerator.prepare(model, train_loader, optimizer)
142
-
143
- accelerator.init_trackers(project_name="cifar_diffusion", config=asdict(config))
144
-
145
- accelerator.print(count_parameters(model))
146
- accelerator.print(count_parameters_per_layer(model))
147
-
148
- ### Train:
149
- for i in range(1, config.n_epoch + 1):
150
- accelerator.print(f"epoch: {i}")
151
-
152
- for x, y in tqdm(train_loader):
153
- x = x / config.scaling_factor
154
-
155
- noise_level = torch.tensor(
156
- np.random.beta(config.beta_a, config.beta_b, len(x)), device=accelerator.device
157
- )
158
- signal_level = 1 - noise_level
159
- noise = torch.randn_like(x)
160
-
161
- x_noisy = noise_level.view(-1, 1, 1, 1) * noise + signal_level.view(-1, 1, 1, 1) * x
162
-
163
- x_noisy = x_noisy.float()
164
- noise_level = noise_level.float()
165
- label = y
166
-
167
- prob = 0.15
168
- mask = torch.rand(y.size(0), device=accelerator.device) < prob
169
- label[mask] = 0 # OR replacement_vector
170
-
171
- if global_step % config.save_and_eval_every_iters == 0:
172
- accelerator.wait_for_everyone()
173
- if accelerator.is_main_process:
174
- ##eval and saving:
175
- out = eval_gen(diffuser=diffuser, labels=emb_val)
176
- out.save("img.jpg")
177
- accelerator.log({f"step: {global_step}": wandb.Image("img.jpg")})
178
-
179
- opt_unwrapped = accelerator.unwrap_model(optimizer)
180
- full_state_dict = {
181
- "model_ema": ema_model.state_dict(),
182
- "opt_state": opt_unwrapped.state_dict(),
183
- "global_step": global_step,
184
- }
185
- accelerator.save(full_state_dict, config.model_name)
186
- wandb.save(config.model_name)
187
-
188
- model.train()
189
-
190
- with accelerator.accumulate():
191
- ###train loop:
192
- optimizer.zero_grad()
193
-
194
- pred = model(x_noisy, noise_level.view(-1, 1), label)
195
- loss = loss_fn(pred, x)
196
- accelerator.log({"train_loss": loss.item()}, step=global_step)
197
- accelerator.backward(loss)
198
- optimizer.step()
199
-
200
- if accelerator.is_main_process:
201
- update_ema(ema_model, model, alpha=config.alpha)
202
-
203
- global_step += 1
204
- accelerator.end_training()
205
-
206
-
207
- # args = (config, data_path, val_path)
208
- # notebook_launcher(training_loop)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tld/transformer_blocks.py DELETED
@@ -1,139 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- from einops import rearrange
5
-
6
-
7
- class SinusoidalEmbedding(nn.Module):
8
- def __init__(self, emb_min_freq=1.0, emb_max_freq=1000.0, embedding_dims=32):
9
- super(SinusoidalEmbedding, self).__init__()
10
-
11
- frequencies = torch.exp(
12
- torch.linspace(np.log(emb_min_freq), np.log(emb_max_freq), embedding_dims // 2)
13
- )
14
-
15
- self.register_buffer("angular_speeds", 2.0 * torch.pi * frequencies)
16
-
17
- def forward(self, x):
18
- embeddings = torch.cat(
19
- [torch.sin(self.angular_speeds * x), torch.cos(self.angular_speeds * x)], dim=-1
20
- )
21
- return embeddings
22
-
23
-
24
- class MHAttention(nn.Module):
25
- def __init__(self, is_causal=False, dropout_level=0.0, n_heads=4):
26
- super().__init__()
27
- self.is_causal = is_causal
28
- self.dropout_level = dropout_level
29
- self.n_heads = n_heads
30
-
31
- def forward(self, q, k, v, attn_mask=None):
32
- assert q.size(-1) == k.size(-1)
33
- assert k.size(-2) == v.size(-2)
34
-
35
- q, k, v = [rearrange(x, "bs n (h d) -> bs h n d", h=self.n_heads) for x in [q, k, v]]
36
-
37
- out = nn.functional.scaled_dot_product_attention(
38
- q,
39
- k,
40
- v,
41
- attn_mask=attn_mask,
42
- is_causal=self.is_causal,
43
- dropout_p=self.dropout_level if self.training else 0,
44
- )
45
-
46
- out = rearrange(out, "bs h n d -> bs n (h d)", h=self.n_heads)
47
-
48
- return out
49
-
50
-
51
- class SelfAttention(nn.Module):
52
- def __init__(self, embed_dim, is_causal=False, dropout_level=0.0, n_heads=4):
53
- super().__init__()
54
- self.qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
55
- self.mha = MHAttention(is_causal, dropout_level, n_heads)
56
-
57
- def forward(self, x):
58
- q, k, v = self.qkv_linear(x).chunk(3, dim=2)
59
- return self.mha(q, k, v)
60
-
61
-
62
- class CrossAttention(nn.Module):
63
- def __init__(self, embed_dim, is_causal=False, dropout_level=0, n_heads=4):
64
- super().__init__()
65
- self.kv_linear = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
66
- self.q_linear = nn.Linear(embed_dim, embed_dim, bias=False)
67
- self.mha = MHAttention(is_causal, dropout_level, n_heads)
68
-
69
- def forward(self, x, y):
70
- q = self.q_linear(x)
71
- k, v = self.kv_linear(y).chunk(2, dim=2)
72
- return self.mha(q, k, v)
73
-
74
-
75
- class MLP(nn.Module):
76
- def __init__(self, embed_dim, mlp_multiplier, dropout_level):
77
- super().__init__()
78
- self.mlp = nn.Sequential(
79
- nn.Linear(embed_dim, mlp_multiplier * embed_dim),
80
- nn.GELU(),
81
- nn.Linear(mlp_multiplier * embed_dim, embed_dim),
82
- nn.Dropout(dropout_level),
83
- )
84
-
85
- def forward(self, x):
86
- return self.mlp(x)
87
-
88
-
89
- class MLPSepConv(nn.Module):
90
- def __init__(self, embed_dim, mlp_multiplier, dropout_level):
91
- """see: https://github.com/ofsoundof/LocalViT"""
92
- super().__init__()
93
- self.mlp = nn.Sequential(
94
- # this Conv with kernel size 1 is equivalent to the Linear layer in a "regular" transformer MLP
95
- nn.Conv2d(embed_dim, mlp_multiplier * embed_dim, kernel_size=1, padding="same"),
96
- nn.Conv2d(
97
- mlp_multiplier * embed_dim,
98
- mlp_multiplier * embed_dim,
99
- kernel_size=3,
100
- padding="same",
101
- groups=mlp_multiplier * embed_dim,
102
- ), # <- depthwise conv
103
- nn.GELU(),
104
- nn.Conv2d(mlp_multiplier * embed_dim, embed_dim, kernel_size=1, padding="same"),
105
- nn.Dropout(dropout_level),
106
- )
107
-
108
- def forward(self, x):
109
- w = h = int(np.sqrt(x.size(1))) # only square images for now
110
- x = rearrange(x, "bs (h w) d -> bs d h w", h=h, w=w)
111
- x = self.mlp(x)
112
- x = rearrange(x, "bs d h w -> bs (h w) d")
113
- return x
114
-
115
-
116
- class DecoderBlock(nn.Module):
117
- def __init__(
118
- self,
119
- embed_dim: int,
120
- is_causal: bool,
121
- mlp_multiplier: int,
122
- dropout_level: float,
123
- mlp_class: type[MLP] | type[MLPSepConv],
124
- ):
125
- super().__init__()
126
- self.self_attention = SelfAttention(embed_dim, is_causal, dropout_level, n_heads=embed_dim // 64)
127
- self.cross_attention = CrossAttention(
128
- embed_dim, is_causal=False, dropout_level=0, n_heads=embed_dim // 64
129
- )
130
- self.mlp = mlp_class(embed_dim, mlp_multiplier, dropout_level)
131
- self.norm1 = nn.LayerNorm(embed_dim)
132
- self.norm2 = nn.LayerNorm(embed_dim)
133
- self.norm3 = nn.LayerNorm(embed_dim)
134
-
135
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
136
- x = self.self_attention(self.norm1(x)) + x
137
- x = self.cross_attention(self.norm2(x), y) + x
138
- x = self.mlp(self.norm3(x)) + x
139
- return x