richiesh lopho commited on
Commit
66df18e
·
0 Parent(s):

Duplicate from TempoFunk/makeavid-sd-jax

Browse files

Co-authored-by: lopho <[email protected]>

Files changed (41) hide show
  1. .gitattributes +37 -0
  2. README.md +22 -0
  3. app.py +364 -0
  4. example.gif +3 -0
  5. examples/example_01_barbarian/input.png +3 -0
  6. examples/example_01_barbarian/output.gif +3 -0
  7. examples/example_01_barbarian/params.json +14 -0
  8. examples/example_02_zombies/output.gif +3 -0
  9. examples/example_02_zombies/params.json +14 -0
  10. examples/example_03_astronaut/output.gif +3 -0
  11. examples/example_03_astronaut/params.json +14 -0
  12. examples/example_04_furry_moster/output.gif +3 -0
  13. examples/example_04_furry_moster/params.json +14 -0
  14. examples/example_05_people/input.png +3 -0
  15. examples/example_05_people/output.gif +3 -0
  16. examples/example_05_people/params.json +14 -0
  17. examples/example_06_sophie/output.gif +3 -0
  18. examples/example_06_sophie/params.json +14 -0
  19. makeavid_sd/LICENSE +661 -0
  20. makeavid_sd/__init__.py +1 -0
  21. makeavid_sd/flax_impl/__init__.py +0 -0
  22. makeavid_sd/flax_impl/dataset.py +159 -0
  23. makeavid_sd/flax_impl/flax_attention_pseudo3d.py +212 -0
  24. makeavid_sd/flax_impl/flax_embeddings.py +62 -0
  25. makeavid_sd/flax_impl/flax_resnet_pseudo3d.py +175 -0
  26. makeavid_sd/flax_impl/flax_trainer.py +608 -0
  27. makeavid_sd/flax_impl/flax_unet_pseudo3d_blocks.py +254 -0
  28. makeavid_sd/flax_impl/flax_unet_pseudo3d_condition.py +251 -0
  29. makeavid_sd/flax_impl/train.py +143 -0
  30. makeavid_sd/flax_impl/train.sh +34 -0
  31. makeavid_sd/inference.py +534 -0
  32. makeavid_sd/torch_impl/__init__.py +0 -0
  33. makeavid_sd/torch_impl/torch_attention_pseudo3d.py +294 -0
  34. makeavid_sd/torch_impl/torch_cross_attention.py +171 -0
  35. makeavid_sd/torch_impl/torch_embeddings.py +92 -0
  36. makeavid_sd/torch_impl/torch_resnet_pseudo3d.py +295 -0
  37. makeavid_sd/torch_impl/torch_unet_pseudo3d_blocks.py +493 -0
  38. makeavid_sd/torch_impl/torch_unet_pseudo3d_condition.py +235 -0
  39. packages.txt +0 -0
  40. pre-requirements.txt +5 -0
  41. requirements.txt +10 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.webp filter=lfs diff=lfs merge=lfs -text
2
+ *.gif filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.7z filter=lfs diff=lfs merge=lfs -text
5
+ *.arrow filter=lfs diff=lfs merge=lfs -text
6
+ *.bin filter=lfs diff=lfs merge=lfs -text
7
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
8
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
9
+ *.ftz filter=lfs diff=lfs merge=lfs -text
10
+ *.gz filter=lfs diff=lfs merge=lfs -text
11
+ *.h5 filter=lfs diff=lfs merge=lfs -text
12
+ *.joblib filter=lfs diff=lfs merge=lfs -text
13
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
14
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
15
+ *.model filter=lfs diff=lfs merge=lfs -text
16
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
17
+ *.npy filter=lfs diff=lfs merge=lfs -text
18
+ *.npz filter=lfs diff=lfs merge=lfs -text
19
+ *.onnx filter=lfs diff=lfs merge=lfs -text
20
+ *.ot filter=lfs diff=lfs merge=lfs -text
21
+ *.parquet filter=lfs diff=lfs merge=lfs -text
22
+ *.pb filter=lfs diff=lfs merge=lfs -text
23
+ *.pickle filter=lfs diff=lfs merge=lfs -text
24
+ *.pkl filter=lfs diff=lfs merge=lfs -text
25
+ *.pt filter=lfs diff=lfs merge=lfs -text
26
+ *.pth filter=lfs diff=lfs merge=lfs -text
27
+ *.rar filter=lfs diff=lfs merge=lfs -text
28
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
29
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
31
+ *.tflite filter=lfs diff=lfs merge=lfs -text
32
+ *.tgz filter=lfs diff=lfs merge=lfs -text
33
+ *.wasm filter=lfs diff=lfs merge=lfs -text
34
+ *.xz filter=lfs diff=lfs merge=lfs -text
35
+ *.zip filter=lfs diff=lfs merge=lfs -text
36
+ *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Make-A-Video Stable Diffusion Jax
3
+ emoji: 💀
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.28.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: agpl-3.0
11
+ library_name: diffusers
12
+ pipeline_tag: text-to-video
13
+ datasets:
14
+ - TempoFunk/tempofunk-sdance
15
+ - TempoFunk/small
16
+ models:
17
+ - TempoFunk/makeavid-sd-jax
18
+ - runwayml/stable-diffusion-v1-5
19
+ tags:
20
+ - jax-diffusers-event
21
+ duplicated_from: TempoFunk/makeavid-sd-jax
22
+ ---
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ from io import BytesIO
5
+ import base64
6
+ from functools import partial
7
+
8
+ from PIL import Image, ImageOps
9
+ import gradio as gr
10
+
11
+ from makeavid_sd.inference import (
12
+ InferenceUNetPseudo3D,
13
+ jnp,
14
+ SCHEDULERS
15
+ )
16
+
17
+ print(os.environ.get('XLA_PYTHON_CLIENT_PREALLOCATE', 'NotSet'))
18
+ print(os.environ.get('XLA_PYTHON_CLIENT_ALLOCATOR', 'NotSet'))
19
+
20
+ _seen_compilations = set()
21
+
22
+ _model = InferenceUNetPseudo3D(
23
+ model_path = 'TempoFunk/makeavid-sd-jax',
24
+ dtype = jnp.float16,
25
+ hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
26
+ )
27
+
28
+ if _model.failed != False:
29
+ trace = f'```{_model.failed}```'
30
+ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo:
31
+ exception = gr.Markdown(trace)
32
+ demo.launch()
33
+
34
+ _examples = []
35
+ _expath = 'examples'
36
+ for x in sorted(os.listdir(_expath)):
37
+ with open(os.path.join(_expath, x, 'params.json'), 'r') as f:
38
+ ex = json.load(f)
39
+ ex['image_input'] = None
40
+ if os.path.isfile(os.path.join(_expath, x, 'input.png')):
41
+ ex['image_input'] = os.path.join(_expath, x, 'input.png')
42
+ ex['image_output'] = os.path.join(_expath, x, 'output.gif')
43
+ _examples.append(ex)
44
+
45
+
46
+ _output_formats = (
47
+ 'webp', 'gif'
48
+ )
49
+
50
+ # gradio is illiterate. type hints make it go poopoo in pantsu.
51
+ def generate(
52
+ prompt = 'An elderly man having a great time in the park.',
53
+ neg_prompt = '',
54
+ hint_image = None,
55
+ inference_steps = 20,
56
+ cfg = 15.0,
57
+ cfg_image = 9.0,
58
+ seed = 0,
59
+ fps = 12,
60
+ num_frames = 24,
61
+ height = 512,
62
+ width = 512,
63
+ scheduler_type = 'dpm',
64
+ output_format = 'gif'
65
+ ) -> str:
66
+ num_frames = min(24, max(2, int(num_frames)))
67
+ inference_steps = min(60, max(2, int(inference_steps)))
68
+ height = min(576, max(256, int(height)))
69
+ width = min(576, max(256, int(width)))
70
+ height = (height // 64) * 64
71
+ width = (width // 64) * 64
72
+ cfg = max(cfg, 1.0)
73
+ cfg_image = max(cfg_image, 1.0)
74
+ fps = min(1000, max(1, int(fps)))
75
+ seed = min(2**32-2, int(seed))
76
+ if seed < 0:
77
+ seed = -seed
78
+ if hint_image is not None:
79
+ if hint_image.mode != 'RGB':
80
+ hint_image = hint_image.convert('RGB')
81
+ if hint_image.size != (width, height):
82
+ hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS)
83
+ scheduler_type = scheduler_type.lower()
84
+ if scheduler_type not in SCHEDULERS:
85
+ scheduler_type = 'dpm'
86
+ output_format = output_format.lower()
87
+ if output_format not in _output_formats:
88
+ output_format = 'gif'
89
+ mask_image = None
90
+ images = _model.generate(
91
+ prompt = [prompt] * _model.device_count,
92
+ neg_prompt = neg_prompt,
93
+ hint_image = hint_image,
94
+ mask_image = mask_image,
95
+ inference_steps = inference_steps,
96
+ cfg = cfg,
97
+ cfg_image = cfg_image,
98
+ height = height,
99
+ width = width,
100
+ num_frames = num_frames,
101
+ seed = seed,
102
+ scheduler_type = scheduler_type
103
+ )
104
+ _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
105
+ with BytesIO() as buffer:
106
+ images[1].save(
107
+ buffer,
108
+ format = output_format,
109
+ save_all = True,
110
+ append_images = images[2:],
111
+ loop = 0,
112
+ duration = round(1000 / fps),
113
+ allow_mixed = True,
114
+ optimize = True
115
+ )
116
+ data = f'data:image/{output_format};base64,' + base64.b64encode(buffer.getvalue()).decode()
117
+ with BytesIO() as buffer:
118
+ images[-1].save(buffer, format = 'png', optimize = True)
119
+ last_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
120
+ with BytesIO() as buffer:
121
+ images[0].save(buffer, format ='png', optimize = True)
122
+ first_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
123
+ return data, last_data, first_data
124
+
125
+ def check_if_compiled(hint_image, inference_steps, height, width, num_frames, scheduler_type, message):
126
+ height = int(height)
127
+ width = int(width)
128
+ inference_steps = int(inference_steps)
129
+ height = (height // 64) * 64
130
+ width = (width // 64) * 64
131
+ if (hint_image is None, inference_steps, height, width, num_frames, scheduler_type) in _seen_compilations:
132
+ return ''
133
+ else:
134
+ return message
135
+
136
+ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo:
137
+ variant = 'panel'
138
+ with gr.Row():
139
+ with gr.Column():
140
+ intro1 = gr.Markdown("""
141
+ # Make-A-Video Stable Diffusion JAX
142
+
143
+ We have extended a pretrained latent-diffusion inpainting image generation model with **temporal convolutions and attention**.
144
+ We guide the video generation with a hint image by taking advantage of the extra 5 input channels of the inpainting model.
145
+ In this demo the hint image can be given by the user, otherwise it is generated by an generative image model.
146
+
147
+ The temporal layers are a port of [Make-A-Video PyTorch](https://github.com/lucidrains/make-a-video-pytorch) to [JAX](https://github.com/google/jax) utilizing [FLAX](https://github.com/google/flax).
148
+ The convolution is pseudo 3D and seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D.
149
+ Temporal attention is purely self attention and also separately attends to time.
150
+
151
+ Only the new temporal layers have been fine tuned on a dataset of videos themed around dance.
152
+ The model has been trained for 80 epochs on a dataset of 18,000 Videos with 120 frames each, randomly selecting a 24 frame range from each sample.
153
+
154
+ See model and dataset links in the metadata.
155
+
156
+ Model implementation and training code can be found at <https://github.com/lopho/makeavid-sd-tpu>
157
+ """)
158
+ with gr.Column():
159
+ intro3 = gr.Markdown("""
160
+ **Please be patient. The model might have to compile with current parameters.**
161
+
162
+ This can take up to 5 minutes on the first run, and 2-3 minutes on later runs.
163
+ The compilation will be cached and later runs with the same parameters
164
+ will be much faster.
165
+
166
+ Changes to the following parameters require the model to compile
167
+ - Number of frames
168
+ - Width & Height
169
+ - Inference steps
170
+ - Input image vs. no input image
171
+ - Noise scheduler type
172
+
173
+ If you encounter any issues, please report them here: [Space discussions](https://huggingface.co/spaces/TempoFunk/makeavid-sd-jax/discussions) (or DM [@lopho](https://twitter.com/lopho))
174
+
175
+ <small>Leave a ❤️ like if you like. Consider it a dopamine donation at no cost.</small>
176
+ """)
177
+
178
+ with gr.Row(variant = variant):
179
+ with gr.Column():
180
+ with gr.Row():
181
+ #cancel_button = gr.Button(value = 'Cancel')
182
+ submit_button = gr.Button(value = 'Make A Video', variant = 'primary')
183
+ prompt_input = gr.Textbox(
184
+ label = 'Prompt',
185
+ value = 'They are dancing in the club but everybody is a 3d cg hairy monster wearing a hairy costume.',
186
+ interactive = True
187
+ )
188
+ neg_prompt_input = gr.Textbox(
189
+ label = 'Negative prompt (optional)',
190
+ value = 'monochrome, saturated',
191
+ interactive = True
192
+ )
193
+ cfg_input = gr.Slider(
194
+ label = 'Guidance scale video',
195
+ minimum = 1.0,
196
+ maximum = 20.0,
197
+ step = 0.1,
198
+ value = 15.0,
199
+ interactive = True
200
+ )
201
+ cfg_image_input = gr.Slider(
202
+ label = 'Guidance scale hint (no effect with input image)',
203
+ minimum = 1.0,
204
+ maximum = 20.0,
205
+ step = 0.1,
206
+ value = 15.0,
207
+ interactive = True
208
+ )
209
+ seed_input = gr.Number(
210
+ label = 'Random seed',
211
+ value = 0,
212
+ interactive = True,
213
+ precision = 0
214
+ )
215
+ image_input = gr.Image(
216
+ label = 'Hint image (optional)',
217
+ interactive = True,
218
+ image_mode = 'RGB',
219
+ type = 'pil',
220
+ optional = True,
221
+ source = 'upload'
222
+ )
223
+ inference_steps_input = gr.Slider(
224
+ label = 'Steps',
225
+ minimum = 2,
226
+ maximum = 60,
227
+ value = 20,
228
+ step = 1,
229
+ interactive = True
230
+ )
231
+ num_frames_input = gr.Slider(
232
+ label = 'Number of frames to generate',
233
+ minimum = 2,
234
+ maximum = 24,
235
+ step = 1,
236
+ value = 24,
237
+ interactive = True
238
+ )
239
+ width_input = gr.Slider(
240
+ label = 'Width',
241
+ minimum = 256,
242
+ maximum = 576,
243
+ step = 64,
244
+ value = 512,
245
+ interactive = True
246
+ )
247
+ height_input = gr.Slider(
248
+ label = 'Height',
249
+ minimum = 256,
250
+ maximum = 576,
251
+ step = 64,
252
+ value = 512,
253
+ interactive = True
254
+ )
255
+ scheduler_input = gr.Dropdown(
256
+ label = 'Noise scheduler',
257
+ choices = list(SCHEDULERS.keys()),
258
+ value = 'dpm',
259
+ interactive = True
260
+ )
261
+ with gr.Row():
262
+ fps_input = gr.Slider(
263
+ label = 'Output FPS',
264
+ minimum = 1,
265
+ maximum = 1000,
266
+ step = 1,
267
+ value = 12,
268
+ interactive = True
269
+ )
270
+ output_format = gr.Dropdown(
271
+ label = 'Output format',
272
+ choices = _output_formats,
273
+ value = 'gif',
274
+ interactive = True
275
+ )
276
+ with gr.Column():
277
+ #will_trigger = gr.Markdown('')
278
+ patience = gr.Markdown('**Please be patient. The model might have to compile with current parameters.**')
279
+ image_output = gr.Image(
280
+ label = 'Output',
281
+ value = 'example.gif',
282
+ interactive = False
283
+ )
284
+ tips = gr.Markdown('🤫 *Secret tip*: try using the last frame as input for the next generation.')
285
+ with gr.Row():
286
+ last_frame_output = gr.Image(
287
+ label = 'Last frame',
288
+ interactive = False
289
+ )
290
+ first_frame_output = gr.Image(
291
+ label = 'Initial frame',
292
+ interactive = False
293
+ )
294
+ examples_lst = []
295
+ for x in _examples:
296
+ examples_lst.append([
297
+ x['image_output'],
298
+ x['prompt'],
299
+ x['neg_prompt'],
300
+ x['image_input'],
301
+ x['cfg'],
302
+ x['cfg_image'],
303
+ x['seed'],
304
+ x['fps'],
305
+ x['steps'],
306
+ x['scheduler'],
307
+ x['num_frames'],
308
+ x['height'],
309
+ x['width'],
310
+ x['format']
311
+ ])
312
+ examples = gr.Examples(
313
+ examples = examples_lst,
314
+ inputs = [
315
+ image_output,
316
+ prompt_input,
317
+ neg_prompt_input,
318
+ image_input,
319
+ cfg_input,
320
+ cfg_image_input,
321
+ seed_input,
322
+ fps_input,
323
+ inference_steps_input,
324
+ scheduler_input,
325
+ num_frames_input,
326
+ height_input,
327
+ width_input,
328
+ output_format
329
+ ],
330
+ postprocess = False
331
+ )
332
+ #trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input, scheduler_input ]
333
+ #trigger_check_fun = partial(check_if_compiled, message = 'Current parameters need compilation.')
334
+ #height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
335
+ #width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
336
+ #num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
337
+ #image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
338
+ #inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
339
+ #scheduler_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
340
+ submit_button.click(
341
+ fn = generate,
342
+ inputs = [
343
+ prompt_input,
344
+ neg_prompt_input,
345
+ image_input,
346
+ inference_steps_input,
347
+ cfg_input,
348
+ cfg_image_input,
349
+ seed_input,
350
+ fps_input,
351
+ num_frames_input,
352
+ height_input,
353
+ width_input,
354
+ scheduler_input,
355
+ output_format
356
+ ],
357
+ outputs = [ image_output, last_frame_output, first_frame_output ],
358
+ postprocess = False
359
+ )
360
+ #cancel_button.click(fn = lambda: None, cancels = ev)
361
+
362
+ demo.queue(concurrency_count = 1, max_size = 8)
363
+ demo.launch()
364
+
example.gif ADDED

Git LFS Details

  • SHA256: d5cd05f2a45e4b0b3fa5465d8a8203fad029246071163787cb602e8d630aa70d
  • Pointer size: 132 Bytes
  • Size of remote file: 4.33 MB
examples/example_01_barbarian/input.png ADDED

Git LFS Details

  • SHA256: 87c4d10eb4e1bfa8f09657ec0d85de66052e34c1801b7b21e1cfd4123504b42b
  • Pointer size: 131 Bytes
  • Size of remote file: 471 kB
examples/example_01_barbarian/output.gif ADDED

Git LFS Details

  • SHA256: b9d3fb7269244fe2e60bfe5a1104e4b5ec6c9e322b371573d8adacba017b594d
  • Pointer size: 132 Bytes
  • Size of remote file: 5.33 MB
examples/example_01_barbarian/params.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "he is dancing as minotaur dancer wearing a fur armor water in a dark cave, john cena, fantasy, barbarian",
3
+ "neg_prompt": "",
4
+ "cfg": 15,
5
+ "cfg_image": 9,
6
+ "seed": 1,
7
+ "steps": 20,
8
+ "width": 512,
9
+ "height": 512,
10
+ "scheduler": "dpm",
11
+ "fps": 20,
12
+ "format": "gif",
13
+ "num_frames": 24
14
+ }
examples/example_02_zombies/output.gif ADDED

Git LFS Details

  • SHA256: f31690b537f0f45dda16c16281a7bfb730e2f431825fe0b5f5db6f0b9626b388
  • Pointer size: 132 Bytes
  • Size of remote file: 4.39 MB
examples/example_02_zombies/params.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Group of scary zombies dancing. Halloween concept.",
3
+ "neg_prompt": "monochrome",
4
+ "cfg": 15,
5
+ "cfg_image": 15,
6
+ "seed": 0,
7
+ "steps": 20,
8
+ "width": 512,
9
+ "height": 512,
10
+ "scheduler": "dpm",
11
+ "fps": 20,
12
+ "format": "gif",
13
+ "num_frames": 24
14
+ }
examples/example_03_astronaut/output.gif ADDED

Git LFS Details

  • SHA256: d849e387f15d15ba192eba4ffd11613fb70c19a23a2df4df47cee2d1cd049695
  • Pointer size: 132 Bytes
  • Size of remote file: 5.36 MB
examples/example_03_astronaut/params.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Astronaut performing shuffle dance moves on a Moon surface. Stanley Kubrick.",
3
+ "neg_prompt": "",
4
+ "cfg": 15,
5
+ "cfg_image": 15,
6
+ "seed": 0,
7
+ "steps": 20,
8
+ "width": 512,
9
+ "height": 512,
10
+ "scheduler": "dpm",
11
+ "fps": 20,
12
+ "format": "gif",
13
+ "num_frames": 24
14
+ }
examples/example_04_furry_moster/output.gif ADDED

Git LFS Details

  • SHA256: d5cd05f2a45e4b0b3fa5465d8a8203fad029246071163787cb602e8d630aa70d
  • Pointer size: 132 Bytes
  • Size of remote file: 4.33 MB
examples/example_04_furry_moster/params.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "They are dancing in the club but everybody is a 3d cg hairy monster wearing a hairy costume.",
3
+ "neg_prompt": "monochrome, saturated",
4
+ "cfg": 15,
5
+ "cfg_image": 15,
6
+ "seed": 0,
7
+ "steps": 20,
8
+ "width": 512,
9
+ "height": 512,
10
+ "scheduler": "dpm",
11
+ "fps": 12,
12
+ "format": "gif",
13
+ "num_frames": 24
14
+ }
examples/example_05_people/input.png ADDED

Git LFS Details

  • SHA256: 62ab7da78435c4284915c836986d3d4c72610a28b5ca5d971bccb9a639686b43
  • Pointer size: 131 Bytes
  • Size of remote file: 408 kB
examples/example_05_people/output.gif ADDED

Git LFS Details

  • SHA256: 0be83e354696c3c113d2053ec00b1ca4a7a5d797ffa58310e58973084b025a57
  • Pointer size: 132 Bytes
  • Size of remote file: 4.48 MB
examples/example_05_people/params.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Front view close up of group of people dancing at a concert in nightclub.",
3
+ "neg_prompt": "",
4
+ "cfg": 15,
5
+ "cfg_image": 9,
6
+ "seed": 3,
7
+ "steps": 20,
8
+ "width": 512,
9
+ "height": 512,
10
+ "scheduler": "dpm",
11
+ "fps": 20,
12
+ "format": "gif",
13
+ "num_frames": 24
14
+ }
examples/example_06_sophie/output.gif ADDED

Git LFS Details

  • SHA256: 464c468839bdc51e36f8f3c61f1d8d5f823414d207d024059ee9dbcebceda044
  • Pointer size: 132 Bytes
  • Size of remote file: 4.63 MB
examples/example_06_sophie/params.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "A girl is dancing by a beautiful lake by sophie anderson and greg rutkowski and alphonse mucha.",
3
+ "neg_prompt": "",
4
+ "cfg": 15,
5
+ "cfg_image": 15,
6
+ "seed": 1,
7
+ "steps": 20,
8
+ "width": 512,
9
+ "height": 512,
10
+ "scheduler": "dpm",
11
+ "fps": 20,
12
+ "format": "gif",
13
+ "num_frames": 24
14
+ }
makeavid_sd/LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
makeavid_sd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.1.0'
makeavid_sd/flax_impl/__init__.py ADDED
File without changes
makeavid_sd/flax_impl/dataset.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Dict, Any, Union, Optional
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader, ConcatDataset
6
+ import datasets
7
+ from diffusers import DDPMScheduler
8
+ from functools import partial
9
+ import random
10
+
11
+ import numpy as np
12
+
13
+
14
+ @torch.no_grad()
15
+ def collate_fn(
16
+ batch: List[Dict[str, Any]],
17
+ noise_scheduler: DDPMScheduler,
18
+ num_frames: int,
19
+ hint_spacing: Optional[int] = None,
20
+ as_numpy: bool = True
21
+ ) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
22
+ if hint_spacing is None or hint_spacing < 1:
23
+ hint_spacing = num_frames
24
+ if as_numpy:
25
+ dtype = np.float32
26
+ else:
27
+ dtype = torch.float32
28
+ prompts = []
29
+ videos = []
30
+ for s in batch:
31
+ # prompt
32
+ prompts.append(torch.tensor(s['prompt']).to(dtype = torch.float32))
33
+ # frames
34
+ frames = torch.tensor(s['video']).to(dtype = torch.float32)
35
+ max_frames = len(frames)
36
+ assert max_frames >= num_frames
37
+ video_slice = random.randint(0, max_frames - num_frames)
38
+ frames = frames[video_slice:video_slice + num_frames]
39
+ frames = frames.permute(1, 0, 2, 3) # f, c, h, w -> c, f, h, w
40
+ videos.append(frames)
41
+
42
+ encoder_hidden_states = torch.cat(prompts) # b, 77, 768
43
+
44
+ latents = torch.stack(videos) # b, c, f, h, w
45
+ latents = latents * 0.18215
46
+ hint_latents = latents[:, :, ::hint_spacing, :, :]
47
+ hint_latents = hint_latents.repeat_interleave(hint_spacing, 2)
48
+ #hint_latents = hint_latents[:, :, :num_frames-1, :, :]
49
+ #input_latents = latents[:, :, 1:, :, :]
50
+ input_latents = latents
51
+ noise = torch.randn_like(input_latents)
52
+ bsz = input_latents.shape[0]
53
+ timesteps = torch.randint(
54
+ 0,
55
+ noise_scheduler.config.num_train_timesteps,
56
+ (bsz,),
57
+ dtype = torch.int64
58
+ )
59
+ noisy_latents = noise_scheduler.add_noise(input_latents, noise, timesteps)
60
+ mask = torch.zeros([
61
+ noisy_latents.shape[0],
62
+ 1,
63
+ noisy_latents.shape[2],
64
+ noisy_latents.shape[3],
65
+ noisy_latents.shape[4]
66
+ ])
67
+ latent_model_input = torch.cat([noisy_latents, mask, hint_latents], dim = 1)
68
+
69
+ latent_model_input = latent_model_input.to(memory_format = torch.contiguous_format)
70
+ encoder_hidden_states = encoder_hidden_states.to(memory_format = torch.contiguous_format)
71
+ timesteps = timesteps.to(memory_format = torch.contiguous_format)
72
+ noise = noise.to(memory_format = torch.contiguous_format)
73
+
74
+ if as_numpy:
75
+ latent_model_input = latent_model_input.numpy().astype(dtype)
76
+ encoder_hidden_states = encoder_hidden_states.numpy().astype(dtype)
77
+ timesteps = timesteps.numpy().astype(np.int32)
78
+ noise = noise.numpy().astype(dtype)
79
+ else:
80
+ latent_model_input = latent_model_input.to(dtype = dtype)
81
+ encoder_hidden_states = encoder_hidden_states.to(dtype = dtype)
82
+ noise = noise.to(dtype = dtype)
83
+
84
+ return {
85
+ 'latent_model_input': latent_model_input,
86
+ 'encoder_hidden_states': encoder_hidden_states,
87
+ 'timesteps': timesteps,
88
+ 'noise': noise
89
+ }
90
+
91
+ def worker_init_fn(worker_id: int):
92
+ wseed = torch.initial_seed() % 4294967294 # max val for random 2**32 - 1
93
+ random.seed(wseed)
94
+ np.random.seed(wseed)
95
+
96
+
97
+ def load_dataset(
98
+ dataset_path: str,
99
+ model_path: str,
100
+ cache_dir: Optional[str] = None,
101
+ batch_size: int = 1,
102
+ num_frames: int = 24,
103
+ hint_spacing: Optional[int] = None,
104
+ num_workers: int = 0,
105
+ shuffle: bool = False,
106
+ as_numpy: bool = True,
107
+ pin_memory: bool = False,
108
+ pin_memory_device: str = ''
109
+ ) -> DataLoader:
110
+ noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
111
+ model_path,
112
+ subfolder = 'scheduler'
113
+ )
114
+ dataset = datasets.load_dataset(
115
+ dataset_path,
116
+ streaming = False,
117
+ cache_dir = cache_dir
118
+ )
119
+ merged_dataset = ConcatDataset([ dataset[s] for s in dataset ])
120
+ dataloader = DataLoader(
121
+ merged_dataset,
122
+ batch_size = batch_size,
123
+ num_workers = num_workers,
124
+ persistent_workers = num_workers > 0,
125
+ drop_last = True,
126
+ shuffle = shuffle,
127
+ worker_init_fn = worker_init_fn,
128
+ collate_fn = partial(collate_fn,
129
+ noise_scheduler = noise_scheduler,
130
+ num_frames = num_frames,
131
+ hint_spacing = hint_spacing,
132
+ as_numpy = as_numpy
133
+ ),
134
+ pin_memory = pin_memory,
135
+ pin_memory_device = pin_memory_device
136
+ )
137
+ return dataloader
138
+
139
+
140
+ def validate_dataset(
141
+ dataset_path: str
142
+ ) -> List[int]:
143
+ import os
144
+ import json
145
+ data_path = os.path.join(dataset_path, 'data')
146
+ meta = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'metadata')))
147
+ prompts = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'prompts')))
148
+ videos = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'videos')))
149
+ ok = meta.intersection(prompts).intersection(videos)
150
+ all_of_em = meta.union(prompts).union(videos)
151
+ not_ok = []
152
+ for a in all_of_em:
153
+ if a not in ok:
154
+ not_ok.append(a)
155
+ ok = list(ok)
156
+ ok.sort()
157
+ with open(os.path.join(data_path, 'id_list.json'), 'w') as f:
158
+ json.dump(ok, f)
159
+
makeavid_sd/flax_impl/flax_attention_pseudo3d.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+ import einops
9
+
10
+ #from flax_memory_efficient_attention import jax_memory_efficient_attention
11
+ #from flax_attention import FlaxAttention
12
+ from diffusers.models.attention_flax import FlaxAttention
13
+
14
+
15
+ class TransformerPseudo3DModel(nn.Module):
16
+ in_channels: int
17
+ num_attention_heads: int
18
+ attention_head_dim: int
19
+ num_layers: int = 1
20
+ use_memory_efficient_attention: bool = False
21
+ dtype: jnp.dtype = jnp.float32
22
+
23
+ def setup(self) -> None:
24
+ inner_dim = self.num_attention_heads * self.attention_head_dim
25
+ self.norm = nn.GroupNorm(
26
+ num_groups = 32,
27
+ epsilon = 1e-5
28
+ )
29
+ self.proj_in = nn.Conv(
30
+ inner_dim,
31
+ kernel_size = (1, 1),
32
+ strides = (1, 1),
33
+ padding = 'VALID',
34
+ dtype = self.dtype
35
+ )
36
+ transformer_blocks = []
37
+ #CheckpointTransformerBlock = nn.checkpoint(
38
+ # BasicTransformerBlockPseudo3D,
39
+ # static_argnums = (2,3,4)
40
+ # #prevent_cse = False
41
+ #)
42
+ CheckpointTransformerBlock = BasicTransformerBlockPseudo3D
43
+ for _ in range(self.num_layers):
44
+ transformer_blocks.append(CheckpointTransformerBlock(
45
+ dim = inner_dim,
46
+ num_attention_heads = self.num_attention_heads,
47
+ attention_head_dim = self.attention_head_dim,
48
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
49
+ dtype = self.dtype
50
+ ))
51
+ self.transformer_blocks = transformer_blocks
52
+ self.proj_out = nn.Conv(
53
+ inner_dim,
54
+ kernel_size = (1, 1),
55
+ strides = (1, 1),
56
+ padding = 'VALID',
57
+ dtype = self.dtype
58
+ )
59
+
60
+ def __call__(self,
61
+ hidden_states: jax.Array,
62
+ encoder_hidden_states: Optional[jax.Array] = None
63
+ ) -> jax.Array:
64
+ is_video = hidden_states.ndim == 5
65
+ f: Optional[int] = None
66
+ if is_video:
67
+ # jax is channels last
68
+ # b,c,f,h,w WRONG
69
+ # b,f,h,w,c CORRECT
70
+ # b, c, f, h, w = hidden_states.shape
71
+ #hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w')
72
+ b, f, h, w, c = hidden_states.shape
73
+ hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
74
+
75
+ batch, height, width, channels = hidden_states.shape
76
+ residual = hidden_states
77
+ hidden_states = self.norm(hidden_states)
78
+ hidden_states = self.proj_in(hidden_states)
79
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
80
+ for block in self.transformer_blocks:
81
+ hidden_states = block(
82
+ hidden_states,
83
+ encoder_hidden_states,
84
+ f,
85
+ height,
86
+ width
87
+ )
88
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
89
+ hidden_states = self.proj_out(hidden_states)
90
+ hidden_states = hidden_states + residual
91
+ if is_video:
92
+ hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
93
+ return hidden_states
94
+
95
+
96
+ class BasicTransformerBlockPseudo3D(nn.Module):
97
+ dim: int
98
+ num_attention_heads: int
99
+ attention_head_dim: int
100
+ use_memory_efficient_attention: bool = False
101
+ dtype: jnp.dtype = jnp.float32
102
+
103
+ def setup(self) -> None:
104
+ self.attn1 = FlaxAttention(
105
+ query_dim = self.dim,
106
+ heads = self.num_attention_heads,
107
+ dim_head = self.attention_head_dim,
108
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
109
+ dtype = self.dtype
110
+ )
111
+ self.ff = FeedForward(dim = self.dim, dtype = self.dtype)
112
+ self.attn2 = FlaxAttention(
113
+ query_dim = self.dim,
114
+ heads = self.num_attention_heads,
115
+ dim_head = self.attention_head_dim,
116
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
117
+ dtype = self.dtype
118
+ )
119
+ self.attn_temporal = FlaxAttention(
120
+ query_dim = self.dim,
121
+ heads = self.num_attention_heads,
122
+ dim_head = self.attention_head_dim,
123
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
124
+ dtype = self.dtype
125
+ )
126
+ self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
127
+ self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
128
+ self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
129
+ self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
130
+
131
+ def __call__(self,
132
+ hidden_states: jax.Array,
133
+ context: Optional[jax.Array] = None,
134
+ frames_length: Optional[int] = None,
135
+ height: Optional[int] = None,
136
+ width: Optional[int] = None
137
+ ) -> jax.Array:
138
+ if context is not None and frames_length is not None:
139
+ context = context.repeat(frames_length, axis = 0)
140
+ # self attention
141
+ norm_hidden_states = self.norm1(hidden_states)
142
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
143
+ # cross attention
144
+ norm_hidden_states = self.norm2(hidden_states)
145
+ hidden_states = self.attn2(
146
+ norm_hidden_states,
147
+ context = context
148
+ ) + hidden_states
149
+ # temporal attention
150
+ if frames_length is not None:
151
+ #bf, hw, c = hidden_states.shape
152
+ # (b f) (h w) c -> b f (h w) c
153
+ #hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c)
154
+ #b, f, hw, c = hidden_states.shape
155
+ # b f (h w) c -> b (h w) f c
156
+ #hidden_states = hidden_states.transpose(0, 2, 1, 3)
157
+ # b (h w) f c -> (b h w) f c
158
+ #hidden_states = hidden_states.reshape(b * hw, frames_length, c)
159
+ hidden_states = einops.rearrange(
160
+ hidden_states,
161
+ '(b f) (h w) c -> (b h w) f c',
162
+ f = frames_length,
163
+ h = height,
164
+ w = width
165
+ )
166
+ norm_hidden_states = self.norm_temporal(hidden_states)
167
+ hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
168
+ # (b h w) f c -> b (h w) f c
169
+ #hidden_states = hidden_states.reshape(b, hw, f, c)
170
+ # b (h w) f c -> b f (h w) c
171
+ #hidden_states = hidden_states.transpose(0, 2, 1, 3)
172
+ # b f h w c -> (b f) (h w) c
173
+ #hidden_states = hidden_states.reshape(bf, hw, c)
174
+ hidden_states = einops.rearrange(
175
+ hidden_states,
176
+ '(b h w) f c -> (b f) (h w) c',
177
+ f = frames_length,
178
+ h = height,
179
+ w = width
180
+ )
181
+ norm_hidden_states = self.norm3(hidden_states)
182
+ hidden_states = self.ff(norm_hidden_states) + hidden_states
183
+ return hidden_states
184
+
185
+
186
+ class FeedForward(nn.Module):
187
+ dim: int
188
+ dtype: jnp.dtype = jnp.float32
189
+
190
+ def setup(self) -> None:
191
+ self.net_0 = GEGLU(self.dim, self.dtype)
192
+ self.net_2 = nn.Dense(self.dim, dtype = self.dtype)
193
+
194
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
195
+ hidden_states = self.net_0(hidden_states)
196
+ hidden_states = self.net_2(hidden_states)
197
+ return hidden_states
198
+
199
+
200
+ class GEGLU(nn.Module):
201
+ dim: int
202
+ dtype: jnp.dtype = jnp.float32
203
+
204
+ def setup(self) -> None:
205
+ inner_dim = self.dim * 4
206
+ self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype)
207
+
208
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
209
+ hidden_states = self.proj(hidden_states)
210
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2)
211
+ return hidden_linear * nn.gelu(hidden_gelu)
212
+
makeavid_sd/flax_impl/flax_embeddings.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import flax.linen as nn
5
+
6
+
7
+ def get_sinusoidal_embeddings(
8
+ timesteps: jax.Array,
9
+ embedding_dim: int,
10
+ freq_shift: float = 1,
11
+ min_timescale: float = 1,
12
+ max_timescale: float = 1.0e4,
13
+ flip_sin_to_cos: bool = False,
14
+ scale: float = 1.0,
15
+ dtype: jnp.dtype = jnp.float32
16
+ ) -> jax.Array:
17
+ assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
18
+ assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
19
+ num_timescales = float(embedding_dim // 2)
20
+ log_timescale_increment = jnp.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
21
+ inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype = dtype) * -log_timescale_increment)
22
+ emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
23
+
24
+ # scale embeddings
25
+ scaled_time = scale * emb
26
+
27
+ if flip_sin_to_cos:
28
+ signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis = 1)
29
+ else:
30
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = 1)
31
+ signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
32
+ return signal
33
+
34
+
35
+ class TimestepEmbedding(nn.Module):
36
+ time_embed_dim: int = 32
37
+ dtype: jnp.dtype = jnp.float32
38
+
39
+ @nn.compact
40
+ def __call__(self, temb: jax.Array) -> jax.Array:
41
+ temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_1")(temb)
42
+ temb = nn.silu(temb)
43
+ temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_2")(temb)
44
+ return temb
45
+
46
+
47
+ class Timesteps(nn.Module):
48
+ dim: int = 32
49
+ flip_sin_to_cos: bool = False
50
+ freq_shift: float = 1
51
+ dtype: jnp.dtype = jnp.float32
52
+
53
+ @nn.compact
54
+ def __call__(self, timesteps: jax.Array) -> jax.Array:
55
+ return get_sinusoidal_embeddings(
56
+ timesteps = timesteps,
57
+ embedding_dim = self.dim,
58
+ flip_sin_to_cos = self.flip_sin_to_cos,
59
+ freq_shift = self.freq_shift,
60
+ dtype = self.dtype
61
+ )
62
+
makeavid_sd/flax_impl/flax_resnet_pseudo3d.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional, Union, Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+ import einops
9
+
10
+
11
+ class ConvPseudo3D(nn.Module):
12
+ features: int
13
+ kernel_size: Sequence[int]
14
+ strides: Union[None, int, Sequence[int]] = 1
15
+ padding: nn.linear.PaddingLike = 'SAME'
16
+ dtype: jnp.dtype = jnp.float32
17
+
18
+ def setup(self) -> None:
19
+ self.spatial_conv = nn.Conv(
20
+ features = self.features,
21
+ kernel_size = self.kernel_size,
22
+ strides = self.strides,
23
+ padding = self.padding,
24
+ dtype = self.dtype
25
+ )
26
+ self.temporal_conv = nn.Conv(
27
+ features = self.features,
28
+ kernel_size = (3,),
29
+ padding = 'SAME',
30
+ dtype = self.dtype,
31
+ bias_init = nn.initializers.zeros_init()
32
+ # TODO dirac delta (identity) initialization impl
33
+ # kernel_init = torch.nn.init.dirac_ <-> jax/lax
34
+ )
35
+
36
+ def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array:
37
+ is_video = x.ndim == 5
38
+ convolve_across_time = convolve_across_time and is_video
39
+ if is_video:
40
+ b, f, h, w, c = x.shape
41
+ x = einops.rearrange(x, 'b f h w c -> (b f) h w c')
42
+ x = self.spatial_conv(x)
43
+ if is_video:
44
+ x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b)
45
+ b, f, h, w, c = x.shape
46
+ if not convolve_across_time:
47
+ return x
48
+ if is_video:
49
+ x = einops.rearrange(x, 'b f h w c -> (b h w) f c')
50
+ x = self.temporal_conv(x)
51
+ x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w)
52
+ return x
53
+
54
+
55
+ class UpsamplePseudo3D(nn.Module):
56
+ out_channels: int
57
+ dtype: jnp.dtype = jnp.float32
58
+
59
+ def setup(self) -> None:
60
+ self.conv = ConvPseudo3D(
61
+ features = self.out_channels,
62
+ kernel_size = (3, 3),
63
+ strides = (1, 1),
64
+ padding = ((1, 1), (1, 1)),
65
+ dtype = self.dtype
66
+ )
67
+
68
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
69
+ is_video = hidden_states.ndim == 5
70
+ if is_video:
71
+ b, *_ = hidden_states.shape
72
+ hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
73
+ batch, h, w, c = hidden_states.shape
74
+ hidden_states = jax.image.resize(
75
+ image = hidden_states,
76
+ shape = (batch, h * 2, w * 2, c),
77
+ method = 'nearest'
78
+ )
79
+ if is_video:
80
+ hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
81
+ hidden_states = self.conv(hidden_states)
82
+ return hidden_states
83
+
84
+
85
+ class DownsamplePseudo3D(nn.Module):
86
+ out_channels: int
87
+ dtype: jnp.dtype = jnp.float32
88
+
89
+ def setup(self) -> None:
90
+ self.conv = ConvPseudo3D(
91
+ features = self.out_channels,
92
+ kernel_size = (3, 3),
93
+ strides = (2, 2),
94
+ padding = ((1, 1), (1, 1)),
95
+ dtype = self.dtype
96
+ )
97
+
98
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
99
+ hidden_states = self.conv(hidden_states)
100
+ return hidden_states
101
+
102
+
103
+ class ResnetBlockPseudo3D(nn.Module):
104
+ in_channels: int
105
+ out_channels: Optional[int] = None
106
+ use_nin_shortcut: Optional[bool] = None
107
+ dtype: jnp.dtype = jnp.float32
108
+
109
+ def setup(self) -> None:
110
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
111
+ self.norm1 = nn.GroupNorm(
112
+ num_groups = 32,
113
+ epsilon = 1e-5
114
+ )
115
+ self.conv1 = ConvPseudo3D(
116
+ features = out_channels,
117
+ kernel_size = (3, 3),
118
+ strides = (1, 1),
119
+ padding = ((1, 1), (1, 1)),
120
+ dtype = self.dtype
121
+ )
122
+ self.time_emb_proj = nn.Dense(
123
+ out_channels,
124
+ dtype = self.dtype
125
+ )
126
+ self.norm2 = nn.GroupNorm(
127
+ num_groups = 32,
128
+ epsilon = 1e-5
129
+ )
130
+ self.conv2 = ConvPseudo3D(
131
+ features = out_channels,
132
+ kernel_size = (3, 3),
133
+ strides = (1, 1),
134
+ padding = ((1, 1), (1, 1)),
135
+ dtype = self.dtype
136
+ )
137
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
138
+ self.conv_shortcut = None
139
+ if use_nin_shortcut:
140
+ self.conv_shortcut = ConvPseudo3D(
141
+ features = self.out_channels,
142
+ kernel_size = (1, 1),
143
+ strides = (1, 1),
144
+ padding = 'VALID',
145
+ dtype = self.dtype
146
+ )
147
+
148
+ def __call__(self,
149
+ hidden_states: jax.Array,
150
+ temb: jax.Array
151
+ ) -> jax.Array:
152
+ is_video = hidden_states.ndim == 5
153
+ residual = hidden_states
154
+ hidden_states = self.norm1(hidden_states)
155
+ hidden_states = nn.silu(hidden_states)
156
+ hidden_states = self.conv1(hidden_states)
157
+ temb = nn.silu(temb)
158
+ temb = self.time_emb_proj(temb)
159
+ temb = jnp.expand_dims(temb, 1)
160
+ temb = jnp.expand_dims(temb, 1)
161
+ if is_video:
162
+ b, f, *_ = hidden_states.shape
163
+ hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
164
+ hidden_states = hidden_states + temb.repeat(f, 0)
165
+ hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
166
+ else:
167
+ hidden_states = hidden_states + temb
168
+ hidden_states = self.norm2(hidden_states)
169
+ hidden_states = nn.silu(hidden_states)
170
+ hidden_states = self.conv2(hidden_states)
171
+ if self.conv_shortcut is not None:
172
+ residual = self.conv_shortcut(residual)
173
+ hidden_states = hidden_states + residual
174
+ return hidden_states
175
+
makeavid_sd/flax_impl/flax_trainer.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Optional, Union, Tuple, Dict, List
3
+
4
+ import os
5
+ import random
6
+ import math
7
+ import time
8
+ import numpy as np
9
+ from tqdm.auto import tqdm, trange
10
+
11
+ import torch
12
+ from torch.utils.data import DataLoader
13
+
14
+ import jax
15
+ import jax.numpy as jnp
16
+ import optax
17
+ from flax import jax_utils, traverse_util
18
+ from flax.core.frozen_dict import FrozenDict
19
+ from flax.training.train_state import TrainState
20
+ from flax.training.common_utils import shard
21
+
22
+ # convert 2D -> 3D
23
+ from diffusers import FlaxUNet2DConditionModel
24
+
25
+ # inference test, run on these on cpu
26
+ from diffusers import AutoencoderKL
27
+ from diffusers.schedulers.scheduling_ddim_flax import FlaxDDIMScheduler, DDIMSchedulerState
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+ from PIL import Image
30
+
31
+
32
+ from .flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
33
+
34
+
35
+ def seed_all(seed: int) -> jax.random.PRNGKeyArray:
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ rng = jax.random.PRNGKey(seed)
40
+ return rng
41
+
42
+ def count_params(
43
+ params: Union[Dict[str, Any],
44
+ FrozenDict[str, Any]],
45
+ filter_name: Optional[str] = None
46
+ ) -> int:
47
+ p: Dict[Tuple[str], jax.Array] = traverse_util.flatten_dict(params)
48
+ cc = 0
49
+ for k in p:
50
+ if filter_name is not None:
51
+ if filter_name in ' '.join(k):
52
+ cc += len(p[k].flatten())
53
+ else:
54
+ cc += len(p[k].flatten())
55
+ return cc
56
+
57
+ def map_2d_to_pseudo3d(
58
+ params2d: Dict[str, Any],
59
+ params3d: Dict[str, Any],
60
+ verbose: bool = True
61
+ ) -> Dict[str, Any]:
62
+ params2d = traverse_util.flatten_dict(params2d)
63
+ params3d = traverse_util.flatten_dict(params3d)
64
+ new_params = dict()
65
+ for k in params3d:
66
+ if 'spatial_conv' in k:
67
+ k2d = list(k)
68
+ k2d.remove('spatial_conv')
69
+ k2d = tuple(k2d)
70
+ if verbose:
71
+ tqdm.write(f'Spatial: {k} <- {k2d}')
72
+ p = params2d[k2d]
73
+ elif k not in params2d:
74
+ if verbose:
75
+ tqdm.write(f'Missing: {k}')
76
+ p = params3d[k]
77
+ else:
78
+ p = params2d[k]
79
+ assert p.shape == params3d[k].shape, f'shape mismatch: {k}: {p.shape} != {params3d[k].shape}'
80
+ new_params[k] = p
81
+ new_params = traverse_util.unflatten_dict(new_params)
82
+ return new_params
83
+
84
+
85
+ class FlaxTrainerUNetPseudo3D:
86
+ def __init__(self,
87
+ model_path: str,
88
+ from_pt: bool = True,
89
+ convert2d: bool = False,
90
+ sample_size: Tuple[int, int] = (64, 64),
91
+ seed: int = 0,
92
+ dtype: str = 'float32',
93
+ param_dtype: str = 'float32',
94
+ only_temporal: bool = True,
95
+ use_memory_efficient_attention = False,
96
+ verbose: bool = True
97
+ ) -> None:
98
+ self.verbose = verbose
99
+ self.tracker: Optional['wandb.sdk.wandb_run.Run'] = None
100
+ self._use_wandb: bool = False
101
+ self._tracker_meta: Dict[str, Union[float, int]] = {
102
+ 't00': 0.0,
103
+ 't0': 0.0,
104
+ 'step0': 0
105
+ }
106
+
107
+ self.log('Init JAX')
108
+ self.num_devices = jax.device_count()
109
+ self.log(f'Device count: {self.num_devices}')
110
+
111
+ self.seed = seed
112
+ self.rng: jax.random.PRNGKeyArray = seed_all(self.seed)
113
+
114
+ self.sample_size = sample_size
115
+ if dtype == 'float32':
116
+ self.dtype = jnp.float32
117
+ elif dtype == 'bfloat16':
118
+ self.dtype = jnp.bfloat16
119
+ elif dtype == 'float16':
120
+ self.dtype = jnp.float16
121
+ else:
122
+ raise ValueError(f'unknown type: {dtype}')
123
+ self.dtype_str: str = dtype
124
+ if param_dtype not in ['float32', 'bfloat16', 'float16']:
125
+ raise ValueError(f'unknown parameter type: {param_dtype}')
126
+ self.param_dtype = param_dtype
127
+ self._load_models(
128
+ model_path = model_path,
129
+ convert2d = convert2d,
130
+ from_pt = from_pt,
131
+ use_memory_efficient_attention = use_memory_efficient_attention
132
+ )
133
+ self._mark_parameters(only_temporal = only_temporal)
134
+ # optionally for validation + sampling
135
+ self.tokenizer: Optional[CLIPTokenizer] = None
136
+ self.text_encoder: Optional[CLIPTextModel] = None
137
+ self.vae: Optional[AutoencoderKL] = None
138
+ self.ddim: Optional[Tuple[FlaxDDIMScheduler, DDIMSchedulerState]] = None
139
+
140
+ def log(self, message: Any) -> None:
141
+ if self.verbose and jax.process_index() == 0:
142
+ tqdm.write(str(message))
143
+
144
+ def log_metrics(self, metrics: dict, step: int, epoch: int) -> None:
145
+ if jax.process_index() > 0 or (not self.verbose and self.tracker is None):
146
+ return
147
+ now = time.monotonic()
148
+ log_data = {
149
+ 'train/step': step,
150
+ 'train/epoch': epoch,
151
+ 'train/steps_per_sec': (step - self._tracker_meta['step0']) / (now - self._tracker_meta['t0']),
152
+ **{ f'train/{k}': v for k, v in metrics.items() }
153
+ }
154
+ self._tracker_meta['t0'] = now
155
+ self._tracker_meta['step0'] = step
156
+ self.log(log_data)
157
+ if self.tracker is not None:
158
+ self.tracker.log(log_data, step = step)
159
+
160
+
161
+ def enable_wandb(self, enable: bool = True) -> None:
162
+ self._use_wandb = enable
163
+
164
+ def _setup_wandb(self, config: Dict[str, Any] = dict()) -> None:
165
+ import wandb
166
+ import wandb.sdk
167
+ self.tracker: wandb.sdk.wandb_run.Run = wandb.init(
168
+ config = config,
169
+ settings = wandb.sdk.Settings(
170
+ username = 'anon',
171
+ host = 'anon',
172
+ email = 'anon',
173
+ root_dir = 'anon',
174
+ _executable = 'anon',
175
+ _disable_stats = True,
176
+ _disable_meta = True,
177
+ disable_code = True,
178
+ disable_git = True
179
+ ) # pls don't log sensitive data like system user names. also, fuck you for even trying.
180
+ )
181
+
182
+ def _init_tracker_meta(self) -> None:
183
+ now = time.monotonic()
184
+ self._tracker_meta = {
185
+ 't00': now,
186
+ 't0': now,
187
+ 'step0': 0
188
+ }
189
+
190
+ def _load_models(self,
191
+ model_path: str,
192
+ convert2d: bool,
193
+ from_pt: bool,
194
+ use_memory_efficient_attention: bool
195
+ ) -> None:
196
+ self.log(f'Load pretrained from {model_path}')
197
+ if convert2d:
198
+ self.log(' Convert 2D model to Pseudo3D')
199
+ self.log(' Initiate Pseudo3D model')
200
+ config = UNetPseudo3DConditionModel.load_config(model_path, subfolder = 'unet')
201
+ model = UNetPseudo3DConditionModel.from_config(
202
+ config,
203
+ sample_size = self.sample_size,
204
+ dtype = self.dtype,
205
+ param_dtype = self.param_dtype,
206
+ use_memory_efficient_attention = use_memory_efficient_attention
207
+ )
208
+ params: Dict[str, Any] = model.init_weights(self.rng).unfreeze()
209
+ self.log(' Load 2D model')
210
+ model2d, params2d = FlaxUNet2DConditionModel.from_pretrained(
211
+ model_path,
212
+ subfolder = 'unet',
213
+ dtype = self.dtype,
214
+ from_pt = from_pt
215
+ )
216
+ self.log(' Map 2D -> 3D')
217
+ params = map_2d_to_pseudo3d(params2d, params, verbose = self.verbose)
218
+ del params2d
219
+ del model2d
220
+ del config
221
+ else:
222
+ model, params = UNetPseudo3DConditionModel.from_pretrained(
223
+ model_path,
224
+ subfolder = 'unet',
225
+ from_pt = from_pt,
226
+ sample_size = self.sample_size,
227
+ dtype = self.dtype,
228
+ param_dtype = self.param_dtype,
229
+ use_memory_efficient_attention = use_memory_efficient_attention
230
+ )
231
+ self.log(f'Cast parameters to {model.param_dtype}')
232
+ if model.param_dtype == 'float32':
233
+ params = model.to_fp32(params)
234
+ elif model.param_dtype == 'float16':
235
+ params = model.to_fp16(params)
236
+ elif model.param_dtype == 'bfloat16':
237
+ params = model.to_bf16(params)
238
+ self.pretrained_model = model_path
239
+ self.model: UNetPseudo3DConditionModel = model
240
+ self.params: FrozenDict[str, Any] = FrozenDict(params)
241
+
242
+ def _mark_parameters(self, only_temporal: bool) -> None:
243
+ self.log('Mark training parameters')
244
+ if only_temporal:
245
+ self.log('Only training temporal layers')
246
+ if only_temporal:
247
+ param_partitions = traverse_util.path_aware_map(
248
+ lambda path, _: 'trainable' if 'temporal' in ' '.join(path) else 'frozen', self.params
249
+ )
250
+ else:
251
+ param_partitions = traverse_util.path_aware_map(
252
+ lambda *_: 'trainable', self.params
253
+ )
254
+ self.only_temporal = only_temporal
255
+ self.param_partitions: FrozenDict[str, Any] = FrozenDict(param_partitions)
256
+ self.log(f'Total parameters: {count_params(self.params)}')
257
+ self.log(f'Temporal parameters: {count_params(self.params, "temporal")}')
258
+
259
+ def _load_inference_models(self) -> None:
260
+ assert jax.process_index() == 0, 'not main process'
261
+ if self.text_encoder is None:
262
+ self.log('Load text encoder')
263
+ self.text_encoder = CLIPTextModel.from_pretrained(
264
+ self.pretrained_model,
265
+ subfolder = 'text_encoder'
266
+ )
267
+ if self.tokenizer is None:
268
+ self.log('Load tokenizer')
269
+ self.tokenizer = CLIPTokenizer.from_pretrained(
270
+ self.pretrained_model,
271
+ subfolder = 'tokenizer'
272
+ )
273
+ if self.vae is None:
274
+ self.log('Load vae')
275
+ self.vae = AutoencoderKL.from_pretrained(
276
+ self.pretrained_model,
277
+ subfolder = 'vae'
278
+ )
279
+ if self.ddim is None:
280
+ self.log('Load ddim scheduler')
281
+ # tuple(scheduler , scheduler state)
282
+ self.ddim = FlaxDDIMScheduler.from_pretrained(
283
+ self.pretrained_model,
284
+ subfolder = 'scheduler',
285
+ from_pt = True
286
+ )
287
+
288
+ def _unload_inference_models(self) -> None:
289
+ self.text_encoder = None
290
+ self.tokenizer = None
291
+ self.vae = None
292
+ self.ddim = None
293
+
294
+ def sample(self,
295
+ params: Union[Dict[str, Any], FrozenDict[str, Any]],
296
+ prompt: str,
297
+ image_path: str,
298
+ num_frames: int,
299
+ replicate_params: bool = True,
300
+ neg_prompt: str = '',
301
+ steps: int = 50,
302
+ cfg: float = 9.0,
303
+ unload_after_usage: bool = False
304
+ ) -> List[Image.Image]:
305
+ assert jax.process_index() == 0, 'not main process'
306
+ self.log('Sample')
307
+ self._load_inference_models()
308
+ with torch.no_grad():
309
+ tokens = self.tokenizer(
310
+ [ prompt ],
311
+ truncation = True,
312
+ return_overflowing_tokens = False,
313
+ padding = 'max_length',
314
+ return_tensors = 'pt'
315
+ ).input_ids
316
+ neg_tokens = self.tokenizer(
317
+ [ neg_prompt ],
318
+ truncation = True,
319
+ return_overflowing_tokens = False,
320
+ padding = 'max_length',
321
+ return_tensors = 'pt'
322
+ ).input_ids
323
+ encoded_prompt = self.text_encoder(input_ids = tokens).last_hidden_state
324
+ encoded_neg_prompt = self.text_encoder(input_ids = neg_tokens).last_hidden_state
325
+ hint_latent = torch.tensor(np.asarray(Image.open(image_path))).permute(2,0,1).to(torch.float32).div(255).mul(2).sub(1).unsqueeze(0)
326
+ hint_latent = self.vae.encode(hint_latent).latent_dist.mean * self.vae.config.scaling_factor #0.18215 # deterministic
327
+ hint_latent = hint_latent.unsqueeze(2).repeat_interleave(num_frames, 2)
328
+ mask = torch.zeros_like(hint_latent[:,0:1,:,:,:]) # zero mask, e.g. skip masking for now
329
+ init_latent = torch.randn_like(hint_latent)
330
+ # move to devices
331
+ encoded_prompt = jnp.array(encoded_prompt.numpy())
332
+ encoded_neg_prompt = jnp.array(encoded_neg_prompt.numpy())
333
+ hint_latent = jnp.array(hint_latent.numpy())
334
+ mask = jnp.array(mask.numpy())
335
+ init_latent = init_latent.repeat(jax.device_count(), 1, 1, 1, 1)
336
+ init_latent = jnp.array(init_latent.numpy())
337
+ self.ddim = (self.ddim[0], self.ddim[0].set_timesteps(self.ddim[1], steps))
338
+ timesteps = self.ddim[1].timesteps
339
+ if replicate_params:
340
+ params = jax_utils.replicate(params)
341
+ ddim_state = jax_utils.replicate(self.ddim[1])
342
+ encoded_prompt = jax_utils.replicate(encoded_prompt)
343
+ encoded_neg_prompt = jax_utils.replicate(encoded_neg_prompt)
344
+ hint_latent = jax_utils.replicate(hint_latent)
345
+ mask = jax_utils.replicate(mask)
346
+ # sampling fun
347
+ def sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask):
348
+ latent_model_input = jnp.concatenate([init_latent, mask, hint_latent], axis = 1)
349
+ pred = self.model.apply(
350
+ { 'params': params },
351
+ latent_model_input,
352
+ t,
353
+ encoded_prompt
354
+ ).sample
355
+ if cfg != 1.0:
356
+ neg_pred = self.model.apply(
357
+ { 'params': params },
358
+ latent_model_input,
359
+ t,
360
+ encoded_neg_prompt
361
+ ).sample
362
+ pred = neg_pred + cfg * (pred - neg_pred)
363
+ # TODO check if noise is added at the right dimension
364
+ init_latent, ddim_state = self.ddim[0].step(ddim_state, pred, t, init_latent).to_tuple()
365
+ return init_latent, ddim_state
366
+ p_sample_loop = jax.pmap(sample_loop, 'sample', donate_argnums = ())
367
+ pbar_sample = trange(len(timesteps), desc = 'Sample', dynamic_ncols = True, smoothing = 0.1, disable = not self.verbose)
368
+ init_latent = shard(init_latent)
369
+ for i in pbar_sample:
370
+ t = timesteps[i].repeat(self.num_devices)
371
+ t = shard(t)
372
+ init_latent, ddim_state = p_sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask)
373
+ # decode
374
+ self.log('Decode')
375
+ init_latent = torch.tensor(np.array(init_latent))
376
+ init_latent = init_latent / self.vae.config.scaling_factor
377
+ # d:0 b:1 c:2 f:3 h:4 w:5 -> d b f c h w
378
+ init_latent = init_latent.permute(0, 1, 3, 2, 4, 5)
379
+ images = []
380
+ pbar_decode = trange(len(init_latent), desc = 'Decode', dynamic_ncols = True)
381
+ for sample in init_latent:
382
+ ims = self.vae.decode(sample.squeeze()).sample
383
+ ims = ims.add(1).div(2).mul(255).round().clamp(0, 255).to(torch.uint8).permute(0,2,3,1).numpy()
384
+ ims = [ Image.fromarray(x) for x in ims ]
385
+ for im in ims:
386
+ images.append(im)
387
+ pbar_decode.update(1)
388
+ if unload_after_usage:
389
+ self._unload_inference_models()
390
+ return images
391
+
392
+ def get_params_from_state(self, state: TrainState) -> FrozenDict[Any, str]:
393
+ return FrozenDict(jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)))
394
+
395
+ def train(self,
396
+ dataloader: DataLoader,
397
+ lr: float,
398
+ num_frames: int,
399
+ log_every_step: int = 10,
400
+ save_every_epoch: int = 1,
401
+ sample_every_epoch: int = 1,
402
+ output_dir: str = 'output',
403
+ warmup: float = 0,
404
+ decay: float = 0,
405
+ epochs: int = 10,
406
+ weight_decay: float = 1e-2
407
+ ) -> None:
408
+ eps = 1e-8
409
+ total_steps = len(dataloader) * epochs
410
+ warmup_steps = math.ceil(warmup * total_steps) if warmup > 0 else 0
411
+ decay_steps = math.ceil(decay * total_steps) + warmup_steps if decay > 0 else warmup_steps + 1
412
+ self.log(f'Total steps: {total_steps}')
413
+ self.log(f'Warmup steps: {warmup_steps}')
414
+ self.log(f'Decay steps: {decay_steps - warmup_steps}')
415
+ if warmup > 0 or decay > 0:
416
+ if not decay > 0:
417
+ # only warmup, keep peak lr until end
418
+ self.log('Warmup schedule')
419
+ end_lr = lr
420
+ else:
421
+ # warmup + annealing to end lr
422
+ self.log('Warmup + cosine annealing schedule')
423
+ end_lr = eps
424
+ lr_schedule = optax.warmup_cosine_decay_schedule(
425
+ init_value = 0.0,
426
+ peak_value = lr,
427
+ warmup_steps = warmup_steps,
428
+ decay_steps = decay_steps,
429
+ end_value = end_lr
430
+ )
431
+ else:
432
+ # no warmup or decay -> constant lr
433
+ self.log('constant schedule')
434
+ lr_schedule = optax.constant_schedule(value = lr)
435
+ adamw = optax.adamw(
436
+ learning_rate = lr_schedule,
437
+ b1 = 0.9,
438
+ b2 = 0.999,
439
+ eps = eps,
440
+ weight_decay = weight_decay #0.01 # 0.0001
441
+ )
442
+ optim = optax.chain(
443
+ optax.clip_by_global_norm(max_norm = 1.0),
444
+ adamw
445
+ )
446
+ partition_optimizers = {
447
+ 'trainable': optim,
448
+ 'frozen': optax.set_to_zero()
449
+ }
450
+ tx = optax.multi_transform(partition_optimizers, self.param_partitions)
451
+ state = TrainState.create(
452
+ apply_fn = self.model.__call__,
453
+ params = self.params,
454
+ tx = tx
455
+ )
456
+ validation_rng, train_rngs = jax.random.split(self.rng)
457
+ train_rngs = jax.random.split(train_rngs, jax.local_device_count())
458
+
459
+ def train_step(state: TrainState, batch: Dict[str, jax.Array], train_rng: jax.random.PRNGKeyArray):
460
+ def compute_loss(
461
+ params: Dict[str, Any],
462
+ batch: Dict[str, jax.Array],
463
+ sample_rng: jax.random.PRNGKeyArray # unused, dataloader provides everything
464
+ ) -> jax.Array:
465
+ # 'latent_model_input': latent_model_input
466
+ # 'encoder_hidden_states': encoder_hidden_states
467
+ # 'timesteps': timesteps
468
+ # 'noise': noise
469
+ latent_model_input = batch['latent_model_input']
470
+ encoder_hidden_states = batch['encoder_hidden_states']
471
+ timesteps = batch['timesteps']
472
+ noise = batch['noise']
473
+ model_pred = self.model.apply(
474
+ { 'params': params },
475
+ latent_model_input,
476
+ timesteps,
477
+ encoder_hidden_states
478
+ ).sample
479
+ loss = (noise - model_pred) ** 2
480
+ loss = loss.mean()
481
+ return loss
482
+ grad_fn = jax.value_and_grad(compute_loss)
483
+
484
+ def loss_and_grad(
485
+ train_rng: jax.random.PRNGKeyArray
486
+ ) -> Tuple[jax.Array, Any, jax.random.PRNGKeyArray]:
487
+ sample_rng, train_rng = jax.random.split(train_rng, 2)
488
+ loss, grad = grad_fn(state.params, batch, sample_rng)
489
+ return loss, grad, train_rng
490
+
491
+ loss, grad, new_train_rng = loss_and_grad(train_rng)
492
+ # self.log(grad) # NOTE uncomment to visualize gradient
493
+ grad = jax.lax.pmean(grad, axis_name = 'batch')
494
+ new_state = state.apply_gradients(grads = grad)
495
+ metrics: Dict[str, Any] = { 'loss': loss }
496
+ metrics = jax.lax.pmean(metrics, axis_name = 'batch')
497
+ def l2(xs) -> jax.Array:
498
+ return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
499
+ metrics['l2_grads'] = l2(jax.tree_util.tree_leaves(grad))
500
+
501
+ return new_state, metrics, new_train_rng
502
+
503
+ p_train_step = jax.pmap(fun = train_step, axis_name = 'batch', donate_argnums = (0, ))
504
+ state = jax_utils.replicate(state)
505
+
506
+ train_metrics = []
507
+ train_metric = None
508
+
509
+ global_step: int = 0
510
+
511
+ if jax.process_index() == 0:
512
+ self._init_tracker_meta()
513
+ hyper_params = {
514
+ 'lr': lr,
515
+ 'lr_warmup': warmup,
516
+ 'lr_decay': decay,
517
+ 'weight_decay': weight_decay,
518
+ 'total_steps': total_steps,
519
+ 'batch_size': dataloader.batch_size // self.num_devices,
520
+ 'num_frames': num_frames,
521
+ 'sample_size': self.sample_size,
522
+ 'num_devices': self.num_devices,
523
+ 'seed': self.seed,
524
+ 'use_memory_efficient_attention': self.model.use_memory_efficient_attention,
525
+ 'only_temporal': self.only_temporal,
526
+ 'dtype': self.dtype_str,
527
+ 'param_dtype': self.param_dtype,
528
+ 'pretrained_model': self.pretrained_model,
529
+ 'model_config': self.model.config
530
+ }
531
+ if self._use_wandb:
532
+ self.log('Setting up wandb')
533
+ self._setup_wandb(hyper_params)
534
+ self.log(hyper_params)
535
+ output_path = os.path.join(output_dir, str(global_step), 'unet')
536
+ self.log(f'saving checkpoint to {output_path}')
537
+ self.model.save_pretrained(
538
+ save_directory = output_path,
539
+ params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
540
+ is_main_process = True
541
+ )
542
+
543
+ pbar_epoch = tqdm(
544
+ total = epochs,
545
+ desc = 'Epochs',
546
+ smoothing = 1,
547
+ position = 0,
548
+ dynamic_ncols = True,
549
+ leave = True,
550
+ disable = jax.process_index() > 0
551
+ )
552
+ steps_per_epoch = len(dataloader) # TODO dataloader
553
+ for epoch in range(epochs):
554
+ pbar_steps = tqdm(
555
+ total = steps_per_epoch,
556
+ desc = 'Steps',
557
+ position = 1,
558
+ smoothing = 0.1,
559
+ dynamic_ncols = True,
560
+ leave = True,
561
+ disable = jax.process_index() > 0
562
+ )
563
+ for batch in dataloader:
564
+ # keep input + gt as float32, results in fp32 loss and grad
565
+ # otherwise uncomment the following to cast to the model dtype
566
+ # batch = { k: (v.astype(self.dtype) if v.dtype == np.float32 else v) for k,v in batch.items() }
567
+ batch = shard(batch)
568
+ state, train_metric, train_rngs = p_train_step(
569
+ state, batch, train_rngs
570
+ )
571
+ train_metrics.append(train_metric)
572
+ if global_step % log_every_step == 0 and jax.process_index() == 0:
573
+ train_metrics = jax_utils.unreplicate(train_metrics)
574
+ train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
575
+ if global_step == 0:
576
+ self.log(f'grad dtype: {train_metrics["l2_grads"].dtype}')
577
+ self.log(f'loss dtype: {train_metrics["loss"].dtype}')
578
+ train_metrics_dict = { k: v.item() for k, v in train_metrics.items() }
579
+ train_metrics_dict['lr'] = lr_schedule(global_step).item()
580
+ self.log_metrics(train_metrics_dict, step = global_step, epoch = epoch)
581
+ train_metrics = []
582
+ pbar_steps.update(1)
583
+ global_step += 1
584
+ if epoch % save_every_epoch == 0 and jax.process_index() == 0:
585
+ output_path = os.path.join(output_dir, str(global_step), 'unet')
586
+ self.log(f'saving checkpoint to {output_path}')
587
+ self.model.save_pretrained(
588
+ save_directory = output_path,
589
+ params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
590
+ is_main_process = True
591
+ )
592
+ self.log(f'checkpoint saved ')
593
+ if epoch % sample_every_epoch == 0 and jax.process_index() == 0:
594
+ images = self.sample(
595
+ params = state.params,
596
+ replicate_params = False,
597
+ prompt = 'dancing person',
598
+ image_path = 'testimage.png',
599
+ num_frames = num_frames,
600
+ steps = 50,
601
+ cfg = 9.0,
602
+ unload_after_usage = False
603
+ )
604
+ os.makedirs(os.path.join('image_output', str(epoch)), exist_ok = True)
605
+ for i, im in enumerate(images):
606
+ im.save(os.path.join('image_output', str(epoch), str(i).zfill(5) + '.png'), optimize = True)
607
+ pbar_epoch.update(1)
608
+
makeavid_sd/flax_impl/flax_unet_pseudo3d_blocks.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+ from .flax_attention_pseudo3d import TransformerPseudo3DModel
9
+ from .flax_resnet_pseudo3d import ResnetBlockPseudo3D, DownsamplePseudo3D, UpsamplePseudo3D
10
+
11
+
12
+ class UNetMidBlockPseudo3DCrossAttn(nn.Module):
13
+ in_channels: int
14
+ num_layers: int = 1
15
+ attn_num_head_channels: int = 1
16
+ use_memory_efficient_attention: bool = False
17
+ dtype: jnp.dtype = jnp.float32
18
+
19
+ def setup(self) -> None:
20
+ resnets = [
21
+ ResnetBlockPseudo3D(
22
+ in_channels = self.in_channels,
23
+ out_channels = self.in_channels,
24
+ dtype = self.dtype
25
+ )
26
+ ]
27
+ attentions = []
28
+ for _ in range(self.num_layers):
29
+ attn_block = TransformerPseudo3DModel(
30
+ in_channels = self.in_channels,
31
+ num_attention_heads = self.attn_num_head_channels,
32
+ attention_head_dim = self.in_channels // self.attn_num_head_channels,
33
+ num_layers = 1,
34
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
35
+ dtype = self.dtype
36
+ )
37
+ attentions.append(attn_block)
38
+ res_block = ResnetBlockPseudo3D(
39
+ in_channels = self.in_channels,
40
+ out_channels = self.in_channels,
41
+ dtype = self.dtype
42
+ )
43
+ resnets.append(res_block)
44
+ self.attentions = attentions
45
+ self.resnets = resnets
46
+
47
+ def __call__(self,
48
+ hidden_states: jax.Array,
49
+ temb: jax.Array,
50
+ encoder_hidden_states = jax.Array
51
+ ) -> jax.Array:
52
+ hidden_states = self.resnets[0](hidden_states, temb)
53
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
54
+ hidden_states = attn(hidden_states, encoder_hidden_states)
55
+ hidden_states = resnet(hidden_states, temb)
56
+ return hidden_states
57
+
58
+
59
+ class CrossAttnDownBlockPseudo3D(nn.Module):
60
+ in_channels: int
61
+ out_channels: int
62
+ num_layers: int = 1
63
+ attn_num_head_channels: int = 1
64
+ add_downsample: bool = True
65
+ use_memory_efficient_attention: bool = False
66
+ dtype: jnp.dtype = jnp.float32
67
+
68
+ def setup(self) -> None:
69
+ attentions = []
70
+ resnets = []
71
+ for i in range(self.num_layers):
72
+ in_channels = self.in_channels if i == 0 else self.out_channels
73
+ res_block = ResnetBlockPseudo3D(
74
+ in_channels = in_channels,
75
+ out_channels = self.out_channels,
76
+ dtype = self.dtype
77
+ )
78
+ resnets.append(res_block)
79
+ attn_block = TransformerPseudo3DModel(
80
+ in_channels = self.out_channels,
81
+ num_attention_heads = self.attn_num_head_channels,
82
+ attention_head_dim = self.out_channels // self.attn_num_head_channels,
83
+ num_layers = 1,
84
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
85
+ dtype = self.dtype
86
+ )
87
+ attentions.append(attn_block)
88
+ self.resnets = resnets
89
+ self.attentions = attentions
90
+
91
+ if self.add_downsample:
92
+ self.downsamplers_0 = DownsamplePseudo3D(
93
+ out_channels = self.out_channels,
94
+ dtype = self.dtype
95
+ )
96
+ else:
97
+ self.downsamplers_0 = None
98
+
99
+ def __call__(self,
100
+ hidden_states: jax.Array,
101
+ temb: jax.Array,
102
+ encoder_hidden_states: jax.Array
103
+ ) -> Tuple[jax.Array, jax.Array]:
104
+ output_states = ()
105
+ for resnet, attn in zip(self.resnets, self.attentions):
106
+ hidden_states = resnet(hidden_states, temb)
107
+ hidden_states = attn(hidden_states, encoder_hidden_states)
108
+ output_states += (hidden_states, )
109
+ if self.add_downsample:
110
+ hidden_states = self.downsamplers_0(hidden_states)
111
+ output_states += (hidden_states, )
112
+ return hidden_states, output_states
113
+
114
+
115
+ class DownBlockPseudo3D(nn.Module):
116
+ in_channels: int
117
+ out_channels: int
118
+ num_layers: int = 1
119
+ add_downsample: bool = True
120
+ dtype: jnp.dtype = jnp.float32
121
+
122
+ def setup(self) -> None:
123
+ resnets = []
124
+ for i in range(self.num_layers):
125
+ in_channels = self.in_channels if i == 0 else self.out_channels
126
+ res_block = ResnetBlockPseudo3D(
127
+ in_channels = in_channels,
128
+ out_channels = self.out_channels,
129
+ dtype = self.dtype
130
+ )
131
+ resnets.append(res_block)
132
+ self.resnets = resnets
133
+ if self.add_downsample:
134
+ self.downsamplers_0 = DownsamplePseudo3D(
135
+ out_channels = self.out_channels,
136
+ dtype = self.dtype
137
+ )
138
+ else:
139
+ self.downsamplers_0 = None
140
+
141
+ def __call__(self,
142
+ hidden_states: jax.Array,
143
+ temb: jax.Array
144
+ ) -> Tuple[jax.Array, jax.Array]:
145
+ output_states = ()
146
+ for resnet in self.resnets:
147
+ hidden_states = resnet(hidden_states, temb)
148
+ output_states += (hidden_states, )
149
+ if self.add_downsample:
150
+ hidden_states = self.downsamplers_0(hidden_states)
151
+ output_states += (hidden_states, )
152
+ return hidden_states, output_states
153
+
154
+
155
+ class CrossAttnUpBlockPseudo3D(nn.Module):
156
+ in_channels: int
157
+ out_channels: int
158
+ prev_output_channels: int
159
+ num_layers: int = 1
160
+ attn_num_head_channels: int = 1
161
+ add_upsample: bool = True
162
+ use_memory_efficient_attention: bool = False
163
+ dtype: jnp.dtype = jnp.float32
164
+
165
+ def setup(self) -> None:
166
+ resnets = []
167
+ attentions = []
168
+ for i in range(self.num_layers):
169
+ res_skip_channels = self.in_channels if (i == self.num_layers -1) else self.out_channels
170
+ resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
171
+ res_block = ResnetBlockPseudo3D(
172
+ in_channels = resnet_in_channels + res_skip_channels,
173
+ out_channels = self.out_channels,
174
+ dtype = self.dtype
175
+ )
176
+ resnets.append(res_block)
177
+ attn_block = TransformerPseudo3DModel(
178
+ in_channels = self.out_channels,
179
+ num_attention_heads = self.attn_num_head_channels,
180
+ attention_head_dim = self.out_channels // self.attn_num_head_channels,
181
+ num_layers = 1,
182
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
183
+ dtype = self.dtype
184
+ )
185
+ attentions.append(attn_block)
186
+ self.resnets = resnets
187
+ self.attentions = attentions
188
+ if self.add_upsample:
189
+ self.upsamplers_0 = UpsamplePseudo3D(
190
+ out_channels = self.out_channels,
191
+ dtype = self.dtype
192
+ )
193
+ else:
194
+ self.upsamplers_0 = None
195
+
196
+ def __call__(self,
197
+ hidden_states: jax.Array,
198
+ res_hidden_states_tuple: Tuple[jax.Array, ...],
199
+ temb: jax.Array,
200
+ encoder_hidden_states: jax.Array
201
+ ) -> jax.Array:
202
+ for resnet, attn in zip(self.resnets, self.attentions):
203
+ res_hidden_states = res_hidden_states_tuple[-1]
204
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
205
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis = -1)
206
+ hidden_states = resnet(hidden_states, temb)
207
+ hidden_states = attn(hidden_states, encoder_hidden_states)
208
+ if self.add_upsample:
209
+ hidden_states = self.upsamplers_0(hidden_states)
210
+ return hidden_states
211
+
212
+
213
+ class UpBlockPseudo3D(nn.Module):
214
+ in_channels: int
215
+ out_channels: int
216
+ prev_output_channels: int
217
+ num_layers: int = 1
218
+ add_upsample: bool = True
219
+ dtype: jnp.dtype = jnp.float32
220
+
221
+ def setup(self) -> None:
222
+ resnets = []
223
+ for i in range(self.num_layers):
224
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
225
+ resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
226
+ res_block = ResnetBlockPseudo3D(
227
+ in_channels = resnet_in_channels + res_skip_channels,
228
+ out_channels = self.out_channels,
229
+ dtype = self.dtype
230
+ )
231
+ resnets.append(res_block)
232
+ self.resnets = resnets
233
+ if self.add_upsample:
234
+ self.upsamplers_0 = UpsamplePseudo3D(
235
+ out_channels = self.out_channels,
236
+ dtype = self.dtype
237
+ )
238
+ else:
239
+ self.upsamplers_0 = None
240
+
241
+ def __call__(self,
242
+ hidden_states: jax.Array,
243
+ res_hidden_states_tuple: Tuple[jax.Array, ...],
244
+ temb: jax.Array
245
+ ) -> jax.Array:
246
+ for resnet in self.resnets:
247
+ res_hidden_states = res_hidden_states_tuple[-1]
248
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
249
+ hidden_states = jnp.concatenate([hidden_states, res_hidden_states], axis = -1)
250
+ hidden_states = resnet(hidden_states, temb)
251
+ if self.add_upsample:
252
+ hidden_states = self.upsamplers_0(hidden_states)
253
+ return hidden_states
254
+
makeavid_sd/flax_impl/flax_unet_pseudo3d_condition.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, Union
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict
8
+
9
+ from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
10
+ from diffusers.models.modeling_flax_utils import FlaxModelMixin
11
+ from diffusers.utils import BaseOutput
12
+
13
+ from .flax_unet_pseudo3d_blocks import (
14
+ CrossAttnDownBlockPseudo3D,
15
+ CrossAttnUpBlockPseudo3D,
16
+ DownBlockPseudo3D,
17
+ UpBlockPseudo3D,
18
+ UNetMidBlockPseudo3DCrossAttn
19
+ )
20
+ #from flax_embeddings import (
21
+ # TimestepEmbedding,
22
+ # Timesteps
23
+ #)
24
+ from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from .flax_resnet_pseudo3d import ConvPseudo3D
26
+
27
+
28
+ class UNetPseudo3DConditionOutput(BaseOutput):
29
+ sample: jax.Array
30
+
31
+
32
+ @flax_register_to_config
33
+ class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
34
+ sample_size: Union[int, Tuple[int, int]] = (64, 64)
35
+ in_channels: int = 4
36
+ out_channels: int = 4
37
+ down_block_types: Tuple[str] = (
38
+ "CrossAttnDownBlockPseudo3D",
39
+ "CrossAttnDownBlockPseudo3D",
40
+ "CrossAttnDownBlockPseudo3D",
41
+ "DownBlockPseudo3D"
42
+ )
43
+ up_block_types: Tuple[str] = (
44
+ "UpBlockPseudo3D",
45
+ "CrossAttnUpBlockPseudo3D",
46
+ "CrossAttnUpBlockPseudo3D",
47
+ "CrossAttnUpBlockPseudo3D"
48
+ )
49
+ block_out_channels: Tuple[int] = (
50
+ 320,
51
+ 640,
52
+ 1280,
53
+ 1280
54
+ )
55
+ layers_per_block: int = 2
56
+ attention_head_dim: Union[int, Tuple[int]] = 8
57
+ cross_attention_dim: int = 768
58
+ flip_sin_to_cos: bool = True
59
+ freq_shift: int = 0
60
+ use_memory_efficient_attention: bool = False
61
+ dtype: jnp.dtype = jnp.float32
62
+ param_dtype: str = 'float32'
63
+
64
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
65
+ if self.param_dtype == 'bfloat16':
66
+ param_dtype = jnp.bfloat16
67
+ elif self.param_dtype == 'float16':
68
+ param_dtype = jnp.float16
69
+ elif self.param_dtype == 'float32':
70
+ param_dtype = jnp.float32
71
+ else:
72
+ raise ValueError(f'unknown parameter type: {self.param_dtype}')
73
+ sample_size = self.sample_size
74
+ if isinstance(sample_size, int):
75
+ sample_size = (sample_size, sample_size)
76
+ sample_shape = (1, self.in_channels, 1, *sample_size)
77
+ sample = jnp.zeros(sample_shape, dtype = param_dtype)
78
+ timesteps = jnp.ones((1, ), dtype = jnp.int32)
79
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype)
80
+ params_rng, dropout_rng = jax.random.split(rng)
81
+ rngs = { "params": params_rng, "dropout": dropout_rng }
82
+ return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
83
+
84
+ def setup(self) -> None:
85
+ if isinstance(self.attention_head_dim, int):
86
+ attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types)
87
+ else:
88
+ attention_head_dim = self.attention_head_dim
89
+ time_embed_dim = self.block_out_channels[0] * 4
90
+ self.conv_in = ConvPseudo3D(
91
+ features = self.block_out_channels[0],
92
+ kernel_size = (3, 3),
93
+ strides = (1, 1),
94
+ padding = ((1, 1), (1, 1)),
95
+ dtype = self.dtype
96
+ )
97
+ self.time_proj = FlaxTimesteps(
98
+ dim = self.block_out_channels[0],
99
+ flip_sin_to_cos = self.flip_sin_to_cos,
100
+ freq_shift = self.freq_shift
101
+ )
102
+ self.time_embedding = FlaxTimestepEmbedding(
103
+ time_embed_dim = time_embed_dim,
104
+ dtype = self.dtype
105
+ )
106
+ down_blocks = []
107
+ output_channels = self.block_out_channels[0]
108
+ for i, down_block_type in enumerate(self.down_block_types):
109
+ input_channels = output_channels
110
+ output_channels = self.block_out_channels[i]
111
+ is_final_block = i == len(self.block_out_channels) - 1
112
+ # allows loading 3d models with old layer type names in their configs
113
+ # eg. 2D instead of Pseudo3D, like lxj's timelapse model
114
+ if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']:
115
+ down_block = CrossAttnDownBlockPseudo3D(
116
+ in_channels = input_channels,
117
+ out_channels = output_channels,
118
+ num_layers = self.layers_per_block,
119
+ attn_num_head_channels = attention_head_dim[i],
120
+ add_downsample = not is_final_block,
121
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
122
+ dtype = self.dtype
123
+ )
124
+ elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']:
125
+ down_block = DownBlockPseudo3D(
126
+ in_channels = input_channels,
127
+ out_channels = output_channels,
128
+ num_layers = self.layers_per_block,
129
+ add_downsample = not is_final_block,
130
+ dtype = self.dtype
131
+ )
132
+ else:
133
+ raise NotImplementedError(f'Unimplemented down block type: {down_block_type}')
134
+ down_blocks.append(down_block)
135
+ self.down_blocks = down_blocks
136
+ self.mid_block = UNetMidBlockPseudo3DCrossAttn(
137
+ in_channels = self.block_out_channels[-1],
138
+ attn_num_head_channels = attention_head_dim[-1],
139
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
140
+ dtype = self.dtype
141
+ )
142
+ up_blocks = []
143
+ reversed_block_out_channels = list(reversed(self.block_out_channels))
144
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
145
+ output_channels = reversed_block_out_channels[0]
146
+ for i, up_block_type in enumerate(self.up_block_types):
147
+ prev_output_channels = output_channels
148
+ output_channels = reversed_block_out_channels[i]
149
+ input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)]
150
+ is_final_block = i == len(self.block_out_channels) - 1
151
+ if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']:
152
+ up_block = CrossAttnUpBlockPseudo3D(
153
+ in_channels = input_channels,
154
+ out_channels = output_channels,
155
+ prev_output_channels = prev_output_channels,
156
+ num_layers = self.layers_per_block + 1,
157
+ attn_num_head_channels = reversed_attention_head_dim[i],
158
+ add_upsample = not is_final_block,
159
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
160
+ dtype = self.dtype
161
+ )
162
+ elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']:
163
+ up_block = UpBlockPseudo3D(
164
+ in_channels = input_channels,
165
+ out_channels = output_channels,
166
+ prev_output_channels = prev_output_channels,
167
+ num_layers = self.layers_per_block + 1,
168
+ add_upsample = not is_final_block,
169
+ dtype = self.dtype
170
+ )
171
+ else:
172
+ raise NotImplementedError(f'Unimplemented up block type: {up_block_type}')
173
+ up_blocks.append(up_block)
174
+ self.up_blocks = up_blocks
175
+ self.conv_norm_out = nn.GroupNorm(
176
+ num_groups = 32,
177
+ epsilon = 1e-5
178
+ )
179
+ self.conv_out = ConvPseudo3D(
180
+ features = self.out_channels,
181
+ kernel_size = (3, 3),
182
+ strides = (1, 1),
183
+ padding = ((1, 1), (1, 1)),
184
+ dtype = self.dtype
185
+ )
186
+
187
+ def __call__(self,
188
+ sample: jax.Array,
189
+ timesteps: jax.Array,
190
+ encoder_hidden_states: jax.Array,
191
+ return_dict: bool = True
192
+ ) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]:
193
+ if timesteps.dtype != jnp.float32:
194
+ timesteps = timesteps.astype(dtype = jnp.float32)
195
+ if len(timesteps.shape) == 0:
196
+ timesteps = jnp.expand_dims(timesteps, 0)
197
+ # b,c,f,h,w -> b,f,h,w,c
198
+ sample = sample.transpose((0, 2, 3, 4, 1))
199
+
200
+ t_emb = self.time_proj(timesteps)
201
+ t_emb = self.time_embedding(t_emb)
202
+ sample = self.conv_in(sample)
203
+ down_block_res_samples = (sample, )
204
+ for down_block in self.down_blocks:
205
+ if isinstance(down_block, CrossAttnDownBlockPseudo3D):
206
+ sample, res_samples = down_block(
207
+ hidden_states = sample,
208
+ temb = t_emb,
209
+ encoder_hidden_states = encoder_hidden_states
210
+ )
211
+ elif isinstance(down_block, DownBlockPseudo3D):
212
+ sample, res_samples = down_block(
213
+ hidden_states = sample,
214
+ temb = t_emb
215
+ )
216
+ else:
217
+ raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}')
218
+ down_block_res_samples += res_samples
219
+ sample = self.mid_block(
220
+ hidden_states = sample,
221
+ temb = t_emb,
222
+ encoder_hidden_states = encoder_hidden_states
223
+ )
224
+ for up_block in self.up_blocks:
225
+ res_samples = down_block_res_samples[-(self.layers_per_block + 1):]
226
+ down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)]
227
+ if isinstance(up_block, CrossAttnUpBlockPseudo3D):
228
+ sample = up_block(
229
+ hidden_states = sample,
230
+ temb = t_emb,
231
+ encoder_hidden_states = encoder_hidden_states,
232
+ res_hidden_states_tuple = res_samples
233
+ )
234
+ elif isinstance(up_block, UpBlockPseudo3D):
235
+ sample = up_block(
236
+ hidden_states = sample,
237
+ temb = t_emb,
238
+ res_hidden_states_tuple = res_samples
239
+ )
240
+ else:
241
+ raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}')
242
+ sample = self.conv_norm_out(sample)
243
+ sample = nn.silu(sample)
244
+ sample = self.conv_out(sample)
245
+
246
+ # b,f,h,w,c -> b,c,f,h,w
247
+ sample = sample.transpose((0, 4, 1, 2, 3))
248
+ if not return_dict:
249
+ return (sample, )
250
+ return UNetPseudo3DConditionOutput(sample = sample)
251
+
makeavid_sd/flax_impl/train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import jax
3
+ _ = jax.device_count() # ugly hack to prevent tpu comms to lock/race or smth smh
4
+
5
+ from typing import Tuple, Optional
6
+ import os
7
+ from argparse import ArgumentParser
8
+
9
+ from flax_trainer import FlaxTrainerUNetPseudo3D
10
+ from dataset import load_dataset
11
+
12
+ def train(
13
+ dataset_path: str,
14
+ model_path: str,
15
+ output_dir: str,
16
+ dataset_cache_dir: Optional[str] = None,
17
+ from_pt: bool = True,
18
+ convert2d: bool = False,
19
+ only_temporal: bool = True,
20
+ sample_size: Tuple[int, int] = (64, 64),
21
+ lr: float = 5e-5,
22
+ batch_size: int = 1,
23
+ num_frames: int = 24,
24
+ epochs: int = 10,
25
+ warmup: float = 0.1,
26
+ decay: float = 0.0,
27
+ weight_decay: float = 1e-2,
28
+ log_every_step: int = 50,
29
+ save_every_epoch: int = 1,
30
+ sample_every_epoch: int = 1,
31
+ seed: int = 0,
32
+ dtype: str = 'bfloat16',
33
+ param_dtype: str = 'float32',
34
+ use_memory_efficient_attention: bool = True,
35
+ verbose: bool = True,
36
+ use_wandb: bool = False
37
+ ) -> None:
38
+ log = lambda x: print(x) if verbose else None
39
+ log('\n----------------')
40
+ log('Init trainer')
41
+ trainer = FlaxTrainerUNetPseudo3D(
42
+ model_path = model_path,
43
+ from_pt = from_pt,
44
+ convert2d = convert2d,
45
+ sample_size = sample_size,
46
+ seed = seed,
47
+ dtype = dtype,
48
+ param_dtype = param_dtype,
49
+ use_memory_efficient_attention = use_memory_efficient_attention,
50
+ verbose = verbose,
51
+ only_temporal = only_temporal
52
+ )
53
+ log('\n----------------')
54
+ log('Init dataset')
55
+ dataloader = load_dataset(
56
+ dataset_path = dataset_path,
57
+ model_path = model_path,
58
+ cache_dir = dataset_cache_dir,
59
+ batch_size = batch_size * trainer.num_devices,
60
+ num_frames = num_frames,
61
+ num_workers = min(trainer.num_devices * 2, os.cpu_count() - 1),
62
+ as_numpy = True,
63
+ shuffle = True
64
+ )
65
+ log('\n----------------')
66
+ log('Train')
67
+ if use_wandb:
68
+ trainer.enable_wandb()
69
+ trainer.train(
70
+ dataloader = dataloader,
71
+ epochs = epochs,
72
+ num_frames = num_frames,
73
+ log_every_step = log_every_step,
74
+ save_every_epoch = save_every_epoch,
75
+ sample_every_epoch = sample_every_epoch,
76
+ lr = lr,
77
+ warmup = warmup,
78
+ decay = decay,
79
+ weight_decay = weight_decay,
80
+ output_dir = output_dir
81
+ )
82
+ log('\n----------------')
83
+ log('Done')
84
+
85
+
86
+ if __name__ == '__main__':
87
+ parser = ArgumentParser()
88
+ bool_type = lambda x: x.lower() in ['true', '1', 'yes']
89
+ parser.add_argument('-v', '--verbose', type = bool_type, default = True)
90
+ parser.add_argument('-d', '--dataset_path', required = True)
91
+ parser.add_argument('-m', '--model_path', required = True)
92
+ parser.add_argument('-o', '--output_dir', required = True)
93
+ parser.add_argument('-b', '--batch_size', type = int, default = 1)
94
+ parser.add_argument('-f', '--num_frames', type = int, default = 24)
95
+ parser.add_argument('-e', '--epochs', type = int, default = 2)
96
+ parser.add_argument('--only_temporal', type = bool_type, default = True)
97
+ parser.add_argument('--dataset_cache_dir', type = str, default = None)
98
+ parser.add_argument('--from_pt', type = bool_type, default = True)
99
+ parser.add_argument('--convert2d', type = bool_type, default = False)
100
+ parser.add_argument('--lr', type = float, default = 1e-4)
101
+ parser.add_argument('--warmup', type = float, default = 0.1)
102
+ parser.add_argument('--decay', type = float, default = 0.0)
103
+ parser.add_argument('--weight_decay', type = float, default = 1e-2)
104
+ parser.add_argument('--sample_size', type = int, nargs = 2, default = [64, 64])
105
+ parser.add_argument('--log_every_step', type = int, default = 250)
106
+ parser.add_argument('--save_every_epoch', type = int, default = 1)
107
+ parser.add_argument('--sample_every_epoch', type = int, default = 1)
108
+ parser.add_argument('--seed', type = int, default = 0)
109
+ parser.add_argument('--use_memory_efficient_attention', type = bool_type, default = True)
110
+ parser.add_argument('--dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'bfloat16')
111
+ parser.add_argument('--param_dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'float32')
112
+ parser.add_argument('--wandb', type = bool_type, default = False)
113
+ args = parser.parse_args()
114
+ args.sample_size = tuple(args.sample_size)
115
+ if args.verbose:
116
+ print(args)
117
+ train(
118
+ dataset_path = args.dataset_path,
119
+ model_path = args.model_path,
120
+ from_pt = args.from_pt,
121
+ convert2d = args.convert2d,
122
+ only_temporal = args.only_temporal,
123
+ output_dir = args.output_dir,
124
+ dataset_cache_dir = args.dataset_cache_dir,
125
+ batch_size = args.batch_size,
126
+ num_frames = args.num_frames,
127
+ epochs = args.epochs,
128
+ lr = args.lr,
129
+ warmup = args.warmup,
130
+ decay = args.decay,
131
+ weight_decay = args.weight_decay,
132
+ sample_size = args.sample_size,
133
+ seed = args.seed,
134
+ dtype = args.dtype,
135
+ param_dtype = args.param_dtype,
136
+ use_memory_efficient_attention = args.use_memory_efficient_attention,
137
+ log_every_step = args.log_every_step,
138
+ save_every_epoch = args.save_every_epoch,
139
+ sample_every_epoch = args.sample_every_epoch,
140
+ verbose = args.verbose,
141
+ use_wandb = args.wandb
142
+ )
143
+
makeavid_sd/flax_impl/train.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ #export WANDB_API_KEY="your_api_key"
4
+ export WANDB_ENTITY="tempofunk"
5
+ export WANDB_JOB_TYPE="train"
6
+ export WANDB_PROJECT="makeavid-sd-tpu"
7
+
8
+ python train.py \
9
+ --dataset_path ../storage/dataset/tempofunk-s \
10
+ --model_path ../storage/trained_models/ep20 \
11
+ --from_pt False \
12
+ --convert2d False \
13
+ --only_temporal True \
14
+ --output_dir ../storage/output \
15
+ --batch_size 1 \
16
+ --num_frames 24 \
17
+ --epochs 20 \
18
+ --lr 0.00005 \
19
+ --warmup 0.1 \
20
+ --decay 0.0 \
21
+ --sample_size 64 64 \
22
+ --log_every_step 50 \
23
+ --save_every_epoch 1 \
24
+ --sample_every_epoch 1 \
25
+ --seed 2 \
26
+ --use_memory_efficient_attention True \
27
+ --dtype bfloat16 \
28
+ --param_dtype float32 \
29
+ --verbose True \
30
+ --dataset_cache_dir ../storage/cache/hf/datasets \
31
+ --wandb True
32
+
33
+ # sudo rm /tmp/libtpu_lockfile
34
+
makeavid_sd/inference.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Union, Optional, Tuple, List, Dict
3
+ import os
4
+ import gc
5
+ from functools import partial
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+
11
+ from flax.core.frozen_dict import FrozenDict
12
+ from flax import jax_utils
13
+ from flax.training.common_utils import shard
14
+ from PIL import Image
15
+ import einops
16
+
17
+ from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel
18
+ from diffusers import (
19
+ FlaxDDIMScheduler,
20
+ FlaxPNDMScheduler,
21
+ FlaxLMSDiscreteScheduler,
22
+ FlaxDPMSolverMultistepScheduler,
23
+ )
24
+ from diffusers.schedulers.scheduling_ddim_flax import DDIMSchedulerState
25
+ from diffusers.schedulers.scheduling_pndm_flax import PNDMSchedulerState
26
+ from diffusers.schedulers.scheduling_lms_discrete_flax import LMSDiscreteSchedulerState
27
+ from diffusers.schedulers.scheduling_dpmsolver_multistep_flax import DPMSolverMultistepSchedulerState
28
+
29
+ from transformers import FlaxCLIPTextModel, CLIPTokenizer
30
+
31
+ from .flax_impl.flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
32
+
33
+ SchedulerType = Union[
34
+ FlaxDDIMScheduler,
35
+ FlaxPNDMScheduler,
36
+ FlaxLMSDiscreteScheduler,
37
+ FlaxDPMSolverMultistepScheduler,
38
+ ]
39
+
40
+ SchedulerStateType = Union[
41
+ DDIMSchedulerState,
42
+ PNDMSchedulerState,
43
+ LMSDiscreteSchedulerState,
44
+ DPMSolverMultistepSchedulerState,
45
+ ]
46
+
47
+ SCHEDULERS: Dict[str, SchedulerType] = {
48
+ 'dpm': FlaxDPMSolverMultistepScheduler, # husbando
49
+ 'ddim': FlaxDDIMScheduler,
50
+ #'PLMS': FlaxPNDMScheduler, # its not correctly implemented in diffusers, output is bad, but at least it "works"
51
+ #'LMS': FlaxLMSDiscreteScheduler, # borked
52
+ # image_latents, image_scheduler_state = scheduler.step(
53
+ # File "/mnt/work1/make_a_vid/makeavid-space/.venv/lib/python3.10/site-packages/diffusers/schedulers/scheduling_lms_discrete_flax.py", line 255, in step
54
+ # order = min(timestep + 1, order)
55
+ # jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/1)>
56
+ # The problem arose with the `bool` function.
57
+ # The error occurred while tracing the function scanned_fun at /mnt/work1/make_a_vid/makeavid-space/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1668 for scan. This concrete value was not available in Python because it depends on the values of the arguments loop_carry[0] and loop_carry[1][1].timesteps
58
+ }
59
+
60
+ def dtypestr(x: jnp.dtype):
61
+ if x == jnp.float32: return 'float32'
62
+ elif x == jnp.float16: return 'float16'
63
+ elif x == jnp.bfloat16: return 'bfloat16'
64
+ else: raise
65
+ def castto(dtype, m, x):
66
+ if dtype == jnp.float32: return m.to_fp32(x)
67
+ elif dtype == jnp.float16: return m.to_fp16(x)
68
+ elif dtype == jnp.bfloat16: return m.to_bf16(x)
69
+ else: raise
70
+
71
+ class InferenceUNetPseudo3D:
72
+ def __init__(self,
73
+ model_path: str,
74
+ dtype: jnp.dtype = jnp.float16,
75
+ hf_auth_token: Union[str, None] = None
76
+ ) -> None:
77
+ self.dtype = dtype
78
+ self.model_path = model_path
79
+ self.hf_auth_token = hf_auth_token
80
+
81
+ self.params: Dict[str, FrozenDict[str, Any]] = {}
82
+ try:
83
+ import traceback
84
+ print('initializing unet')
85
+ unet, unet_params = UNetPseudo3DConditionModel.from_pretrained(
86
+ self.model_path,
87
+ subfolder = 'unet',
88
+ from_pt = False,
89
+ sample_size = (64, 64),
90
+ dtype = self.dtype,
91
+ param_dtype = dtypestr(self.dtype),
92
+ use_memory_efficient_attention = True,
93
+ use_auth_token = self.hf_auth_token
94
+ )
95
+ self.unet: UNetPseudo3DConditionModel = unet
96
+ print('casting unet params')
97
+ unet_params = castto(self.dtype, self.unet, unet_params)
98
+ print('storing unet params')
99
+ self.params['unet'] = FrozenDict(unet_params)
100
+ print('deleting unet params')
101
+ del unet_params
102
+ except Exception as e:
103
+ print(e)
104
+ self.failed = ''.join(traceback.format_exception(None, e, e.__traceback__))
105
+ traceback.print_exc()
106
+ return
107
+ self.failed = False
108
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
109
+ self.model_path,
110
+ subfolder = 'vae',
111
+ from_pt = True,
112
+ dtype = self.dtype,
113
+ use_auth_token = self.hf_auth_token
114
+ )
115
+ self.vae: FlaxAutoencoderKL = vae
116
+ vae_params = castto(self.dtype, self.vae, vae_params)
117
+ self.params['vae'] = FrozenDict(vae_params)
118
+ del vae_params
119
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
120
+ self.model_path,
121
+ subfolder = 'text_encoder',
122
+ from_pt = True,
123
+ dtype = self.dtype,
124
+ use_auth_token = self.hf_auth_token
125
+ )
126
+ text_encoder_params = text_encoder.params
127
+ del text_encoder._params
128
+ text_encoder_params = castto(self.dtype, text_encoder, text_encoder_params)
129
+ self.text_encoder: FlaxCLIPTextModel = text_encoder
130
+ self.params['text_encoder'] = FrozenDict(text_encoder_params)
131
+ del text_encoder_params
132
+ imunet, imunet_params = FlaxUNet2DConditionModel.from_pretrained(
133
+ 'runwayml/stable-diffusion-v1-5',
134
+ subfolder = 'unet',
135
+ from_pt = True,
136
+ dtype = self.dtype,
137
+ use_memory_efficient_attention = True,
138
+ use_auth_token = self.hf_auth_token
139
+ )
140
+ imunet_params = castto(self.dtype, imunet, imunet_params)
141
+ self.imunet: FlaxUNet2DConditionModel = imunet
142
+ self.params['imunet'] = FrozenDict(imunet_params)
143
+ del imunet_params
144
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
145
+ self.model_path,
146
+ subfolder = 'tokenizer',
147
+ use_auth_token = self.hf_auth_token
148
+ )
149
+ self.schedulers: Dict[str, Dict[str, SchedulerType]] = {}
150
+ for scheduler_name in SCHEDULERS:
151
+ if scheduler_name not in ['KarrasVe', 'SDEVe']:
152
+ scheduler, scheduler_state = SCHEDULERS[scheduler_name].from_pretrained(
153
+ self.model_path,
154
+ subfolder = 'scheduler',
155
+ dtype = jnp.float32,
156
+ use_auth_token = self.hf_auth_token
157
+ )
158
+ else:
159
+ scheduler, scheduler_state = SCHEDULERS[scheduler_name].from_pretrained(
160
+ self.model_path,
161
+ subfolder = 'scheduler',
162
+ use_auth_token = self.hf_auth_token
163
+ )
164
+ self.schedulers[scheduler_name] = scheduler
165
+ self.params[scheduler_name] = scheduler_state
166
+ self.vae_scale_factor: int = int(2 ** (len(self.vae.config.block_out_channels) - 1))
167
+ self.device_count = jax.device_count()
168
+ gc.collect()
169
+
170
+ def prepare_inputs(self,
171
+ prompt: List[str],
172
+ neg_prompt: List[str],
173
+ hint_image: List[Image.Image],
174
+ mask_image: List[Image.Image],
175
+ width: int,
176
+ height: int
177
+ ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: # prompt, neg_prompt, hint_image, mask_image
178
+ tokens = self.tokenizer(
179
+ prompt,
180
+ truncation = True,
181
+ return_overflowing_tokens = False,
182
+ max_length = 77, #self.text_encoder.config.max_length defaults to 20 if its not in the config smh
183
+ padding = 'max_length',
184
+ return_tensors = 'np'
185
+ ).input_ids
186
+ tokens = jnp.array(tokens, dtype = jnp.int32)
187
+ neg_tokens = self.tokenizer(
188
+ neg_prompt,
189
+ truncation = True,
190
+ return_overflowing_tokens = False,
191
+ max_length = 77,
192
+ padding = 'max_length',
193
+ return_tensors = 'np'
194
+ ).input_ids
195
+ neg_tokens = jnp.array(neg_tokens, dtype = jnp.int32)
196
+ for i,im in enumerate(hint_image):
197
+ if im.size != (width, height):
198
+ hint_image[i] = hint_image[i].resize((width, height), resample = Image.Resampling.LANCZOS)
199
+ for i,im in enumerate(mask_image):
200
+ if im.size != (width, height):
201
+ mask_image[i] = mask_image[i].resize((width, height), resample = Image.Resampling.LANCZOS)
202
+ # b,h,w,c | c == 3
203
+ hint = jnp.concatenate(
204
+ [ jnp.expand_dims(np.asarray(x.convert('RGB')), axis = 0) for x in hint_image ],
205
+ axis = 0
206
+ ).astype(jnp.float32)
207
+ # scale -1,1
208
+ hint = (hint / 255) * 2 - 1
209
+ # b,h,w,c | c == 1
210
+ mask = jnp.concatenate(
211
+ [ jnp.expand_dims(np.asarray(x.convert('L')), axis = (0, -1)) for x in mask_image ],
212
+ axis = 0
213
+ ).astype(jnp.float32)
214
+ # scale -1,1
215
+ mask = (mask / 255) * 2 - 1
216
+ # binarize mask
217
+ mask = mask.at[mask < 0.5].set(0)
218
+ mask = mask.at[mask >= 0.5].set(1)
219
+ # mask
220
+ hint = hint * (mask < 0.5)
221
+ # b,h,w,c -> b,c,h,w
222
+ hint = hint.transpose((0,3,1,2))
223
+ mask = mask.transpose((0,3,1,2))
224
+ return tokens, neg_tokens, hint, mask
225
+
226
+ def generate(self,
227
+ prompt: Union[str, List[str]] = '',
228
+ inference_steps: int = 20,
229
+ hint_image: Union[Image.Image, List[Image.Image], None] = None,
230
+ mask_image: Union[Image.Image, List[Image.Image], None] = None,
231
+ neg_prompt: Union[str, List[str]] = '',
232
+ cfg: float = 15.0,
233
+ cfg_image: Optional[float] = None,
234
+ num_frames: int = 24,
235
+ width: int = 512,
236
+ height: int = 512,
237
+ seed: int = 0,
238
+ scheduler_type: str = 'dpm'
239
+ ) -> List[List[Image.Image]]:
240
+ assert inference_steps > 0, f'number of inference steps must be > 0 but is {inference_steps}'
241
+ assert num_frames > 0, f'number of frames must be > 0 but is {num_frames}'
242
+ assert width % 32 == 0, f'width must be divisible by 32 but is {width}'
243
+ assert height % 32 == 0, f'height must be divisible by 32 but is {height}'
244
+ if isinstance(prompt, str):
245
+ prompt = [ prompt ]
246
+ batch_size = len(prompt)
247
+ assert batch_size % self.device_count == 0, f'batch size must be multiple of {self.device_count}'
248
+ if hint_image is None:
249
+ hint_image = Image.new('RGB', (width, height), color = (0,0,0))
250
+ use_imagegen = True
251
+ else:
252
+ use_imagegen = False
253
+ if isinstance(hint_image, Image.Image):
254
+ hint_image = [ hint_image ] * batch_size
255
+ assert len(hint_image) == batch_size, f'number of hint images must be equal to batch size {batch_size} but is {len(hint_image)}'
256
+ if mask_image is None:
257
+ mask_image = Image.new('L', hint_image[0].size, color = 0)
258
+ if isinstance(mask_image, Image.Image):
259
+ mask_image = [ mask_image ] * batch_size
260
+ assert len(mask_image) == batch_size, f'number of mask images must be equal to batch size {batch_size} but is {len(mask_image)}'
261
+ if isinstance(neg_prompt, str):
262
+ neg_prompt = [ neg_prompt ] * batch_size
263
+ assert len(neg_prompt) == batch_size, f'number of negative prompts must be equal to batch size {batch_size} but is {len(neg_prompt)}'
264
+ assert scheduler_type in SCHEDULERS, f'unknown type of noise scheduler: {scheduler_type}, must be one of {list(SCHEDULERS.keys())}'
265
+ tokens, neg_tokens, hint, mask = self.prepare_inputs(
266
+ prompt = prompt,
267
+ neg_prompt = neg_prompt,
268
+ hint_image = hint_image,
269
+ mask_image = mask_image,
270
+ width = width,
271
+ height = height
272
+ )
273
+ if cfg_image is None:
274
+ cfg_image = cfg
275
+ #params['scheduler'] = scheduler_state
276
+ # NOTE splitting rngs is not deterministic,
277
+ # running on different device counts gives different seeds
278
+ #rng = jax.random.PRNGKey(seed)
279
+ #rngs = jax.random.split(rng, self.device_count)
280
+ # manually assign seeded RNGs to devices for reproducability
281
+ rngs = jnp.array([ jax.random.PRNGKey(seed + i) for i in range(self.device_count) ])
282
+ params = jax_utils.replicate(self.params)
283
+ tokens = shard(tokens)
284
+ neg_tokens = shard(neg_tokens)
285
+ hint = shard(hint)
286
+ mask = shard(mask)
287
+ images = _p_generate(self,
288
+ tokens,
289
+ neg_tokens,
290
+ hint,
291
+ mask,
292
+ inference_steps,
293
+ num_frames,
294
+ height,
295
+ width,
296
+ cfg,
297
+ cfg_image,
298
+ rngs,
299
+ params,
300
+ use_imagegen,
301
+ scheduler_type,
302
+ )
303
+ if images.ndim == 5:
304
+ images = einops.rearrange(images, 'd f c h w -> (d f) h w c')
305
+ else:
306
+ images = einops.rearrange(images, 'f c h w -> f h w c')
307
+ # to cpu
308
+ images = np.array(images)
309
+ images = [ Image.fromarray(x) for x in images ]
310
+ return images
311
+
312
+ def _generate(self,
313
+ tokens: jnp.ndarray,
314
+ neg_tokens: jnp.ndarray,
315
+ hint: jnp.ndarray,
316
+ mask: jnp.ndarray,
317
+ inference_steps: int,
318
+ num_frames,
319
+ height,
320
+ width,
321
+ cfg: float,
322
+ cfg_image: float,
323
+ rng: jax.random.KeyArray,
324
+ params: Union[Dict[str, Any], FrozenDict[str, Any]],
325
+ use_imagegen: bool,
326
+ scheduler_type: str
327
+ ) -> List[Image.Image]:
328
+ batch_size = tokens.shape[0]
329
+ latent_h = height // self.vae_scale_factor
330
+ latent_w = width // self.vae_scale_factor
331
+ latent_shape = (
332
+ batch_size,
333
+ self.vae.config.latent_channels,
334
+ num_frames,
335
+ latent_h,
336
+ latent_w
337
+ )
338
+ encoded_prompt = self.text_encoder(tokens, params = params['text_encoder'])[0]
339
+ encoded_neg_prompt = self.text_encoder(neg_tokens, params = params['text_encoder'])[0]
340
+
341
+ scheduler = self.schedulers[scheduler_type]
342
+ scheduler_state = params[scheduler_type]
343
+
344
+ if use_imagegen:
345
+ image_latent_shape = (batch_size, self.vae.config.latent_channels, latent_h, latent_w)
346
+ image_latents = jax.random.normal(
347
+ rng,
348
+ shape = image_latent_shape,
349
+ dtype = jnp.float32
350
+ ) * scheduler_state.init_noise_sigma
351
+ image_scheduler_state = scheduler.set_timesteps(
352
+ scheduler_state,
353
+ num_inference_steps = inference_steps,
354
+ shape = image_latents.shape
355
+ )
356
+ def image_sample_loop(step, args):
357
+ image_latents, image_scheduler_state = args
358
+ t = image_scheduler_state.timesteps[step]
359
+ tt = jnp.broadcast_to(t, image_latents.shape[0])
360
+ latents_input = scheduler.scale_model_input(image_scheduler_state, image_latents, t)
361
+ noise_pred = self.imunet.apply(
362
+ { 'params': params['imunet']} ,
363
+ latents_input,
364
+ tt,
365
+ encoder_hidden_states = encoded_prompt
366
+ ).sample
367
+ noise_pred_uncond = self.imunet.apply(
368
+ { 'params': params['imunet'] },
369
+ latents_input,
370
+ tt,
371
+ encoder_hidden_states = encoded_neg_prompt
372
+ ).sample
373
+ noise_pred = noise_pred_uncond + cfg_image * (noise_pred - noise_pred_uncond)
374
+ image_latents, image_scheduler_state = scheduler.step(
375
+ image_scheduler_state,
376
+ noise_pred.astype(jnp.float32),
377
+ t,
378
+ image_latents
379
+ ).to_tuple()
380
+ return image_latents, image_scheduler_state
381
+ image_latents, _ = jax.lax.fori_loop(
382
+ 0, inference_steps,
383
+ image_sample_loop,
384
+ (image_latents, image_scheduler_state)
385
+ )
386
+ hint = image_latents
387
+ else:
388
+ hint = self.vae.apply(
389
+ { 'params': params['vae'] },
390
+ hint,
391
+ method = self.vae.encode
392
+ ).latent_dist.mean * self.vae.config.scaling_factor
393
+ # NOTE vae keeps channels last for encode, but rearranges to channels first for decode
394
+ # b0 h1 w2 c3 -> b0 c3 h1 w2
395
+ hint = hint.transpose((0, 3, 1, 2))
396
+
397
+ hint = jnp.expand_dims(hint, axis = 2).repeat(num_frames, axis = 2)
398
+ mask = jax.image.resize(mask, (*mask.shape[:-2], *hint.shape[-2:]), method = 'nearest')
399
+ mask = jnp.expand_dims(mask, axis = 2).repeat(num_frames, axis = 2)
400
+ # NOTE jax normal distribution is shit with float16 + bfloat16
401
+ # SEE https://github.com/google/jax/discussions/13798
402
+ # generate random at float32
403
+ latents = jax.random.normal(
404
+ rng,
405
+ shape = latent_shape,
406
+ dtype = jnp.float32
407
+ ) * scheduler_state.init_noise_sigma
408
+ scheduler_state = scheduler.set_timesteps(
409
+ scheduler_state,
410
+ num_inference_steps = inference_steps,
411
+ shape = latents.shape
412
+ )
413
+
414
+ def sample_loop(step, args):
415
+ latents, scheduler_state = args
416
+ t = scheduler_state.timesteps[step]#jnp.array(scheduler_state.timesteps, dtype = jnp.int32)[step]
417
+ tt = jnp.broadcast_to(t, latents.shape[0])
418
+ latents_input = scheduler.scale_model_input(scheduler_state, latents, t)
419
+ latents_input = jnp.concatenate([latents_input, mask, hint], axis = 1)
420
+ noise_pred = self.unet.apply(
421
+ { 'params': params['unet'] },
422
+ latents_input,
423
+ tt,
424
+ encoded_prompt
425
+ ).sample
426
+ noise_pred_uncond = self.unet.apply(
427
+ { 'params': params['unet'] },
428
+ latents_input,
429
+ tt,
430
+ encoded_neg_prompt
431
+ ).sample
432
+ noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
433
+ latents, scheduler_state = scheduler.step(
434
+ scheduler_state,
435
+ noise_pred.astype(jnp.float32),
436
+ t,
437
+ latents
438
+ ).to_tuple()
439
+ return latents, scheduler_state
440
+
441
+ latents, _ = jax.lax.fori_loop(
442
+ 0, inference_steps,
443
+ sample_loop,
444
+ (latents, scheduler_state)
445
+ )
446
+ latents = 1 / self.vae.config.scaling_factor * latents
447
+ latents = einops.rearrange(latents, 'b c f h w -> (b f) c h w')
448
+ num_images = len(latents)
449
+ images_out = jnp.zeros(
450
+ (
451
+ num_images,
452
+ self.vae.config.out_channels,
453
+ height,
454
+ width
455
+ ),
456
+ dtype = self.dtype
457
+ )
458
+ def decode_loop(step, images_out):
459
+ # NOTE vae keeps channels last for encode, but rearranges to channels first for decode
460
+ im = self.vae.apply(
461
+ { 'params': params['vae'] },
462
+ jnp.expand_dims(latents[step], axis = 0),
463
+ method = self.vae.decode
464
+ ).sample
465
+ images_out = images_out.at[step].set(im[0])
466
+ return images_out
467
+ images_out = jax.lax.fori_loop(0, num_images, decode_loop, images_out)
468
+ images_out = ((images_out / 2 + 0.5) * 255).round().clip(0, 255).astype(jnp.uint8)
469
+ return images_out
470
+
471
+
472
+ @partial(
473
+ jax.pmap,
474
+ in_axes = ( # 0 -> split across batch dim, None -> duplicate
475
+ None, # 0 inference_class
476
+ 0, # 1 tokens
477
+ 0, # 2 neg_tokens
478
+ 0, # 3 hint
479
+ 0, # 4 mask
480
+ None, # 5 inference_steps
481
+ None, # 6 num_frames
482
+ None, # 7 height
483
+ None, # 8 width
484
+ None, # 9 cfg
485
+ None, # 10 cfg_image
486
+ 0, # 11 rng
487
+ 0, # 12 params
488
+ None, # 13 use_imagegen
489
+ None, # 14 scheduler_type
490
+ ),
491
+ static_broadcasted_argnums = ( # trigger recompilation on change
492
+ 0, # inference_class
493
+ 5, # inference_steps
494
+ 6, # num_frames
495
+ 7, # height
496
+ 8, # width
497
+ 13, # use_imagegen
498
+ 14, # scheduler_type
499
+ )
500
+ )
501
+ def _p_generate(
502
+ inference_class: InferenceUNetPseudo3D,
503
+ tokens,
504
+ neg_tokens,
505
+ hint,
506
+ mask,
507
+ inference_steps: int,
508
+ num_frames: int,
509
+ height: int,
510
+ width: int,
511
+ cfg: float,
512
+ cfg_image: float,
513
+ rng,
514
+ params,
515
+ use_imagegen: bool,
516
+ scheduler_type: str
517
+ ):
518
+ return inference_class._generate(
519
+ tokens,
520
+ neg_tokens,
521
+ hint,
522
+ mask,
523
+ inference_steps,
524
+ num_frames,
525
+ height,
526
+ width,
527
+ cfg,
528
+ cfg_image,
529
+ rng,
530
+ params,
531
+ use_imagegen,
532
+ scheduler_type
533
+ )
534
+
makeavid_sd/torch_impl/__init__.py ADDED
File without changes
makeavid_sd/torch_impl/torch_attention_pseudo3d.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from einops import rearrange
8
+
9
+ from diffusers.models.attention_processor import Attention as CrossAttention
10
+ #from torch_cross_attention import CrossAttention
11
+
12
+
13
+ class TransformerPseudo3DModelOutput:
14
+ def __init__(self, sample: torch.FloatTensor) -> None:
15
+ self.sample = sample
16
+
17
+
18
+ class TransformerPseudo3DModel(nn.Module):
19
+ def __init__(self,
20
+ num_attention_heads: int = 16,
21
+ attention_head_dim: int = 88,
22
+ in_channels: Optional[int] = None,
23
+ num_layers: int = 1,
24
+ dropout: float = 0.0,
25
+ norm_num_groups: int = 32,
26
+ cross_attention_dim: Optional[int] = None,
27
+ attention_bias: bool = False
28
+ ) -> None:
29
+ super().__init__()
30
+ self.num_attention_heads = num_attention_heads
31
+ self.attention_head_dim = attention_head_dim
32
+ inner_dim = num_attention_heads * attention_head_dim
33
+
34
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
35
+ # Define whether input is continuous or discrete depending on configuration
36
+ # its continuous
37
+
38
+ # 2. Define input layers
39
+ self.in_channels = in_channels
40
+
41
+ self.norm = torch.nn.GroupNorm(
42
+ num_groups = norm_num_groups,
43
+ num_channels = in_channels,
44
+ eps = 1e-6,
45
+ affine = True
46
+ )
47
+ self.proj_in = nn.Conv2d(
48
+ in_channels,
49
+ inner_dim,
50
+ kernel_size = 1,
51
+ stride = 1,
52
+ padding = 0
53
+ )
54
+
55
+ # 3. Define transformers blocks
56
+ self.transformer_blocks = nn.ModuleList(
57
+ [
58
+ BasicTransformerBlock(
59
+ inner_dim,
60
+ num_attention_heads,
61
+ attention_head_dim,
62
+ dropout = dropout,
63
+ cross_attention_dim = cross_attention_dim,
64
+ attention_bias = attention_bias,
65
+ )
66
+ for _ in range(num_layers)
67
+ ]
68
+ )
69
+
70
+ # 4. Define output layers
71
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size = 1, stride = 1, padding = 0)
72
+
73
+ def forward(self,
74
+ hidden_states: torch.Tensor,
75
+ encoder_hidden_states: Optional[torch.Tensor] = None,
76
+ timestep: torch.long = None
77
+ ) -> TransformerPseudo3DModelOutput:
78
+ """
79
+ Args:
80
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
81
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
82
+ hidden_states
83
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
84
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
85
+ self-attention.
86
+ timestep ( `torch.long`, *optional*):
87
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
88
+ return_dict (`bool`, *optional*, defaults to `True`):
89
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
90
+
91
+ Returns:
92
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
93
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
94
+ tensor.
95
+ """
96
+ b, c, *_, h, w = hidden_states.shape
97
+ is_video = hidden_states.ndim == 5
98
+ f = None
99
+ if is_video:
100
+ b, c, f, h, w = hidden_states.shape
101
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
102
+ #encoder_hidden_states = encoder_hidden_states.repeat_interleave(f, 0)
103
+
104
+ # 1. Input
105
+ batch, channel, height, weight = hidden_states.shape
106
+ residual = hidden_states
107
+ hidden_states = self.norm(hidden_states)
108
+ hidden_states = self.proj_in(hidden_states)
109
+ inner_dim = hidden_states.shape[1]
110
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
111
+
112
+ # 2. Blocks
113
+ for block in self.transformer_blocks:
114
+ hidden_states = block(
115
+ hidden_states,
116
+ context = encoder_hidden_states,
117
+ timestep = timestep,
118
+ frames_length = f,
119
+ height = height,
120
+ weight = weight
121
+ )
122
+
123
+ # 3. Output
124
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
125
+ hidden_states = self.proj_out(hidden_states)
126
+ output = hidden_states + residual
127
+
128
+ if is_video:
129
+ output = rearrange(output, '(b f) c h w -> b c f h w', b = b)
130
+
131
+ return TransformerPseudo3DModelOutput(sample = output)
132
+
133
+
134
+
135
+ class BasicTransformerBlock(nn.Module):
136
+ r"""
137
+ A basic Transformer block.
138
+
139
+ Parameters:
140
+ dim (`int`): The number of channels in the input and output.
141
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
142
+ attention_head_dim (`int`): The number of channels in each head.
143
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
144
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
145
+ num_embeds_ada_norm (:
146
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
147
+ attention_bias (:
148
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
149
+ """
150
+
151
+ def __init__(self,
152
+ dim: int,
153
+ num_attention_heads: int,
154
+ attention_head_dim: int,
155
+ dropout: float = 0.0,
156
+ cross_attention_dim: Optional[int] = None,
157
+ attention_bias: bool = False,
158
+ ) -> None:
159
+ super().__init__()
160
+ self.attn1 = CrossAttention(
161
+ query_dim = dim,
162
+ heads = num_attention_heads,
163
+ dim_head = attention_head_dim,
164
+ dropout = dropout,
165
+ bias = attention_bias
166
+ ) # is a self-attention
167
+ self.ff = FeedForward(dim, dropout = dropout)
168
+ self.attn2 = CrossAttention(
169
+ query_dim = dim,
170
+ cross_attention_dim = cross_attention_dim,
171
+ heads = num_attention_heads,
172
+ dim_head = attention_head_dim,
173
+ dropout = dropout,
174
+ bias = attention_bias
175
+ ) # is self-attn if context is none
176
+ self.attn_temporal = CrossAttention(
177
+ query_dim = dim,
178
+ heads = num_attention_heads,
179
+ dim_head = attention_head_dim,
180
+ dropout = dropout,
181
+ bias = attention_bias
182
+ ) # is a self-attention
183
+
184
+ # layer norms
185
+ self.norm1 = nn.LayerNorm(dim)
186
+ self.norm2 = nn.LayerNorm(dim)
187
+ self.norm_temporal = nn.LayerNorm(dim)
188
+ self.norm3 = nn.LayerNorm(dim)
189
+
190
+ def forward(self,
191
+ hidden_states: torch.Tensor,
192
+ context: Optional[torch.Tensor] = None,
193
+ timestep: torch.int64 = None,
194
+ frames_length: Optional[int] = None,
195
+ height: Optional[int] = None,
196
+ weight: Optional[int] = None
197
+ ) -> torch.Tensor:
198
+ if context is not None and frames_length is not None:
199
+ context = context.repeat_interleave(frames_length, 0)
200
+ # 1. Self-Attention
201
+ norm_hidden_states = (
202
+ self.norm1(hidden_states)
203
+ )
204
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
205
+
206
+ # 2. Cross-Attention
207
+ norm_hidden_states = (
208
+ self.norm2(hidden_states)
209
+ )
210
+ hidden_states = self.attn2(
211
+ norm_hidden_states,
212
+ encoder_hidden_states = context
213
+ ) + hidden_states
214
+
215
+ # append temporal attention
216
+ if frames_length is not None:
217
+ hidden_states = rearrange(
218
+ hidden_states,
219
+ '(b f) (h w) c -> (b h w) f c',
220
+ f = frames_length,
221
+ h = height,
222
+ w = weight
223
+ )
224
+ norm_hidden_states = (
225
+ self.norm_temporal(hidden_states)
226
+ )
227
+ hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
228
+ hidden_states = rearrange(
229
+ hidden_states,
230
+ '(b h w) f c -> (b f) (h w) c',
231
+ f = frames_length,
232
+ h = height,
233
+ w = weight
234
+ )
235
+
236
+ # 3. Feed-forward
237
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
238
+ return hidden_states
239
+
240
+
241
+ class FeedForward(nn.Module):
242
+ r"""
243
+ A feed-forward layer.
244
+
245
+ Parameters:
246
+ dim (`int`): The number of channels in the input.
247
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
248
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
249
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
250
+ """
251
+
252
+ def __init__(self,
253
+ dim: int,
254
+ dim_out: Optional[int] = None,
255
+ mult: int = 4,
256
+ dropout: float = 0.0
257
+ ) -> None:
258
+ super().__init__()
259
+ inner_dim = int(dim * mult)
260
+ dim_out = dim_out if dim_out is not None else dim
261
+
262
+ geglu = GEGLU(dim, inner_dim)
263
+
264
+ self.net = nn.ModuleList([])
265
+ # project in
266
+ self.net.append(geglu)
267
+ # project dropout
268
+ self.net.append(nn.Dropout(dropout))
269
+ # project out
270
+ self.net.append(nn.Linear(inner_dim, dim_out))
271
+
272
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
273
+ for module in self.net:
274
+ hidden_states = module(hidden_states)
275
+ return hidden_states
276
+
277
+
278
+ # feedforward
279
+ class GEGLU(nn.Module):
280
+ r"""
281
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
282
+
283
+ Parameters:
284
+ dim_in (`int`): The number of channels in the input.
285
+ dim_out (`int`): The number of channels in the output.
286
+ """
287
+
288
+ def __init__(self, dim_in: int, dim_out: int) -> None:
289
+ super().__init__()
290
+ self.proj = nn.Linear(dim_in, dim_out * 2)
291
+
292
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
293
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim = -1)
294
+ return hidden_states * F.gelu(gate)
makeavid_sd/torch_impl/torch_cross_attention.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class CrossAttention(nn.Module):
6
+ r"""
7
+ A cross attention layer.
8
+
9
+ Parameters:
10
+ query_dim (`int`): The number of channels in the query.
11
+ cross_attention_dim (`int`, *optional*):
12
+ The number of channels in the context. If not given, defaults to `query_dim`.
13
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
14
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
15
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
16
+ bias (`bool`, *optional*, defaults to False):
17
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
18
+ """
19
+
20
+ def __init__(self,
21
+ query_dim: int,
22
+ cross_attention_dim: Optional[int] = None,
23
+ heads: int = 8,
24
+ dim_head: int = 64,
25
+ dropout: float = 0.0,
26
+ bias: bool = False
27
+ ):
28
+ super().__init__()
29
+ inner_dim = dim_head * heads
30
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
31
+
32
+ self.scale = dim_head**-0.5
33
+ self.heads = heads
34
+ self.n_heads = heads
35
+ self.d_head = dim_head
36
+
37
+ self.to_q = nn.Linear(query_dim, inner_dim, bias = bias)
38
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
39
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
40
+
41
+ self.to_out = nn.ModuleList([])
42
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
43
+ self.to_out.append(nn.Dropout(dropout))
44
+ try:
45
+ # You can install flash attention by cloning their Github repo,
46
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
47
+ # and then running `python setup.py install`
48
+ from flash_attn.flash_attention import FlashAttention
49
+ self.flash = FlashAttention()
50
+ # Set the scale for scaled dot-product attention.
51
+ self.flash.softmax_scale = self.scale
52
+ # Set to `None` if it's not installed
53
+ except ImportError:
54
+ self.flash = None
55
+
56
+ def reshape_heads_to_batch_dim(self, tensor):
57
+ batch_size, seq_len, dim = tensor.shape
58
+ head_size = self.heads
59
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
60
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
61
+ return tensor
62
+
63
+ def reshape_batch_dim_to_heads(self, tensor):
64
+ batch_size, seq_len, dim = tensor.shape
65
+ head_size = self.heads
66
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
67
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
68
+ return tensor
69
+
70
+ def forward(self,
71
+ hidden_states: torch.Tensor,
72
+ encoder_hidden_states: Optional[torch.Tensor] = None,
73
+ mask: Optional[torch.Tensor] = None
74
+ ) -> torch.Tensor:
75
+ batch_size, sequence_length, _ = hidden_states.shape
76
+ is_self = encoder_hidden_states is None
77
+ # attention, what we cannot get enough of
78
+ query = self.to_q(hidden_states)
79
+ has_cond = encoder_hidden_states is not None
80
+
81
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
82
+ key = self.to_k(encoder_hidden_states)
83
+ value = self.to_v(encoder_hidden_states)
84
+
85
+ dim = query.shape[-1]
86
+
87
+ if self.flash is not None and not has_cond and self.d_head <= 64:
88
+ hidden_states = self.flash_attention(query, key, value)
89
+ else:
90
+ hidden_states = self.normal_attention(query, key, value, is_self)
91
+
92
+ # linear proj
93
+ hidden_states = self.to_out[0](hidden_states)
94
+ # dropout
95
+ hidden_states = self.to_out[1](hidden_states)
96
+ return hidden_states
97
+
98
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
99
+ """
100
+ #### Flash Attention
101
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
102
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
103
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
104
+ """
105
+
106
+ # Get batch size and number of elements along sequence axis (`width * height`)
107
+ batch_size, seq_len, _ = q.shape
108
+
109
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
110
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
111
+ qkv = torch.stack((q, k, v), dim = 2)
112
+ # Split the heads
113
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
114
+
115
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
116
+ # fit this size.
117
+ if self.d_head <= 32:
118
+ pad = 32 - self.d_head
119
+ elif self.d_head <= 64:
120
+ pad = 64 - self.d_head
121
+ elif self.d_head <= 128:
122
+ pad = 128 - self.d_head
123
+ else:
124
+ raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
125
+
126
+ # Pad the heads
127
+ if pad:
128
+ qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim = -1)
129
+
130
+ # Compute attention
131
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
132
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
133
+ out, _ = self.flash(qkv)
134
+ # Truncate the extra head size
135
+ out = out[:, :, :, :self.d_head]
136
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
137
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
138
+
139
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
140
+ return out
141
+
142
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_self: bool):
143
+ """
144
+ #### Normal Attention
145
+
146
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
147
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
148
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
149
+ """
150
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
151
+ q = q.view(*q.shape[:2], self.n_heads, -1)
152
+ k = k.view(*k.shape[:2], self.n_heads, -1)
153
+ v = v.view(*v.shape[:2], self.n_heads, -1)
154
+
155
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
156
+ attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
157
+ # Compute softmax
158
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
159
+ half = attn.shape[0] // 2
160
+ attn[half:] = attn[half:].softmax(dim = -1)
161
+ attn[:half] = attn[:half].softmax(dim = -1)
162
+
163
+ # Compute attention output
164
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
165
+ out = torch.einsum('bhij,bjhd->bihd', attn, v)
166
+
167
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
168
+ out = out.reshape(*out.shape[:2], -1)
169
+
170
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
171
+ return out
makeavid_sd/torch_impl/torch_embeddings.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+ def get_timestep_embedding(
6
+ timesteps: torch.Tensor,
7
+ embedding_dim: int,
8
+ flip_sin_to_cos: bool = False,
9
+ downscale_freq_shift: float = 1,
10
+ scale: float = 1,
11
+ max_period: int = 10000,
12
+ ) -> torch.Tensor:
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
15
+
16
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
17
+ These may be fractional.
18
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
19
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
20
+ """
21
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
22
+
23
+ half_dim = embedding_dim // 2
24
+ exponent = -math.log(max_period) * torch.arange(
25
+ start = 0,
26
+ end = half_dim,
27
+ dtype = torch.float32,
28
+ device = timesteps.device
29
+ )
30
+ exponent = exponent / (half_dim - downscale_freq_shift)
31
+
32
+ emb = torch.exp(exponent)
33
+ emb = timesteps[:, None].float() * emb[None, :]
34
+
35
+ # scale embeddings
36
+ emb = scale * emb
37
+
38
+ # concat sine and cosine embeddings
39
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim = -1)
40
+
41
+ # flip sine and cosine embeddings
42
+ if flip_sin_to_cos:
43
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim = -1)
44
+
45
+ # zero pad
46
+ if embedding_dim % 2 == 1:
47
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
48
+ return emb
49
+
50
+
51
+ class TimestepEmbedding(nn.Module):
52
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
53
+ super().__init__()
54
+
55
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
56
+ self.act = None
57
+ if act_fn == "silu":
58
+ self.act = nn.SiLU()
59
+ elif act_fn == "mish":
60
+ self.act = nn.Mish()
61
+
62
+ if out_dim is not None:
63
+ time_embed_dim_out = out_dim
64
+ else:
65
+ time_embed_dim_out = time_embed_dim
66
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
67
+
68
+ def forward(self, sample):
69
+ sample = self.linear_1(sample)
70
+
71
+ if self.act is not None:
72
+ sample = self.act(sample)
73
+
74
+ sample = self.linear_2(sample)
75
+ return sample
76
+
77
+
78
+ class Timesteps(nn.Module):
79
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
80
+ super().__init__()
81
+ self.num_channels = num_channels
82
+ self.flip_sin_to_cos = flip_sin_to_cos
83
+ self.downscale_freq_shift = downscale_freq_shift
84
+
85
+ def forward(self, timesteps):
86
+ t_emb = get_timestep_embedding(
87
+ timesteps,
88
+ self.num_channels,
89
+ flip_sin_to_cos=self.flip_sin_to_cos,
90
+ downscale_freq_shift=self.downscale_freq_shift,
91
+ )
92
+ return t_emb
makeavid_sd/torch_impl/torch_resnet_pseudo3d.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ class Pseudo3DConv(nn.Module):
7
+ def __init__(
8
+ self,
9
+ dim,
10
+ dim_out,
11
+ kernel_size,
12
+ **kwargs
13
+ ):
14
+ super().__init__()
15
+
16
+ self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size, **kwargs)
17
+ self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size, padding=kernel_size // 2)
18
+ self.temporal_conv = nn.Conv1d(dim_out, dim_out, 3, padding=1)
19
+
20
+ nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
21
+ nn.init.zeros_(self.temporal_conv.bias.data)
22
+
23
+ def forward(
24
+ self,
25
+ x,
26
+ convolve_across_time = True
27
+ ):
28
+ b, c, *_, h, w = x.shape
29
+
30
+ is_video = x.ndim == 5
31
+ convolve_across_time &= is_video
32
+
33
+ if is_video:
34
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
35
+
36
+ #with torch.no_grad():
37
+ # x = self.spatial_conv(x)
38
+ x = self.spatial_conv(x)
39
+
40
+ if is_video:
41
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
42
+ b, c, *_, h, w = x.shape
43
+
44
+ if not convolve_across_time:
45
+ return x
46
+
47
+ if is_video:
48
+ x = rearrange(x, 'b c f h w -> (b h w) c f')
49
+ x = self.temporal_conv(x)
50
+ x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
51
+ return x
52
+
53
+ class Upsample2D(nn.Module):
54
+ """
55
+ An upsampling layer with an optional convolution.
56
+
57
+ Parameters:
58
+ channels: channels in the inputs and outputs.
59
+ use_conv: a bool determining if a convolution is applied.
60
+ use_conv_transpose:
61
+ out_channels:
62
+ """
63
+
64
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
65
+ super().__init__()
66
+ self.channels = channels
67
+ self.out_channels = out_channels or channels
68
+ self.use_conv = use_conv
69
+ self.use_conv_transpose = use_conv_transpose
70
+ self.name = name
71
+
72
+ conv = None
73
+ if use_conv_transpose:
74
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
75
+ elif use_conv:
76
+ conv = Pseudo3DConv(self.channels, self.out_channels, 3, padding=1)
77
+
78
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
79
+ if name == "conv":
80
+ self.conv = conv
81
+ else:
82
+ self.Conv2d_0 = conv
83
+
84
+ def forward(self, hidden_states, output_size=None):
85
+ assert hidden_states.shape[1] == self.channels
86
+
87
+ if self.use_conv_transpose:
88
+ return self.conv(hidden_states)
89
+
90
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
91
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
92
+ # https://github.com/pytorch/pytorch/issues/86679
93
+ dtype = hidden_states.dtype
94
+ if dtype == torch.bfloat16:
95
+ hidden_states = hidden_states.to(torch.float32)
96
+
97
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
98
+ if hidden_states.shape[0] >= 64:
99
+ hidden_states = hidden_states.contiguous()
100
+
101
+ b, c, *_, h, w = hidden_states.shape
102
+
103
+ is_video = hidden_states.ndim == 5
104
+
105
+ if is_video:
106
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
107
+
108
+ # if `output_size` is passed we force the interpolation output
109
+ # size and do not make use of `scale_factor=2`
110
+ if output_size is None:
111
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
112
+ else:
113
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
114
+
115
+ if is_video:
116
+ hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b = b)
117
+
118
+ # If the input is bfloat16, we cast back to bfloat16
119
+ if dtype == torch.bfloat16:
120
+ hidden_states = hidden_states.to(dtype)
121
+
122
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
123
+ if self.use_conv:
124
+ if self.name == "conv":
125
+ hidden_states = self.conv(hidden_states)
126
+ else:
127
+ hidden_states = self.Conv2d_0(hidden_states)
128
+
129
+ return hidden_states
130
+
131
+
132
+ class Downsample2D(nn.Module):
133
+ """
134
+ A downsampling layer with an optional convolution.
135
+
136
+ Parameters:
137
+ channels: channels in the inputs and outputs.
138
+ use_conv: a bool determining if a convolution is applied.
139
+ out_channels:
140
+ padding:
141
+ """
142
+
143
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.padding = padding
149
+ stride = 2
150
+ self.name = name
151
+
152
+ if use_conv:
153
+ conv = Pseudo3DConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
157
+
158
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
159
+ if name == "conv":
160
+ self.Conv2d_0 = conv
161
+ self.conv = conv
162
+ elif name == "Conv2d_0":
163
+ self.conv = conv
164
+ else:
165
+ self.conv = conv
166
+
167
+ def forward(self, hidden_states):
168
+ assert hidden_states.shape[1] == self.channels
169
+ if self.use_conv and self.padding == 0:
170
+ pad = (0, 1, 0, 1)
171
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
172
+
173
+ assert hidden_states.shape[1] == self.channels
174
+ if self.use_conv:
175
+ hidden_states = self.conv(hidden_states)
176
+ else:
177
+ b, c, *_, h, w = hidden_states.shape
178
+ is_video = hidden_states.ndim == 5
179
+ if is_video:
180
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
181
+ hidden_states = self.conv(hidden_states)
182
+ if is_video:
183
+ hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b = b)
184
+
185
+ return hidden_states
186
+
187
+
188
+ class ResnetBlockPseudo3D(nn.Module):
189
+ def __init__(
190
+ self,
191
+ *,
192
+ in_channels,
193
+ out_channels=None,
194
+ conv_shortcut=False,
195
+ dropout=0.0,
196
+ temb_channels=512,
197
+ groups=32,
198
+ groups_out=None,
199
+ pre_norm=True,
200
+ eps=1e-6,
201
+ time_embedding_norm="default",
202
+ kernel=None,
203
+ output_scale_factor=1.0,
204
+ use_in_shortcut=None,
205
+ up=False,
206
+ down=False,
207
+ ):
208
+ super().__init__()
209
+ self.pre_norm = pre_norm
210
+ self.pre_norm = True
211
+ self.in_channels = in_channels
212
+ out_channels = in_channels if out_channels is None else out_channels
213
+ self.out_channels = out_channels
214
+ self.use_conv_shortcut = conv_shortcut
215
+ self.time_embedding_norm = time_embedding_norm
216
+ self.up = up
217
+ self.down = down
218
+ self.output_scale_factor = output_scale_factor
219
+ print('OUTPUT_SCALE_FACTOR:', output_scale_factor)
220
+
221
+ if groups_out is None:
222
+ groups_out = groups
223
+
224
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
225
+
226
+ self.conv1 = Pseudo3DConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
227
+
228
+ if temb_channels is not None:
229
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
230
+ else:
231
+ self.time_emb_proj = None
232
+
233
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
234
+ self.dropout = torch.nn.Dropout(dropout)
235
+ self.conv2 = Pseudo3DConv(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
236
+
237
+ self.nonlinearity = nn.SiLU()
238
+
239
+ self.upsample = self.downsample = None
240
+ if self.up:
241
+ self.upsample = Upsample2D(in_channels, use_conv=False)
242
+ elif self.down:
243
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
244
+
245
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
246
+
247
+ self.conv_shortcut = None
248
+ if self.use_in_shortcut:
249
+ self.conv_shortcut = Pseudo3DConv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
250
+
251
+ def forward(self, input_tensor, temb):
252
+ hidden_states = input_tensor
253
+
254
+ hidden_states = self.norm1(hidden_states)
255
+ hidden_states = self.nonlinearity(hidden_states)
256
+
257
+ if self.upsample is not None:
258
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
259
+ if hidden_states.shape[0] >= 64:
260
+ input_tensor = input_tensor.contiguous()
261
+ hidden_states = hidden_states.contiguous()
262
+ input_tensor = self.upsample(input_tensor)
263
+ hidden_states = self.upsample(hidden_states)
264
+ elif self.downsample is not None:
265
+ input_tensor = self.downsample(input_tensor)
266
+ hidden_states = self.downsample(hidden_states)
267
+
268
+ hidden_states = self.conv1(hidden_states)
269
+
270
+ if temb is not None:
271
+ b, c, *_, h, w = hidden_states.shape
272
+ is_video = hidden_states.ndim == 5
273
+ if is_video:
274
+ b, c, f, h, w = hidden_states.shape
275
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
276
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
277
+ hidden_states = hidden_states + temb.repeat_interleave(f, 0)
278
+ hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b=b)
279
+ else:
280
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
281
+ hidden_states = hidden_states + temb
282
+
283
+ hidden_states = self.norm2(hidden_states)
284
+ hidden_states = self.nonlinearity(hidden_states)
285
+
286
+ hidden_states = self.dropout(hidden_states)
287
+ hidden_states = self.conv2(hidden_states)
288
+
289
+ if self.conv_shortcut is not None:
290
+ input_tensor = self.conv_shortcut(input_tensor)
291
+
292
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
293
+
294
+ return output_tensor
295
+
makeavid_sd/torch_impl/torch_unet_pseudo3d_blocks.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional
2
+ import torch
3
+ from torch import nn
4
+
5
+ from torch_attention_pseudo3d import TransformerPseudo3DModel
6
+ from torch_resnet_pseudo3d import Downsample2D, ResnetBlockPseudo3D, Upsample2D
7
+
8
+
9
+ class UNetMidBlock2DCrossAttn(nn.Module):
10
+ def __init__(self,
11
+ in_channels: int,
12
+ temb_channels: int,
13
+ dropout: float = 0.0,
14
+ num_layers: int = 1,
15
+ resnet_eps: float = 1e-6,
16
+ resnet_time_scale_shift: str = "default",
17
+ resnet_act_fn: str = "swish",
18
+ resnet_groups: Optional[int] = 32,
19
+ resnet_pre_norm: bool = True,
20
+ attn_num_head_channels: int = 1,
21
+ attention_type: str = "default",
22
+ output_scale_factor: float =1.0,
23
+ cross_attention_dim: int = 1280,
24
+ **kwargs
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.attention_type = attention_type
29
+ self.attn_num_head_channels = attn_num_head_channels
30
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
31
+
32
+ # there is always at least one resnet
33
+ resnets = [
34
+ ResnetBlockPseudo3D(
35
+ in_channels = in_channels,
36
+ out_channels = in_channels,
37
+ temb_channels = temb_channels,
38
+ eps = resnet_eps,
39
+ groups = resnet_groups,
40
+ dropout = dropout,
41
+ time_embedding_norm = resnet_time_scale_shift,
42
+ #non_linearity = resnet_act_fn,
43
+ output_scale_factor = output_scale_factor,
44
+ pre_norm = resnet_pre_norm
45
+ )
46
+ ]
47
+ attentions = []
48
+
49
+ for _ in range(num_layers):
50
+ attentions.append(
51
+ TransformerPseudo3DModel(
52
+ in_channels = in_channels,
53
+ num_attention_heads = attn_num_head_channels,
54
+ attention_head_dim = in_channels // attn_num_head_channels,
55
+ num_layers = 1,
56
+ cross_attention_dim = cross_attention_dim,
57
+ norm_num_groups = resnet_groups
58
+ )
59
+ )
60
+ resnets.append(
61
+ ResnetBlockPseudo3D(
62
+ in_channels = in_channels,
63
+ out_channels = in_channels,
64
+ temb_channels = temb_channels,
65
+ eps = resnet_eps,
66
+ groups = resnet_groups,
67
+ dropout = dropout,
68
+ time_embedding_norm = resnet_time_scale_shift,
69
+ #non_linearity = resnet_act_fn,
70
+ output_scale_factor = output_scale_factor,
71
+ pre_norm = resnet_pre_norm
72
+ )
73
+ )
74
+
75
+ self.attentions = nn.ModuleList(attentions)
76
+ self.resnets = nn.ModuleList(resnets)
77
+
78
+ def forward(self, hidden_states, temb = None, encoder_hidden_states = None):
79
+ hidden_states = self.resnets[0](hidden_states, temb)
80
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
81
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
82
+ hidden_states = resnet(hidden_states, temb)
83
+
84
+ return hidden_states
85
+
86
+
87
+ class CrossAttnDownBlock2D(nn.Module):
88
+ def __init__(self,
89
+ in_channels: int,
90
+ out_channels: int,
91
+ temb_channels: int,
92
+ dropout: float = 0.0,
93
+ num_layers: int = 1,
94
+ resnet_eps: float = 1e-6,
95
+ resnet_time_scale_shift: str = "default",
96
+ resnet_act_fn: str = "swish",
97
+ resnet_groups: int = 32,
98
+ resnet_pre_norm: bool = True,
99
+ attn_num_head_channels: int = 1,
100
+ cross_attention_dim: int = 1280,
101
+ attention_type: str = "default",
102
+ output_scale_factor: float = 1.0,
103
+ downsample_padding: int = 1,
104
+ add_downsample: bool = True
105
+ ):
106
+ super().__init__()
107
+ resnets = []
108
+ attentions = []
109
+
110
+ self.attention_type = attention_type
111
+ self.attn_num_head_channels = attn_num_head_channels
112
+
113
+ for i in range(num_layers):
114
+ in_channels = in_channels if i == 0 else out_channels
115
+ resnets.append(
116
+ ResnetBlockPseudo3D(
117
+ in_channels = in_channels,
118
+ out_channels = out_channels,
119
+ temb_channels = temb_channels,
120
+ eps = resnet_eps,
121
+ groups = resnet_groups,
122
+ dropout = dropout,
123
+ time_embedding_norm = resnet_time_scale_shift,
124
+ #non_linearity = resnet_act_fn,
125
+ output_scale_factor = output_scale_factor,
126
+ pre_norm = resnet_pre_norm
127
+ )
128
+ )
129
+ attentions.append(
130
+ TransformerPseudo3DModel(
131
+ in_channels = out_channels,
132
+ num_attention_heads = attn_num_head_channels,
133
+ attention_head_dim = out_channels // attn_num_head_channels,
134
+ num_layers = 1,
135
+ cross_attention_dim = cross_attention_dim,
136
+ norm_num_groups = resnet_groups
137
+ )
138
+ )
139
+ self.attentions = nn.ModuleList(attentions)
140
+ self.resnets = nn.ModuleList(resnets)
141
+
142
+ if add_downsample:
143
+ self.downsamplers = nn.ModuleList(
144
+ [
145
+ Downsample2D(
146
+ out_channels,
147
+ use_conv = True,
148
+ out_channels = out_channels,
149
+ padding = downsample_padding,
150
+ name = "op"
151
+ )
152
+ ]
153
+ )
154
+ else:
155
+ self.downsamplers = None
156
+
157
+ def forward(self, hidden_states, temb = None, encoder_hidden_states = None):
158
+ output_states = ()
159
+
160
+ for resnet, attn in zip(self.resnets, self.attentions):
161
+ hidden_states = resnet(hidden_states, temb)
162
+ hidden_states = attn(hidden_states, encoder_hidden_states = encoder_hidden_states).sample
163
+
164
+ output_states += (hidden_states,)
165
+
166
+ if self.downsamplers is not None:
167
+ for downsampler in self.downsamplers:
168
+ hidden_states = downsampler(hidden_states)
169
+
170
+ output_states += (hidden_states,)
171
+
172
+ return hidden_states, output_states
173
+
174
+
175
+ class DownBlock2D(nn.Module):
176
+ def __init__(self,
177
+ in_channels: int,
178
+ out_channels: int,
179
+ temb_channels: int,
180
+ dropout: float = 0.0,
181
+ num_layers: int = 1,
182
+ resnet_eps: float = 1e-6,
183
+ resnet_time_scale_shift: str = "default",
184
+ resnet_act_fn: str = "swish",
185
+ resnet_groups: int = 32,
186
+ resnet_pre_norm: bool = True,
187
+ output_scale_factor: float = 1.0,
188
+ add_downsample: bool = True,
189
+ downsample_padding: int = 1
190
+ ) -> None:
191
+ super().__init__()
192
+ resnets = []
193
+
194
+ for i in range(num_layers):
195
+ in_channels = in_channels if i == 0 else out_channels
196
+ resnets.append(
197
+ ResnetBlockPseudo3D(
198
+ in_channels = in_channels,
199
+ out_channels = out_channels,
200
+ temb_channels = temb_channels,
201
+ eps = resnet_eps,
202
+ groups = resnet_groups,
203
+ dropout = dropout,
204
+ time_embedding_norm = resnet_time_scale_shift,
205
+ #non_linearity = resnet_act_fn,
206
+ output_scale_factor = output_scale_factor,
207
+ pre_norm = resnet_pre_norm
208
+ )
209
+ )
210
+
211
+ self.resnets = nn.ModuleList(resnets)
212
+
213
+ if add_downsample:
214
+ self.downsamplers = nn.ModuleList(
215
+ [
216
+ Downsample2D(
217
+ out_channels,
218
+ use_conv = True,
219
+ out_channels = out_channels,
220
+ padding = downsample_padding,
221
+ name = "op"
222
+ )
223
+ ]
224
+ )
225
+ else:
226
+ self.downsamplers = None
227
+
228
+
229
+ def forward(self, hidden_states, temb = None):
230
+ output_states = ()
231
+
232
+ for resnet in self.resnets:
233
+ hidden_states = resnet(hidden_states, temb)
234
+
235
+ output_states += (hidden_states,)
236
+
237
+ if self.downsamplers is not None:
238
+ for downsampler in self.downsamplers:
239
+ hidden_states = downsampler(hidden_states)
240
+
241
+ output_states += (hidden_states,)
242
+
243
+ return hidden_states, output_states
244
+
245
+
246
+ class CrossAttnUpBlock2D(nn.Module):
247
+ def __init__(self,
248
+ in_channels: int,
249
+ out_channels: int,
250
+ prev_output_channel: int,
251
+ temb_channels: int,
252
+ dropout: float = 0.0,
253
+ num_layers: int = 1,
254
+ resnet_eps: float = 1e-6,
255
+ resnet_time_scale_shift: str = "default",
256
+ resnet_act_fn: str = "swish",
257
+ resnet_groups: int = 32,
258
+ resnet_pre_norm: bool = True,
259
+ attn_num_head_channels: int = 1,
260
+ cross_attention_dim: int = 1280,
261
+ attention_type: str = "default",
262
+ output_scale_factor: float = 1.0,
263
+ add_upsample: bool = True
264
+ ) -> None:
265
+ super().__init__()
266
+ resnets = []
267
+ attentions = []
268
+
269
+ self.attention_type = attention_type
270
+ self.attn_num_head_channels = attn_num_head_channels
271
+
272
+ for i in range(num_layers):
273
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
274
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
275
+
276
+ resnets.append(
277
+ ResnetBlockPseudo3D(
278
+ in_channels = resnet_in_channels + res_skip_channels,
279
+ out_channels = out_channels,
280
+ temb_channels = temb_channels,
281
+ eps = resnet_eps,
282
+ groups = resnet_groups,
283
+ dropout = dropout,
284
+ time_embedding_norm = resnet_time_scale_shift,
285
+ #non_linearity = resnet_act_fn,
286
+ output_scale_factor = output_scale_factor,
287
+ pre_norm = resnet_pre_norm
288
+ )
289
+ )
290
+ attentions.append(
291
+ TransformerPseudo3DModel(
292
+ in_channels = out_channels,
293
+ num_attention_heads = attn_num_head_channels,
294
+ attention_head_dim = out_channels // attn_num_head_channels,
295
+ num_layers = 1,
296
+ cross_attention_dim = cross_attention_dim,
297
+ norm_num_groups = resnet_groups
298
+ )
299
+ )
300
+ self.attentions = nn.ModuleList(attentions)
301
+ self.resnets = nn.ModuleList(resnets)
302
+
303
+ if add_upsample:
304
+ self.upsamplers = nn.ModuleList([
305
+ Upsample2D(
306
+ out_channels,
307
+ use_conv = True,
308
+ out_channels = out_channels
309
+ )
310
+ ])
311
+ else:
312
+ self.upsamplers = None
313
+
314
+ def forward(self,
315
+ hidden_states,
316
+ res_hidden_states_tuple,
317
+ temb = None,
318
+ encoder_hidden_states = None,
319
+ upsample_size = None
320
+ ):
321
+ for resnet, attn in zip(self.resnets, self.attentions):
322
+ # pop res hidden states
323
+ res_hidden_states = res_hidden_states_tuple[-1]
324
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
325
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
326
+ hidden_states = resnet(hidden_states, temb)
327
+ hidden_states = attn(hidden_states, encoder_hidden_states = encoder_hidden_states).sample
328
+
329
+ if self.upsamplers is not None:
330
+ for upsampler in self.upsamplers:
331
+ hidden_states = upsampler(hidden_states, upsample_size)
332
+
333
+ return hidden_states
334
+
335
+
336
+ class UpBlock2D(nn.Module):
337
+ def __init__(self,
338
+ in_channels: int,
339
+ prev_output_channel: int,
340
+ out_channels: int,
341
+ temb_channels: int,
342
+ dropout: float = 0.0,
343
+ num_layers: int = 1,
344
+ resnet_eps: float = 1e-6,
345
+ resnet_time_scale_shift: str = "default",
346
+ resnet_act_fn: str = "swish",
347
+ resnet_groups: int = 32,
348
+ resnet_pre_norm: bool = True,
349
+ output_scale_factor: float = 1.0,
350
+ add_upsample: bool = True
351
+ ) -> None:
352
+ super().__init__()
353
+ resnets = []
354
+
355
+ for i in range(num_layers):
356
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
357
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
358
+
359
+ resnets.append(
360
+ ResnetBlockPseudo3D(
361
+ in_channels = resnet_in_channels + res_skip_channels,
362
+ out_channels = out_channels,
363
+ temb_channels = temb_channels,
364
+ eps = resnet_eps,
365
+ groups = resnet_groups,
366
+ dropout = dropout,
367
+ time_embedding_norm = resnet_time_scale_shift,
368
+ #non_linearity = resnet_act_fn,
369
+ output_scale_factor = output_scale_factor,
370
+ pre_norm = resnet_pre_norm
371
+ )
372
+ )
373
+
374
+ self.resnets = nn.ModuleList(resnets)
375
+
376
+ if add_upsample:
377
+ self.upsamplers = nn.ModuleList([
378
+ Upsample2D(
379
+ out_channels,
380
+ use_conv = True,
381
+ out_channels = out_channels
382
+ )
383
+ ])
384
+ else:
385
+ self.upsamplers = None
386
+
387
+
388
+ def forward(self, hidden_states, res_hidden_states_tuple, temb = None, upsample_size = None):
389
+ for resnet in self.resnets:
390
+ # pop res hidden states
391
+ res_hidden_states = res_hidden_states_tuple[-1]
392
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
393
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
394
+ hidden_states = resnet(hidden_states, temb)
395
+
396
+ if self.upsamplers is not None:
397
+ for upsampler in self.upsamplers:
398
+ hidden_states = upsampler(hidden_states, upsample_size)
399
+
400
+ return hidden_states
401
+
402
+
403
+ def get_down_block(
404
+ down_block_type: str,
405
+ num_layers: int,
406
+ in_channels: int,
407
+ out_channels: int,
408
+ temb_channels: int,
409
+ add_downsample: bool,
410
+ resnet_eps: float,
411
+ resnet_act_fn: str,
412
+ attn_num_head_channels: int,
413
+ resnet_groups: Optional[int] = None,
414
+ cross_attention_dim: Optional[int] = None,
415
+ downsample_padding: Optional[int] = None,
416
+ ) -> Union[DownBlock2D, CrossAttnDownBlock2D]:
417
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
418
+ if down_block_type == "DownBlock2D":
419
+ return DownBlock2D(
420
+ num_layers = num_layers,
421
+ in_channels = in_channels,
422
+ out_channels = out_channels,
423
+ temb_channels = temb_channels,
424
+ add_downsample = add_downsample,
425
+ resnet_eps = resnet_eps,
426
+ resnet_act_fn = resnet_act_fn,
427
+ resnet_groups = resnet_groups,
428
+ downsample_padding = downsample_padding
429
+ )
430
+ elif down_block_type == "CrossAttnDownBlock2D":
431
+ if cross_attention_dim is None:
432
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
433
+ return CrossAttnDownBlock2D(
434
+ num_layers = num_layers,
435
+ in_channels = in_channels,
436
+ out_channels = out_channels,
437
+ temb_channels = temb_channels,
438
+ add_downsample = add_downsample,
439
+ resnet_eps = resnet_eps,
440
+ resnet_act_fn = resnet_act_fn,
441
+ resnet_groups = resnet_groups,
442
+ downsample_padding = downsample_padding,
443
+ cross_attention_dim = cross_attention_dim,
444
+ attn_num_head_channels = attn_num_head_channels
445
+ )
446
+ raise ValueError(f"{down_block_type} does not exist.")
447
+
448
+
449
+ def get_up_block(
450
+ up_block_type: str,
451
+ num_layers,
452
+ in_channels,
453
+ out_channels,
454
+ prev_output_channel,
455
+ temb_channels,
456
+ add_upsample,
457
+ resnet_eps,
458
+ resnet_act_fn,
459
+ attn_num_head_channels,
460
+ resnet_groups = None,
461
+ cross_attention_dim = None,
462
+ ) -> Union[UpBlock2D, CrossAttnUpBlock2D]:
463
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
464
+ if up_block_type == "UpBlock2D":
465
+ return UpBlock2D(
466
+ num_layers = num_layers,
467
+ in_channels = in_channels,
468
+ out_channels = out_channels,
469
+ prev_output_channel = prev_output_channel,
470
+ temb_channels = temb_channels,
471
+ add_upsample = add_upsample,
472
+ resnet_eps = resnet_eps,
473
+ resnet_act_fn = resnet_act_fn,
474
+ resnet_groups = resnet_groups
475
+ )
476
+ elif up_block_type == "CrossAttnUpBlock2D":
477
+ if cross_attention_dim is None:
478
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
479
+ return CrossAttnUpBlock2D(
480
+ num_layers = num_layers,
481
+ in_channels = in_channels,
482
+ out_channels = out_channels,
483
+ prev_output_channel = prev_output_channel,
484
+ temb_channels = temb_channels,
485
+ add_upsample = add_upsample,
486
+ resnet_eps = resnet_eps,
487
+ resnet_act_fn = resnet_act_fn,
488
+ resnet_groups = resnet_groups,
489
+ cross_attention_dim = cross_attention_dim,
490
+ attn_num_head_channels = attn_num_head_channels
491
+ )
492
+ raise ValueError(f"{up_block_type} does not exist.")
493
+
makeavid_sd/torch_impl/torch_unet_pseudo3d_condition.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn as nn
6
+
7
+ from torch_embeddings import TimestepEmbedding, Timesteps
8
+ from torch_unet_pseudo3d_blocks import (
9
+ UNetMidBlock2DCrossAttn,
10
+ get_down_block,
11
+ get_up_block,
12
+ )
13
+
14
+ from torch_resnet_pseudo3d import Pseudo3DConv
15
+
16
+ class UNetPseudo3DConditionOutput:
17
+ sample: torch.FloatTensor
18
+ def __init__(self, sample: torch.FloatTensor) -> None:
19
+ self.sample = sample
20
+
21
+
22
+ class UNetPseudo3DConditionModel(nn.Module):
23
+ def __init__(self,
24
+ sample_size: Optional[int] = None,
25
+ in_channels: int = 9,
26
+ out_channels: int = 4,
27
+ flip_sin_to_cos: bool = True,
28
+ freq_shift: int = 0,
29
+ down_block_types: Tuple[str] = (
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "DownBlock2D",
34
+ ),
35
+ up_block_types: Tuple[str] = (
36
+ "UpBlock2D",
37
+ "CrossAttnUpBlock2D",
38
+ "CrossAttnUpBlock2D",
39
+ "CrossAttnUpBlock2D"
40
+ ),
41
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
42
+ layers_per_block: int = 2,
43
+ downsample_padding: int = 1,
44
+ mid_block_scale_factor: float = 1,
45
+ act_fn: str = "silu",
46
+ norm_num_groups: int = 32,
47
+ norm_eps: float = 1e-5,
48
+ cross_attention_dim: int = 768,
49
+ attention_head_dim: int = 8,
50
+ **kwargs
51
+ ) -> None:
52
+ super().__init__()
53
+ self.dtype = torch.float32
54
+ self.sample_size = sample_size
55
+ time_embed_dim = block_out_channels[0] * 4
56
+
57
+ # input
58
+ self.conv_in = Pseudo3DConv(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
59
+
60
+ # time
61
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
62
+ timestep_input_dim = block_out_channels[0]
63
+
64
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
65
+
66
+ self.down_blocks = nn.ModuleList([])
67
+ self.mid_block = None
68
+ self.up_blocks = nn.ModuleList([])
69
+
70
+ # down
71
+ output_channel = block_out_channels[0]
72
+ for i, down_block_type in enumerate(down_block_types):
73
+ input_channel = output_channel
74
+ output_channel = block_out_channels[i]
75
+ is_final_block = i == len(block_out_channels) - 1
76
+
77
+ down_block = get_down_block(
78
+ down_block_type,
79
+ num_layers = layers_per_block,
80
+ in_channels = input_channel,
81
+ out_channels = output_channel,
82
+ temb_channels = time_embed_dim,
83
+ add_downsample = not is_final_block,
84
+ resnet_eps = norm_eps,
85
+ resnet_act_fn = act_fn,
86
+ resnet_groups = norm_num_groups,
87
+ cross_attention_dim = cross_attention_dim,
88
+ attn_num_head_channels = attention_head_dim,
89
+ downsample_padding = downsample_padding
90
+ )
91
+ self.down_blocks.append(down_block)
92
+
93
+ # mid
94
+ self.mid_block = UNetMidBlock2DCrossAttn(
95
+ in_channels = block_out_channels[-1],
96
+ temb_channels = time_embed_dim,
97
+ resnet_eps = norm_eps,
98
+ resnet_act_fn = act_fn,
99
+ output_scale_factor = mid_block_scale_factor,
100
+ resnet_time_scale_shift = "default",
101
+ cross_attention_dim = cross_attention_dim,
102
+ attn_num_head_channels = attention_head_dim,
103
+ resnet_groups = norm_num_groups
104
+ )
105
+
106
+ # count how many layers upsample the images
107
+ self.num_upsamplers = 0
108
+
109
+ # up
110
+ reversed_block_out_channels = list(reversed(block_out_channels))
111
+ output_channel = reversed_block_out_channels[0]
112
+ for i, up_block_type in enumerate(up_block_types):
113
+ is_final_block = i == len(block_out_channels) - 1
114
+
115
+ prev_output_channel = output_channel
116
+ output_channel = reversed_block_out_channels[i]
117
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
118
+
119
+ # add upsample block for all BUT final layer
120
+ if not is_final_block:
121
+ add_upsample = True
122
+ self.num_upsamplers += 1
123
+ else:
124
+ add_upsample = False
125
+
126
+ up_block = get_up_block(
127
+ up_block_type,
128
+ num_layers = layers_per_block + 1,
129
+ in_channels = input_channel,
130
+ out_channels = output_channel,
131
+ prev_output_channel = prev_output_channel,
132
+ temb_channels = time_embed_dim,
133
+ add_upsample = add_upsample,
134
+ resnet_eps = norm_eps,
135
+ resnet_act_fn = act_fn,
136
+ resnet_groups = norm_num_groups,
137
+ cross_attention_dim = cross_attention_dim,
138
+ attn_num_head_channels = attention_head_dim
139
+ )
140
+ self.up_blocks.append(up_block)
141
+ prev_output_channel = output_channel
142
+
143
+ # out
144
+ self.conv_norm_out = nn.GroupNorm(
145
+ num_channels = block_out_channels[0],
146
+ num_groups = norm_num_groups,
147
+ eps = norm_eps
148
+ )
149
+ self.conv_act = nn.SiLU()
150
+ self.conv_out = Pseudo3DConv(block_out_channels[0], out_channels, 3, padding = 1)
151
+
152
+
153
+ def forward(
154
+ self,
155
+ sample: torch.FloatTensor,
156
+ timesteps: Union[torch.Tensor, float, int],
157
+ encoder_hidden_states: torch.Tensor
158
+ ) -> Union[UNetPseudo3DConditionOutput, Tuple]:
159
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
160
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
161
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
162
+ # on the fly if necessary.
163
+ default_overall_up_factor = 2**self.num_upsamplers
164
+
165
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
166
+ forward_upsample_size = False
167
+ upsample_size = None
168
+
169
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
170
+ forward_upsample_size = True
171
+
172
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
173
+ timesteps = timesteps.expand(sample.shape[0])
174
+
175
+ t_emb = self.time_proj(timesteps)
176
+
177
+ # timesteps does not contain any weights and will always return f32 tensors
178
+ # but time_embedding might actually be running in fp16. so we need to cast here.
179
+ # there might be better ways to encapsulate this.
180
+ t_emb = t_emb.to(dtype=self.dtype)
181
+ emb = self.time_embedding(t_emb)
182
+
183
+ # 2. pre-process
184
+ sample = self.conv_in(sample)
185
+
186
+ # 3. down
187
+ down_block_res_samples = (sample,)
188
+ for downsample_block in self.down_blocks:
189
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
190
+ sample, res_samples = downsample_block(
191
+ hidden_states = sample,
192
+ temb = emb,
193
+ encoder_hidden_states = encoder_hidden_states,
194
+ )
195
+ else:
196
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
197
+
198
+ down_block_res_samples += res_samples
199
+
200
+ # 4. mid
201
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
202
+
203
+ # 5. up
204
+ for i, upsample_block in enumerate(self.up_blocks):
205
+ is_final_block = i == len(self.up_blocks) - 1
206
+
207
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
208
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
209
+
210
+ # if we have not reached the final block and need to forward the
211
+ # upsample size, we do it here
212
+ if not is_final_block and forward_upsample_size:
213
+ upsample_size = down_block_res_samples[-1].shape[2:]
214
+
215
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
216
+ sample = upsample_block(
217
+ hidden_states = sample,
218
+ temb = emb,
219
+ res_hidden_states_tuple = res_samples,
220
+ encoder_hidden_states = encoder_hidden_states,
221
+ upsample_size = upsample_size,
222
+ )
223
+ else:
224
+ sample = upsample_block(
225
+ hidden_states = sample,
226
+ temb = emb,
227
+ res_hidden_states_tuple = res_samples,
228
+ upsample_size = upsample_size
229
+ )
230
+ # 6. post-process
231
+ sample = self.conv_norm_out(sample)
232
+ sample = self.conv_act(sample)
233
+ sample = self.conv_out(sample)
234
+
235
+ return UNetPseudo3DConditionOutput(sample = sample)
packages.txt ADDED
File without changes
pre-requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pip
2
+ setuptools
3
+ wheel
4
+ ninja
5
+ cmake
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pillow
3
+ transformers
4
+ diffusers
5
+ einops
6
+ -f https://download.pytorch.org/whl/cpu/torch
7
+ torch[cpu]
8
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
9
+ jax[cuda11_pip] #jax[cuda11_cudnn82] #jax[cuda11_cudnn86] #jax[cuda11_cudnn805]
10
+ flax