sky24h commited on
Commit
2f3aac0
·
1 Parent(s): 2a2a7e5

init commit

Browse files
.gitattributes CHANGED
@@ -17,6 +17,7 @@
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
 
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
 
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
21
  *.pickle filter=lfs diff=lfs merge=lfs -text
22
  *.pkl filter=lfs diff=lfs merge=lfs -text
23
  *.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: FLATTEN Unofficial
3
- emoji: 😻
4
  colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: FLATTEN
3
+ emoji: 📽️
4
  colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,7 +1,57 @@
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ from inference_utils import inference
4
 
 
 
5
 
6
+ @spaces.GPU(duration=120)
7
+ def send_to_model(source_video, prompt, neg_prompt, guidance_scale, video_length, old_qk):
8
+ return inference(prompt=prompt, neg_prompt=neg_prompt, guidance_scale=guidance_scale, video_length=video_length, video_path=source_video, old_qk=old_qk)
9
+
10
+
11
+ if __name__ == "__main__":
12
+ with gr.Blocks() as demo:
13
+ gr.HTML(
14
+ """
15
+ <h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
16
+ FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing
17
+ </h1>
18
+ <p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
19
+ <a style="text-align: center; display:inline-block"
20
+ href="https://flatten-video-editing.github.io/">
21
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
22
+ alt="Paper Page">
23
+ </a>
24
+ <a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/FLATTEN-unofficial?duplicate=true">
25
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
26
+ </a>
27
+ </p>
28
+ """
29
+ )
30
+ gr.Interface(
31
+ fn=send_to_model,
32
+ inputs=[
33
+ gr.Video(value=None, label="Source Image"),
34
+ gr.Textbox(value="", label="Prompt"),
35
+ gr.Textbox(value="", label="Negative Prompt"),
36
+ gr.Slider(
37
+ value = 15,
38
+ minimum = 10,
39
+ maximum = 30,
40
+ step = 1,
41
+ label = "guidance_scale",
42
+ info = "The scale of the guidance field.",
43
+ ),
44
+ gr.Textbox(value=16, label="Video Length", info="The length of the video, must be less than 16 frames in the online demo to avoid timeout. However, you can run the model locally to process longer videos."),
45
+ gr.Dropdown(value=0, choices=[0, 1], label="Choose Option", info="Select 0 or 1."),
46
+ ],
47
+ outputs=[gr.Video(label="output", autoplay=True)],
48
+ allow_flagging="never",
49
+ description="This is an unofficial demo for the paper 'FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing'.",
50
+ examples=[
51
+ ["./data/puff.mp4", "A Tiger, high quality", "a cat with big eyes, deformed", 20, 16, 0],
52
+ ["./data/background.mp4", "pointillism painting, detailed", "", 25, 16, 1],
53
+ ["./data/trucks-race.mp4", "Wooden trucks drive on a racetrack.", "", 15, 16, 1],
54
+ ],
55
+ cache_examples=True,
56
+ )
57
+ demo.queue(max_size=10).launch()
checkpoints/unet/config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.10.0.dev0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": [
6
+ 5,
7
+ 10,
8
+ 20,
9
+ 20
10
+ ],
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "cross_attention_dim": 1024,
19
+ "down_block_types": [
20
+ "CrossAttnDownBlock2D",
21
+ "CrossAttnDownBlock2D",
22
+ "CrossAttnDownBlock2D",
23
+ "DownBlock2D"
24
+ ],
25
+ "downsample_padding": 1,
26
+ "dual_cross_attention": false,
27
+ "flip_sin_to_cos": true,
28
+ "freq_shift": 0,
29
+ "in_channels": 4,
30
+ "layers_per_block": 2,
31
+ "mid_block_scale_factor": 1,
32
+ "norm_eps": 1e-05,
33
+ "norm_num_groups": 32,
34
+ "num_class_embeds": null,
35
+ "only_cross_attention": false,
36
+ "out_channels": 4,
37
+ "sample_size": 64,
38
+ "up_block_types": [
39
+ "UpBlock2D",
40
+ "CrossAttnUpBlock2D",
41
+ "CrossAttnUpBlock2D",
42
+ "CrossAttnUpBlock2D"
43
+ ],
44
+ "use_linear_projection": true
45
+ }
checkpoints/unet/diffusion_pytorch_model.fp16.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da861650fc7df96390abffcbb9bcf67b7c5566422fda5af9bc003605be65c5f3
3
+ size 1732107093
data/background.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a17a3627a373b49a377e27c2efc33da739e6114ee919cbf016305683a47dacf
3
+ size 207581
data/puff.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a730d21ba80ec72d8fd19cbf55b19372210f97e2a3005a3d424bec0c81c1c8e4
3
+ size 203311
data/trucks-race.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ee14a4157cb65432b147e2c3c63d07c49d9dc79ba3d1ee696d068d133b9beba
3
+ size 1232961
gradio_cached_examples/19/log.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ output,flag,username,timestamp
2
+ "{""video"": {""path"": ""gradio_cached_examples/19/output/d89cb02e15074188f8f6/A-Tiger-high-quality_a-cat-with-big-eyes-deformed_20_1727813741.3332028.mp4"", ""url"": ""/file=/tmp/gradio/17794950654f68456a056c68191682fee58f63e5/A-Tiger-high-quality_a-cat-with-big-eyes-deformed_20_1727813741.3332028.mp4"", ""size"": null, ""orig_name"": ""A-Tiger,-high-quality_a-cat-with-big-eyes,-deformed_20_1727813741.3332028.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-10-01 20:15:41.884688
3
+ "{""video"": {""path"": ""gradio_cached_examples/19/output/d9fa9e82120d2c0f1786/pointillism-painting-detailed__25_1727813866.2198303.mp4"", ""url"": ""/file=/tmp/gradio/164e0b9a788d053c6e89e46accc931aa4ecc036d/pointillism-painting-detailed__25_1727813866.2198303.mp4"", ""size"": null, ""orig_name"": ""pointillism-painting,-detailed__25_1727813866.2198303.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-10-01 20:17:46.857907
4
+ "{""video"": {""path"": ""gradio_cached_examples/19/output/2be6ecf324aed22bf7b6/Wooden-trucks-drive-on-a-racetrack.__15_1727813992.8895943.mp4"", ""url"": ""/file=/tmp/gradio/76a9de37ebb4bb2027a17530fced1febf5100b0b/Wooden-trucks-drive-on-a-racetrack.__15_1727813992.8895943.mp4"", ""size"": null, ""orig_name"": ""Wooden-trucks-drive-on-a-racetrack.__15_1727813992.8895943.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-10-01 20:19:53.549070
gradio_cached_examples/19/output/2be6ecf324aed22bf7b6/Wooden-trucks-drive-on-a-racetrack.__15_1727813992.8895943.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc7f6ff21e460e91e5de667278b64b83bc706007b412a8d263709a30e3be8898
3
+ size 92054
gradio_cached_examples/19/output/d89cb02e15074188f8f6/A-Tiger-high-quality_a-cat-with-big-eyes-deformed_20_1727813741.3332028.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3746fd58911320cabcb611c7e311b21bdee0273ea5e9801a715cfd400428d37
3
+ size 273838
gradio_cached_examples/19/output/d9fa9e82120d2c0f1786/pointillism-painting-detailed__25_1727813866.2198303.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0395774a685b6e042d183c1d016771be4a79f1a1782a510e49099267ce24ebe1
3
+ size 701932
inference_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import imageio
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+ from diffusers import DDIMScheduler, AutoencoderKL, DDIMInverseScheduler
11
+
12
+ from models.pipeline_flatten import FlattenPipeline
13
+ from models.util import sample_trajectories
14
+ from models.unet import UNet3DConditionModel
15
+
16
+
17
+ def init_pipeline(device):
18
+ dtype = torch.float16
19
+ sd_path = "stabilityai/stable-diffusion-2-1-base"
20
+ UNET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints", "unet")
21
+ unet = UNet3DConditionModel.from_pretrained_2d(UNET_PATH, dtype=torch.float16)
22
+ # unet = UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet").to(dtype=torch.float16)
23
+
24
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(dtype=torch.float16)
25
+ tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer", dtype=dtype)
26
+ text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
27
+ scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
28
+ inverse = DDIMInverseScheduler.from_pretrained(sd_path, subfolder="scheduler")
29
+
30
+ pipe = FlattenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, inverse_scheduler=inverse)
31
+ pipe.enable_vae_slicing()
32
+ pipe.enable_xformers_memory_efficient_attention()
33
+ pipe.to(device)
34
+ return pipe
35
+
36
+
37
+ height = 512
38
+ width = 512
39
+ sample_steps = 50
40
+ inject_step = 40
41
+
42
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
43
+ pipe = init_pipeline(device)
44
+
45
+
46
+ def inference(
47
+ seed = 66,
48
+ prompt = None,
49
+ neg_prompt = "",
50
+ guidance_scale = 10.0,
51
+ video_length = 16,
52
+ video_path = None,
53
+ output_dir = None,
54
+ frame_rate = 1,
55
+ fps = 15,
56
+ old_qk = 0,
57
+ ):
58
+ generator = torch.Generator(device=device)
59
+ generator.manual_seed(seed)
60
+
61
+ # read the source video
62
+ video_reader = imageio.get_reader(video_path, "ffmpeg")
63
+ video = []
64
+ for frame in video_reader:
65
+ if len(video) >= video_length:
66
+ break
67
+ video.append(cv2.resize(frame, (width, height))) # .transpose(2, 0, 1))
68
+ real_frames = [Image.fromarray(frame) for frame in video]
69
+
70
+ # compute optical flows and sample trajectories
71
+ trajectories = sample_trajectories(torch.tensor(np.array(video)).permute(0, 3, 1, 2), device)
72
+ torch.cuda.empty_cache()
73
+
74
+ for k in trajectories.keys():
75
+ trajectories[k] = trajectories[k].to(device)
76
+ sample = (pipe(
77
+ prompt,
78
+ video_length = video_length,
79
+ frames = real_frames,
80
+ num_inference_steps = sample_steps,
81
+ generator = generator,
82
+ guidance_scale = guidance_scale,
83
+ negative_prompt = neg_prompt,
84
+ width = width,
85
+ height = height,
86
+ trajs = trajectories,
87
+ output_dir = "tmp/",
88
+ inject_step = inject_step,
89
+ old_qk = old_qk,
90
+ ).videos[0].permute(1, 2, 3, 0).cpu().numpy() * 255).astype(np.uint8)
91
+ temp_video_name = f"/tmp/{prompt}_{neg_prompt}_{str(guidance_scale)}_{time.time()}.mp4".replace(" ", "-")
92
+ video_writer = imageio.get_writer(temp_video_name, fps=fps)
93
+ for frame in sample:
94
+ video_writer.append_data(frame)
95
+ print(f"Saving video to {temp_video_name}, sample shape: {sample.shape}")
96
+ return temp_video_name
97
+
98
+
99
+ if __name__ == "__main__":
100
+ video_path = "./data/puff.mp4"
101
+ inference(
102
+ video_path = video_path,
103
+ prompt = "A Tiger, high quality",
104
+ neg_prompt = "a cat with big eyes, deformed",
105
+ guidance_scale = 20,
106
+ old_qk = 0,
107
+ )
models/__init__.py ADDED
File without changes
models/attention.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Callable
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers import ModelMixin
11
+ from diffusers.utils import BaseOutput
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
14
+ from diffusers.models.cross_attention import CrossAttention
15
+ from einops import rearrange, repeat
16
+
17
+ @dataclass
18
+ class Transformer3DModelOutput(BaseOutput):
19
+ sample: torch.FloatTensor
20
+
21
+
22
+ if is_xformers_available():
23
+ import xformers
24
+ import xformers.ops
25
+ else:
26
+ xformers = None
27
+
28
+
29
+ class Transformer3DModel(ModelMixin, ConfigMixin):
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ ):
47
+ super().__init__()
48
+ self.use_linear_projection = use_linear_projection
49
+ self.num_attention_heads = num_attention_heads
50
+ self.attention_head_dim = attention_head_dim
51
+ inner_dim = num_attention_heads * attention_head_dim
52
+
53
+ # Define input layers
54
+ self.in_channels = in_channels
55
+
56
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
57
+ if use_linear_projection:
58
+ self.proj_in = nn.Linear(in_channels, inner_dim)
59
+ else:
60
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
61
+
62
+ # Define transformers blocks
63
+ self.transformer_blocks = nn.ModuleList(
64
+ [
65
+ BasicTransformerBlock(
66
+ inner_dim,
67
+ num_attention_heads,
68
+ attention_head_dim,
69
+ dropout=dropout,
70
+ cross_attention_dim=cross_attention_dim,
71
+ activation_fn=activation_fn,
72
+ num_embeds_ada_norm=num_embeds_ada_norm,
73
+ attention_bias=attention_bias,
74
+ only_cross_attention=only_cross_attention,
75
+ upcast_attention=upcast_attention,
76
+ )
77
+ for d in range(num_layers)
78
+ ]
79
+ )
80
+
81
+ # 4. Define output layers
82
+ if use_linear_projection:
83
+ self.proj_out = nn.Linear(in_channels, inner_dim)
84
+ else:
85
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
86
+
87
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, \
88
+ inter_frame=False, **kwargs):
89
+ # Input
90
+
91
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
92
+ video_length = hidden_states.shape[2]
93
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
94
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
95
+
96
+ batch, channel, height, weight = hidden_states.shape
97
+ residual = hidden_states
98
+
99
+ # check resolution
100
+ resolu = hidden_states.shape[-1]
101
+ trajs = {}
102
+ trajs["traj"] = kwargs["trajs"]["traj{}".format(resolu)]
103
+ trajs["mask"] = kwargs["trajs"]["mask{}".format(resolu)]
104
+ trajs["t"] = kwargs["t"]
105
+ trajs["old_qk"] = kwargs["old_qk"]
106
+
107
+ hidden_states = self.norm(hidden_states)
108
+ if not self.use_linear_projection:
109
+ hidden_states = self.proj_in(hidden_states)
110
+ inner_dim = hidden_states.shape[1]
111
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
112
+ else:
113
+ inner_dim = hidden_states.shape[1]
114
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
115
+ hidden_states = self.proj_in(hidden_states)
116
+
117
+ # Blocks
118
+ for block in self.transformer_blocks:
119
+ hidden_states = block(
120
+ hidden_states,
121
+ encoder_hidden_states=encoder_hidden_states,
122
+ timestep=timestep,
123
+ video_length=video_length,
124
+ inter_frame=inter_frame,
125
+ **trajs
126
+ )
127
+
128
+ # Output
129
+ if not self.use_linear_projection:
130
+ hidden_states = (
131
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
132
+ )
133
+ hidden_states = self.proj_out(hidden_states)
134
+ else:
135
+ hidden_states = self.proj_out(hidden_states)
136
+ hidden_states = (
137
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
138
+ )
139
+
140
+ output = hidden_states + residual
141
+
142
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
143
+ if not return_dict:
144
+ return (output,)
145
+
146
+ return Transformer3DModelOutput(sample=output)
147
+
148
+
149
+ class BasicTransformerBlock(nn.Module):
150
+ def __init__(
151
+ self,
152
+ dim: int,
153
+ num_attention_heads: int,
154
+ attention_head_dim: int,
155
+ dropout=0.0,
156
+ cross_attention_dim: Optional[int] = None,
157
+ activation_fn: str = "geglu",
158
+ num_embeds_ada_norm: Optional[int] = None,
159
+ attention_bias: bool = False,
160
+ only_cross_attention: bool = False,
161
+ upcast_attention: bool = False,
162
+ ):
163
+ super().__init__()
164
+ self.only_cross_attention = only_cross_attention
165
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
166
+
167
+ # Fully
168
+ self.attn1 = FullyFrameAttention(
169
+ query_dim=dim,
170
+ heads=num_attention_heads,
171
+ dim_head=attention_head_dim,
172
+ dropout=dropout,
173
+ bias=attention_bias,
174
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
175
+ upcast_attention=upcast_attention,
176
+ )
177
+
178
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
179
+
180
+ # Cross-Attn
181
+ if cross_attention_dim is not None:
182
+ self.attn2 = CrossAttention(
183
+ query_dim=dim,
184
+ cross_attention_dim=cross_attention_dim,
185
+ heads=num_attention_heads,
186
+ dim_head=attention_head_dim,
187
+ dropout=dropout,
188
+ bias=attention_bias,
189
+ upcast_attention=upcast_attention,
190
+ )
191
+ else:
192
+ self.attn2 = None
193
+
194
+ if cross_attention_dim is not None:
195
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
196
+ else:
197
+ self.norm2 = None
198
+
199
+ # Feed-forward
200
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
201
+ self.norm3 = nn.LayerNorm(dim)
202
+
203
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None):
204
+ if not is_xformers_available():
205
+ print("Here is how to install it")
206
+ raise ModuleNotFoundError(
207
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
208
+ " xformers",
209
+ name="xformers",
210
+ )
211
+ elif not torch.cuda.is_available():
212
+ raise ValueError(
213
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
214
+ " available for GPU "
215
+ )
216
+ else:
217
+ try:
218
+ # Make sure we can run the memory efficient attention
219
+ _ = xformers.ops.memory_efficient_attention(
220
+ torch.randn((1, 2, 40), device="cuda"),
221
+ torch.randn((1, 2, 40), device="cuda"),
222
+ torch.randn((1, 2, 40), device="cuda"),
223
+ )
224
+ except Exception as e:
225
+ raise e
226
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
227
+ if self.attn2 is not None:
228
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
229
+
230
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, \
231
+ inter_frame=False, **kwargs):
232
+ # SparseCausal-Attention
233
+ norm_hidden_states = (
234
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
235
+ )
236
+
237
+ if self.only_cross_attention:
238
+ hidden_states = (
239
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask, inter_frame=inter_frame, **kwargs) + hidden_states
240
+ )
241
+ else:
242
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length, inter_frame=inter_frame, **kwargs) + hidden_states
243
+
244
+ if self.attn2 is not None:
245
+ # Cross-Attention
246
+ norm_hidden_states = (
247
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
248
+ )
249
+ hidden_states = (
250
+ self.attn2(
251
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ # Feed-forward
257
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
258
+
259
+ return hidden_states
260
+
261
+ class FullyFrameAttention(nn.Module):
262
+ r"""
263
+ A cross attention layer.
264
+
265
+ Parameters:
266
+ query_dim (`int`): The number of channels in the query.
267
+ cross_attention_dim (`int`, *optional*):
268
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
269
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
270
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
271
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
272
+ bias (`bool`, *optional*, defaults to False):
273
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ query_dim: int,
279
+ cross_attention_dim: Optional[int] = None,
280
+ heads: int = 8,
281
+ dim_head: int = 64,
282
+ dropout: float = 0.0,
283
+ bias=False,
284
+ upcast_attention: bool = False,
285
+ upcast_softmax: bool = False,
286
+ added_kv_proj_dim: Optional[int] = None,
287
+ norm_num_groups: Optional[int] = None,
288
+ ):
289
+ super().__init__()
290
+ inner_dim = dim_head * heads
291
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
292
+ self.upcast_attention = upcast_attention
293
+ self.upcast_softmax = upcast_softmax
294
+
295
+ self.scale = dim_head**-0.5
296
+
297
+ self.heads = heads
298
+ # for slice_size > 0 the attention score computation
299
+ # is split across the batch axis to save memory
300
+ # You can set slice_size with `set_attention_slice`
301
+ self.sliceable_head_dim = heads
302
+ self._slice_size = None
303
+ self._use_memory_efficient_attention_xformers = False
304
+ self.added_kv_proj_dim = added_kv_proj_dim
305
+
306
+ if norm_num_groups is not None:
307
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
308
+ else:
309
+ self.group_norm = None
310
+
311
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
312
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
313
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
314
+
315
+ if self.added_kv_proj_dim is not None:
316
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
317
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
318
+
319
+ self.to_out = nn.ModuleList([])
320
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
321
+ self.to_out.append(nn.Dropout(dropout))
322
+
323
+ self.q = None
324
+ self.inject_q = None
325
+ self.k = None
326
+ self.inject_k = None
327
+
328
+
329
+ def reshape_heads_to_batch_dim(self, tensor):
330
+ batch_size, seq_len, dim = tensor.shape
331
+ head_size = self.heads
332
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
333
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
334
+ return tensor
335
+
336
+ def reshape_heads_to_batch_dim2(self, tensor):
337
+ batch_size, seq_len, dim = tensor.shape
338
+ head_size = self.heads
339
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
340
+ tensor = tensor.permute(0, 2, 1, 3)
341
+ return tensor
342
+
343
+ def reshape_heads_to_batch_dim3(self, tensor):
344
+ batch_size1, batch_size2, seq_len, dim = tensor.shape
345
+ head_size = self.heads
346
+ tensor = tensor.reshape(batch_size1, batch_size2, seq_len, head_size, dim // head_size)
347
+ tensor = tensor.permute(0, 3, 1, 2, 4)
348
+ return tensor
349
+
350
+ def reshape_batch_dim_to_heads(self, tensor):
351
+ batch_size, seq_len, dim = tensor.shape
352
+ head_size = self.heads
353
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
354
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
355
+ return tensor
356
+
357
+ def set_attention_slice(self, slice_size):
358
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
359
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
360
+
361
+ self._slice_size = slice_size
362
+
363
+ def _attention(self, query, key, value, attention_mask=None):
364
+ if self.upcast_attention:
365
+ query = query.float()
366
+ key = key.float()
367
+
368
+ attention_scores = torch.baddbmm(
369
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
370
+ query,
371
+ key.transpose(-1, -2),
372
+ beta=0,
373
+ alpha=self.scale,
374
+ )
375
+ if attention_mask is not None:
376
+ attention_scores = attention_scores + attention_mask
377
+
378
+ if self.upcast_softmax:
379
+ attention_scores = attention_scores.float()
380
+
381
+ attention_probs = attention_scores.softmax(dim=-1)
382
+
383
+ # cast back to the original dtype
384
+ attention_probs = attention_probs.to(value.dtype)
385
+
386
+ # compute attention output
387
+ hidden_states = torch.bmm(attention_probs, value)
388
+
389
+ # reshape hidden_states
390
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
391
+ return hidden_states
392
+
393
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
394
+ batch_size_attention = query.shape[0]
395
+ hidden_states = torch.zeros(
396
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
397
+ )
398
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
399
+ for i in range(hidden_states.shape[0] // slice_size):
400
+ start_idx = i * slice_size
401
+ end_idx = (i + 1) * slice_size
402
+
403
+ query_slice = query[start_idx:end_idx]
404
+ key_slice = key[start_idx:end_idx]
405
+
406
+ if self.upcast_attention:
407
+ query_slice = query_slice.float()
408
+ key_slice = key_slice.float()
409
+
410
+ attn_slice = torch.baddbmm(
411
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
412
+ query_slice,
413
+ key_slice.transpose(-1, -2),
414
+ beta=0,
415
+ alpha=self.scale,
416
+ )
417
+
418
+ if attention_mask is not None:
419
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
420
+
421
+ if self.upcast_softmax:
422
+ attn_slice = attn_slice.float()
423
+
424
+ attn_slice = attn_slice.softmax(dim=-1)
425
+
426
+ # cast back to the original dtype
427
+ attn_slice = attn_slice.to(value.dtype)
428
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
429
+
430
+ hidden_states[start_idx:end_idx] = attn_slice
431
+
432
+ # reshape hidden_states
433
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
434
+ return hidden_states
435
+
436
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
437
+ # TODO attention_mask
438
+ query = query.contiguous()
439
+ key = key.contiguous()
440
+ value = value.contiguous()
441
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
442
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
443
+ return hidden_states
444
+
445
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, inter_frame=False, **kwargs):
446
+ batch_size, sequence_length, _ = hidden_states.shape
447
+
448
+ encoder_hidden_states = encoder_hidden_states
449
+
450
+ h = w = int(math.sqrt(sequence_length))
451
+ if self.group_norm is not None:
452
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
453
+
454
+ query = self.to_q(hidden_states) # (bf) x d(hw) x c
455
+ self.q = query
456
+ if self.inject_q is not None:
457
+ query = self.inject_q
458
+ dim = query.shape[-1]
459
+ query_old = query.clone()
460
+
461
+ # All frames
462
+ query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
463
+
464
+ query = self.reshape_heads_to_batch_dim(query)
465
+
466
+ if self.added_kv_proj_dim is not None:
467
+ raise NotImplementedError
468
+
469
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
470
+ key = self.to_k(encoder_hidden_states)
471
+ self.k = key
472
+ if self.inject_k is not None:
473
+ key = self.inject_k
474
+ key_old = key.clone()
475
+ value = self.to_v(encoder_hidden_states)
476
+
477
+ if inter_frame:
478
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
479
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
480
+ key = rearrange(key, "b f d c -> b (f d) c",)
481
+ value = rearrange(value, "b f d c -> b (f d) c")
482
+ else:
483
+ # All frames
484
+ key = rearrange(key, "(b f) d c -> b (f d) c", f=video_length)
485
+ value = rearrange(value, "(b f) d c -> b (f d) c", f=video_length)
486
+
487
+ key = self.reshape_heads_to_batch_dim(key)
488
+ value = self.reshape_heads_to_batch_dim(value)
489
+
490
+ if attention_mask is not None:
491
+ if attention_mask.shape[-1] != query.shape[1]:
492
+ target_length = query.shape[1]
493
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
494
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
495
+
496
+ # attention, what we cannot get enough of
497
+ if self._use_memory_efficient_attention_xformers:
498
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
499
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
500
+ hidden_states = hidden_states.to(query.dtype)
501
+ else:
502
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
503
+ hidden_states = self._attention(query, key, value, attention_mask)
504
+ else:
505
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
506
+
507
+ if h in [64]:
508
+ hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length)
509
+ if self.group_norm is not None:
510
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
511
+
512
+ if kwargs["old_qk"] == 1:
513
+ query = query_old
514
+ key = key_old
515
+ else:
516
+ query = hidden_states
517
+ key = hidden_states
518
+ value = hidden_states
519
+
520
+ traj = kwargs["traj"]
521
+ traj = rearrange(traj, '(f n) l d -> f n l d', f=video_length, n=sequence_length)
522
+ mask = rearrange(kwargs["mask"], '(f n) l -> f n l', f=video_length, n=sequence_length)
523
+ mask = torch.cat([mask[:, :, 0].unsqueeze(-1), mask[:, :, -video_length+1:]], dim=-1)
524
+
525
+ traj_key_sequence_inds = torch.cat([traj[:, :, 0, :].unsqueeze(-2), traj[:, :, -video_length+1:, :]], dim=-2)
526
+ t_inds = traj_key_sequence_inds[:, :, :, 0]
527
+ x_inds = traj_key_sequence_inds[:, :, :, 1]
528
+ y_inds = traj_key_sequence_inds[:, :, :, 2]
529
+
530
+ query_tempo = query.unsqueeze(-2)
531
+ _key = rearrange(key, '(b f) (h w) d -> b f h w d', b=int(batch_size/video_length), f=video_length, h=h, w=w)
532
+ _value = rearrange(value, '(b f) (h w) d -> b f h w d', b=int(batch_size/video_length), f=video_length, h=h, w=w)
533
+ key_tempo = _key[:, t_inds, x_inds, y_inds]
534
+ value_tempo = _value[:, t_inds, x_inds, y_inds]
535
+ key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d')
536
+ value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d')
537
+
538
+ mask = rearrange(torch.stack([mask, mask]), 'b f n l -> (b f) n l')
539
+ mask = mask[:,None].repeat(1, self.heads, 1, 1).unsqueeze(-2)
540
+ attn_bias = torch.zeros_like(mask, dtype=key_tempo.dtype) # regular zeros_like
541
+ attn_bias[~mask] = -torch.inf
542
+
543
+ # flow attention
544
+ query_tempo = self.reshape_heads_to_batch_dim3(query_tempo)
545
+ key_tempo = self.reshape_heads_to_batch_dim3(key_tempo)
546
+ value_tempo = self.reshape_heads_to_batch_dim3(value_tempo)
547
+
548
+ attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(query_tempo.size(-1)) + attn_bias
549
+ attn_matrix2 = F.softmax(attn_matrix2, dim=-1)
550
+ out = (attn_matrix2@value_tempo).squeeze(-2)
551
+
552
+ hidden_states = rearrange(out,'(b f) k (h w) d -> b (f h w) (k d)', b=int(batch_size/video_length), f=video_length, h=h, w=w)
553
+
554
+ # linear proj
555
+ hidden_states = self.to_out[0](hidden_states)
556
+
557
+ # dropout
558
+ hidden_states = self.to_out[1](hidden_states)
559
+
560
+ # All frames
561
+ hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length)
562
+
563
+ return hidden_states
models/pipeline_flatten.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ import os
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+ from dataclasses import dataclass
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+
26
+ from diffusers.models import AutoencoderKL
27
+ from diffusers import ModelMixin
28
+ from diffusers.schedulers import DDIMScheduler, DDIMInverseScheduler
29
+ from diffusers.utils import (
30
+ PIL_INTERPOLATION,
31
+ is_accelerate_available,
32
+ is_accelerate_version,
33
+ logging,
34
+ randn_tensor,
35
+ BaseOutput
36
+ )
37
+ from diffusers.pipeline_utils import DiffusionPipeline
38
+ from einops import rearrange
39
+ from .unet import UNet3DConditionModel
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ @dataclass
45
+ class FlattenPipelineOutput(BaseOutput):
46
+ videos: Union[torch.Tensor, np.ndarray]
47
+
48
+ class FlattenPipeline(DiffusionPipeline):
49
+ r"""
50
+ pipeline for FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing.
51
+
52
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
53
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
54
+
55
+ Args:
56
+ vae ([`AutoencoderKL`]):
57
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
58
+ text_encoder ([`CLIPTextModel`]):
59
+ Frozen text-encoder. Stable Diffusion uses the text portion of
60
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
61
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
62
+ tokenizer (`CLIPTokenizer`):
63
+ Tokenizer of class
64
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
65
+ unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
66
+ scheduler ([`SchedulerMixin`]):
67
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
68
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
69
+ inverse_scheduler ([`SchedulerMixin`]):
70
+ DDIM inversion scheduler .
71
+ """
72
+ _optional_components = ["safety_checker", "feature_extractor"]
73
+
74
+ def __init__(
75
+ self,
76
+ vae: AutoencoderKL,
77
+ text_encoder: CLIPTextModel,
78
+ tokenizer: CLIPTokenizer,
79
+ unet: UNet3DConditionModel,
80
+ scheduler: DDIMScheduler,
81
+ inverse_scheduler: DDIMInverseScheduler
82
+ ):
83
+ super().__init__()
84
+
85
+ self.register_modules(
86
+ vae=vae,
87
+ text_encoder=text_encoder,
88
+ tokenizer=tokenizer,
89
+ unet=unet,
90
+ scheduler=scheduler,
91
+ inverse_scheduler=inverse_scheduler
92
+ )
93
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
94
+
95
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
96
+ def enable_vae_slicing(self):
97
+ r"""
98
+ Enable sliced VAE decoding.
99
+
100
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
101
+ steps. This is useful to save some memory and allow larger batch sizes.
102
+ """
103
+ self.vae.enable_slicing()
104
+
105
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
106
+ def disable_vae_slicing(self):
107
+ r"""
108
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
109
+ computing decoding in one step.
110
+ """
111
+ self.vae.disable_slicing()
112
+
113
+ def enable_sequential_cpu_offload(self, gpu_id=0):
114
+ r"""
115
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
116
+ text_encoder, vae, and safety checker have their state dicts saved to CPU and then are moved to a
117
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
118
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
119
+ `enable_model_cpu_offload`, but performance is lower.
120
+ """
121
+ if is_accelerate_available():
122
+ from accelerate import cpu_offload
123
+ else:
124
+ raise ImportError("Please install accelerate via `pip install accelerate`")
125
+
126
+ device = torch.device(f"cuda:{gpu_id}")
127
+
128
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
129
+ cpu_offload(cpu_offloaded_model, device)
130
+
131
+ if self.safety_checker is not None:
132
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
133
+
134
+ def enable_model_cpu_offload(self, gpu_id=0):
135
+ r"""
136
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
137
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
138
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
139
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
140
+ """
141
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
142
+ from accelerate import cpu_offload_with_hook
143
+ else:
144
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
145
+
146
+ device = torch.device(f"cuda:{gpu_id}")
147
+
148
+ hook = None
149
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
150
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
151
+
152
+ if self.safety_checker is not None:
153
+ # the safety checker can offload the vae again
154
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
155
+
156
+ # We'll offload the last model manually.
157
+ self.final_offload_hook = hook
158
+
159
+ @property
160
+ def _execution_device(self):
161
+ r"""
162
+ Returns the device on which the pipeline's models will be executed. After calling
163
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
164
+ hooks.
165
+ """
166
+ if not hasattr(self.unet, "_hf_hook"):
167
+ return self.device
168
+ for module in self.unet.modules():
169
+ if (
170
+ hasattr(module, "_hf_hook")
171
+ and hasattr(module._hf_hook, "execution_device")
172
+ and module._hf_hook.execution_device is not None
173
+ ):
174
+ return torch.device(module._hf_hook.execution_device)
175
+ return self.device
176
+
177
+ def _encode_prompt(
178
+ self,
179
+ prompt,
180
+ device,
181
+ num_videos_per_prompt,
182
+ do_classifier_free_guidance,
183
+ negative_prompt=None,
184
+ prompt_embeds: Optional[torch.FloatTensor] = None,
185
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
186
+ ):
187
+ r"""
188
+ Encodes the prompt into text encoder hidden states.
189
+
190
+ Args:
191
+ prompt (`str` or `List[str]`, *optional*):
192
+ prompt to be encoded
193
+ device: (`torch.device`):
194
+ torch device
195
+ num_videos_per_prompt (`int`):
196
+ number of images that should be generated per prompt
197
+ do_classifier_free_guidance (`bool`):
198
+ whether to use classifier free guidance or not
199
+ negative_prompt (`str` or `List[str]`, *optional*):
200
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
201
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
202
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
203
+ prompt_embeds (`torch.FloatTensor`, *optional*):
204
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
205
+ provided, text embeddings will be generated from `prompt` input argument.
206
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
207
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
208
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
209
+ argument.
210
+ """
211
+ if prompt is not None and isinstance(prompt, str):
212
+ batch_size = 1
213
+ elif prompt is not None and isinstance(prompt, list):
214
+ batch_size = len(prompt)
215
+ else:
216
+ batch_size = prompt_embeds.shape[0]
217
+
218
+ if prompt_embeds is None:
219
+ text_inputs = self.tokenizer(
220
+ prompt,
221
+ padding="max_length",
222
+ max_length=self.tokenizer.model_max_length,
223
+ truncation=True,
224
+ return_tensors="pt",
225
+ )
226
+ text_input_ids = text_inputs.input_ids
227
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
228
+
229
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
230
+ text_input_ids, untruncated_ids
231
+ ):
232
+ removed_text = self.tokenizer.batch_decode(
233
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
234
+ )
235
+ logger.warning(
236
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
237
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
238
+ )
239
+
240
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
241
+ attention_mask = text_inputs.attention_mask.to(device)
242
+ else:
243
+ attention_mask = None
244
+
245
+ prompt_embeds = self.text_encoder(
246
+ text_input_ids.to(device),
247
+ attention_mask=attention_mask,
248
+ )
249
+ prompt_embeds = prompt_embeds[0]
250
+
251
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
252
+
253
+ bs_embed, seq_len, _ = prompt_embeds.shape
254
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
255
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
256
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
257
+
258
+ # get unconditional embeddings for classifier free guidance
259
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
260
+ uncond_tokens: List[str]
261
+ if negative_prompt is None:
262
+ uncond_tokens = [""] * batch_size
263
+ elif type(prompt) is not type(negative_prompt):
264
+ raise TypeError(
265
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
266
+ f" {type(prompt)}."
267
+ )
268
+ elif isinstance(negative_prompt, str):
269
+ uncond_tokens = [negative_prompt]
270
+ elif batch_size != len(negative_prompt):
271
+ raise ValueError(
272
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
273
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
274
+ " the batch size of `prompt`."
275
+ )
276
+ else:
277
+ uncond_tokens = negative_prompt
278
+
279
+ max_length = prompt_embeds.shape[1]
280
+ uncond_input = self.tokenizer(
281
+ uncond_tokens,
282
+ padding="max_length",
283
+ max_length=max_length,
284
+ truncation=True,
285
+ return_tensors="pt",
286
+ )
287
+
288
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
289
+ attention_mask = uncond_input.attention_mask.to(device)
290
+ else:
291
+ attention_mask = None
292
+
293
+ negative_prompt_embeds = self.text_encoder(
294
+ uncond_input.input_ids.to(device),
295
+ attention_mask=attention_mask,
296
+ )
297
+ negative_prompt_embeds = negative_prompt_embeds[0]
298
+
299
+ if do_classifier_free_guidance:
300
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301
+ seq_len = negative_prompt_embeds.shape[1]
302
+
303
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
304
+
305
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
306
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
307
+
308
+ # For classifier free guidance, we need to do two forward passes.
309
+ # Here we concatenate the unconditional and text embeddings into a single batch
310
+ # to avoid doing two forward passes
311
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
312
+
313
+ return prompt_embeds
314
+
315
+ def decode_latents(self, latents, return_tensor=False):
316
+ video_length = latents.shape[2]
317
+ latents = 1 / 0.18215 * latents
318
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
319
+ video = self.vae.decode(latents).sample
320
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
321
+ video = (video / 2 + 0.5).clamp(0, 1)
322
+ if return_tensor:
323
+ return video
324
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
325
+ video = video.cpu().float().numpy()
326
+ return video
327
+
328
+ def prepare_extra_step_kwargs(self, generator, eta):
329
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
330
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
331
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
332
+ # and should be between [0, 1]
333
+
334
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
335
+ extra_step_kwargs = {}
336
+ if accepts_eta:
337
+ extra_step_kwargs["eta"] = eta
338
+
339
+ # check if the scheduler accepts generator
340
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
341
+ if accepts_generator:
342
+ extra_step_kwargs["generator"] = generator
343
+ return extra_step_kwargs
344
+
345
+ def check_inputs(
346
+ self,
347
+ prompt,
348
+ # image,
349
+ height,
350
+ width,
351
+ callback_steps,
352
+ negative_prompt=None,
353
+ prompt_embeds=None,
354
+ negative_prompt_embeds=None,
355
+ ):
356
+ if height % 8 != 0 or width % 8 != 0:
357
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
358
+
359
+ if (callback_steps is None) or (
360
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
361
+ ):
362
+ raise ValueError(
363
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
364
+ f" {type(callback_steps)}."
365
+ )
366
+
367
+ if prompt is not None and prompt_embeds is not None:
368
+ raise ValueError(
369
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
370
+ " only forward one of the two."
371
+ )
372
+ elif prompt is None and prompt_embeds is None:
373
+ raise ValueError(
374
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
375
+ )
376
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
377
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
378
+
379
+ if negative_prompt is not None and negative_prompt_embeds is not None:
380
+ raise ValueError(
381
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
382
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
383
+ )
384
+
385
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
386
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
387
+ raise ValueError(
388
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
389
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
390
+ f" {negative_prompt_embeds.shape}."
391
+ )
392
+
393
+
394
+ def check_image(self, image, prompt, prompt_embeds):
395
+ image_is_pil = isinstance(image, PIL.Image.Image)
396
+ image_is_tensor = isinstance(image, torch.Tensor)
397
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
398
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
399
+
400
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
401
+ raise TypeError(
402
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
403
+ )
404
+
405
+ if image_is_pil:
406
+ image_batch_size = 1
407
+ elif image_is_tensor:
408
+ image_batch_size = image.shape[0]
409
+ elif image_is_pil_list:
410
+ image_batch_size = len(image)
411
+ elif image_is_tensor_list:
412
+ image_batch_size = len(image)
413
+
414
+ if prompt is not None and isinstance(prompt, str):
415
+ prompt_batch_size = 1
416
+ elif prompt is not None and isinstance(prompt, list):
417
+ prompt_batch_size = len(prompt)
418
+ elif prompt_embeds is not None:
419
+ prompt_batch_size = prompt_embeds.shape[0]
420
+
421
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
422
+ raise ValueError(
423
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
424
+ )
425
+
426
+ def prepare_image(
427
+ self, image, width, height, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance
428
+ ):
429
+ if not isinstance(image, torch.Tensor):
430
+ if isinstance(image, PIL.Image.Image):
431
+ image = [image]
432
+
433
+ if isinstance(image[0], PIL.Image.Image):
434
+ images = []
435
+
436
+ for image_ in image:
437
+ image_ = image_.convert("RGB")
438
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
439
+ image_ = np.array(image_)
440
+ image_ = image_[None, :]
441
+ images.append(image_)
442
+
443
+ image = images
444
+
445
+ image = np.concatenate(image, axis=0)
446
+ image = np.array(image).astype(np.float32) / 255.0
447
+ image = image.transpose(0, 3, 1, 2)
448
+ image = 2.0 * image - 1.0
449
+ image = torch.from_numpy(image)
450
+ elif isinstance(image[0], torch.Tensor):
451
+ image = torch.cat(image, dim=0)
452
+
453
+ image_batch_size = image.shape[0]
454
+
455
+ if image_batch_size == 1:
456
+ repeat_by = batch_size
457
+ else:
458
+ # image batch size is the same as prompt batch size
459
+ repeat_by = num_videos_per_prompt
460
+
461
+ image = image.repeat_interleave(repeat_by, dim=0)
462
+
463
+ image = image.to(device=device, dtype=dtype)
464
+
465
+ return image
466
+
467
+ def prepare_video_latents(self, frames, batch_size, dtype, device, generator=None):
468
+ if not isinstance(frames, (torch.Tensor, PIL.Image.Image, list)):
469
+ raise ValueError(
470
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
471
+ )
472
+
473
+ frames = frames[0].to(device=device, dtype=dtype)
474
+ frames = rearrange(frames, "c f h w -> f c h w" )
475
+
476
+ if isinstance(generator, list) and len(generator) != batch_size:
477
+ raise ValueError(
478
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
479
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
480
+ )
481
+
482
+ if isinstance(generator, list):
483
+ latents = [
484
+ self.vae.encode(frames[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
485
+ ]
486
+ latents = torch.cat(latents, dim=0)
487
+ else:
488
+ latents = self.vae.encode(frames).latent_dist.sample(generator)
489
+
490
+ latents = self.vae.config.scaling_factor * latents
491
+
492
+ latents = rearrange(latents, "f c h w ->c f h w" )
493
+
494
+ return latents[None]
495
+
496
+ def _default_height_width(self, height, width, image):
497
+ # NOTE: It is possible that a list of images have different
498
+ # dimensions for each image, so just checking the first image
499
+ # is not _exactly_ correct, but it is simple.
500
+ while isinstance(image, list):
501
+ image = image[0]
502
+
503
+ if height is None:
504
+ if isinstance(image, PIL.Image.Image):
505
+ height = image.height
506
+ elif isinstance(image, torch.Tensor):
507
+ height = image.shape[3]
508
+
509
+ height = (height // 8) * 8 # round down to nearest multiple of 8
510
+
511
+ if width is None:
512
+ if isinstance(image, PIL.Image.Image):
513
+ width = image.width
514
+ elif isinstance(image, torch.Tensor):
515
+ width = image.shape[2]
516
+
517
+ width = (width // 8) * 8 # round down to nearest multiple of 8
518
+
519
+ return height, width
520
+
521
+ def get_alpha_prev(self, timestep):
522
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
523
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
524
+ return alpha_prod_t_prev
525
+
526
+ def get_slide_window_indices(self, video_length, window_size):
527
+ assert window_size >=3
528
+ key_frame_indices = np.arange(0, video_length, window_size-1).tolist()
529
+
530
+ # Append last index
531
+ if key_frame_indices[-1] != (video_length-1):
532
+ key_frame_indices.append(video_length-1)
533
+
534
+ slices = np.split(np.arange(video_length), key_frame_indices)
535
+ inter_frame_list = []
536
+ for s in slices:
537
+ if len(s) < 2:
538
+ continue
539
+ inter_frame_list.append(s[1:].tolist())
540
+ return key_frame_indices, inter_frame_list
541
+
542
+ def get_inverse_timesteps(self, num_inference_steps, strength, device):
543
+ # get the original timestep using init_timestep
544
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
545
+
546
+ t_start = max(num_inference_steps - init_timestep, 0)
547
+
548
+ # safety for t_start overflow to prevent empty timsteps slice
549
+ if t_start == 0:
550
+ return self.inverse_scheduler.timesteps, num_inference_steps
551
+ timesteps = self.inverse_scheduler.timesteps[:-t_start]
552
+
553
+ return timesteps, num_inference_steps - t_start
554
+
555
+ def clean_features(self):
556
+ self.unet.up_blocks[1].resnets[0].out_layers_inject_features = None
557
+ self.unet.up_blocks[1].resnets[1].out_layers_inject_features = None
558
+ self.unet.up_blocks[2].resnets[0].out_layers_inject_features = None
559
+ self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = None
560
+ self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = None
561
+ self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = None
562
+ self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = None
563
+ self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = None
564
+ self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = None
565
+ self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = None
566
+ self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = None
567
+ self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = None
568
+ self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = None
569
+ self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = None
570
+ self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = None
571
+
572
+ @torch.no_grad()
573
+ def __call__(
574
+ self,
575
+ prompt: Union[str, List[str]] = None,
576
+ video_length: Optional[int] = 1,
577
+ frames: Union[List[torch.FloatTensor], List[PIL.Image.Image], List[List[torch.FloatTensor]], List[List[PIL.Image.Image]]] = None,
578
+ height: Optional[int] = None,
579
+ width: Optional[int] = None,
580
+ num_inference_steps: int = 50,
581
+ guidance_scale: float = 7.5,
582
+ negative_prompt: Optional[Union[str, List[str]]] = None,
583
+ num_videos_per_prompt: Optional[int] = 1,
584
+ eta: float = 0.0,
585
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
586
+ latents: Optional[torch.FloatTensor] = None,
587
+ prompt_embeds: Optional[torch.FloatTensor] = None,
588
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
589
+ output_type: Optional[str] = "tensor",
590
+ return_dict: bool = True,
591
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
592
+ callback_steps: int = 1,
593
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
594
+ **kwargs,
595
+ ):
596
+ r"""
597
+ Function invoked when calling the pipeline for generation.
598
+
599
+ Args:
600
+ prompt (`str` or `List[str]`, *optional*):
601
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
602
+ instead.
603
+ frames (`List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
604
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
605
+ The original video frames to be edited.
606
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
607
+ The height in pixels of the generated image.
608
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
609
+ The width in pixels of the generated image.
610
+ num_inference_steps (`int`, *optional*, defaults to 50):
611
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
612
+ expense of slower inference.
613
+ guidance_scale (`float`, *optional*, defaults to 7.5):
614
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
615
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
616
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
617
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
618
+ usually at the expense of lower image quality.
619
+ negative_prompt (`str` or `List[str]`, *optional*):
620
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
621
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
622
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
623
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
624
+ The number of images to generate per prompt.
625
+ eta (`float`, *optional*, defaults to 0.0):
626
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
627
+ [`schedulers.DDIMScheduler`], will be ignored for others.
628
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
629
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
630
+ to make generation deterministic.
631
+ latents (`torch.FloatTensor`, *optional*):
632
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
633
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
634
+ tensor will ge generated by sampling using the supplied random `generator`.
635
+ prompt_embeds (`torch.FloatTensor`, *optional*):
636
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
637
+ provided, text embeddings will be generated from `prompt` input argument.
638
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
639
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
640
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
641
+ argument.
642
+ output_type (`str`, *optional*, defaults to `"pil"`):
643
+ The output format of the generate image. Choose between
644
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
645
+ return_dict (`bool`, *optional*, defaults to `True`):
646
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
647
+ plain tuple.
648
+ callback (`Callable`, *optional*):
649
+ A function that will be called every `callback_steps` steps during inference. The function will be
650
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
651
+ callback_steps (`int`, *optional*, defaults to 1):
652
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
653
+ called at every step.
654
+ cross_attention_kwargs (`dict`, *optional*):
655
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
656
+ `self.processor` in
657
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
658
+ """
659
+ height, width = self._default_height_width(height, width, frames)
660
+
661
+ self.check_inputs(
662
+ prompt,
663
+ height,
664
+ width,
665
+ callback_steps,
666
+ negative_prompt,
667
+ prompt_embeds,
668
+ negative_prompt_embeds,
669
+ )
670
+
671
+ if prompt is not None and isinstance(prompt, str):
672
+ batch_size = 1
673
+ elif prompt is not None and isinstance(prompt, list):
674
+ batch_size = len(prompt)
675
+ else:
676
+ batch_size = prompt_embeds.shape[0]
677
+
678
+ device = self._execution_device
679
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
680
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
681
+ # corresponds to doing no classifier free guidance.
682
+ do_classifier_free_guidance = guidance_scale > 1.0
683
+
684
+ # encode empty prompt
685
+ prompt_embeds = self._encode_prompt(
686
+ "",
687
+ device,
688
+ num_videos_per_prompt,
689
+ do_classifier_free_guidance=do_classifier_free_guidance,
690
+ negative_prompt=None,
691
+ prompt_embeds=prompt_embeds,
692
+ negative_prompt_embeds=negative_prompt_embeds,
693
+ )
694
+
695
+ images = []
696
+ for i_img in frames:
697
+ i_img = self.prepare_image(
698
+ image=i_img,
699
+ width=width,
700
+ height=height,
701
+ batch_size=batch_size * num_videos_per_prompt,
702
+ num_videos_per_prompt=num_videos_per_prompt,
703
+ device=device,
704
+ dtype=self.unet.dtype,
705
+ do_classifier_free_guidance=do_classifier_free_guidance,
706
+ )
707
+ images.append(i_img)
708
+ frames = torch.stack(images, dim=2) # b x c x f x h x w
709
+
710
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
711
+
712
+ latents = self.prepare_video_latents(frames, batch_size, self.unet.dtype, device, generator=generator)
713
+
714
+ saved_features0 = []
715
+ saved_features1 = []
716
+ saved_features2 = []
717
+ saved_q4 = []
718
+ saved_k4 = []
719
+ saved_q5 = []
720
+ saved_k5 = []
721
+ saved_q6 = []
722
+ saved_k6 = []
723
+ saved_q7 = []
724
+ saved_k7 = []
725
+ saved_q8 = []
726
+ saved_k8 = []
727
+ saved_q9 = []
728
+ saved_k9 = []
729
+
730
+ # ddim inverse
731
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
732
+ timesteps = self.scheduler.timesteps
733
+
734
+ num_inverse_steps = 100
735
+ self.inverse_scheduler.set_timesteps(num_inverse_steps, device=device)
736
+ inverse_timesteps, num_inverse_steps = self.get_inverse_timesteps(num_inverse_steps, 1, device)
737
+ num_warmup_steps = len(inverse_timesteps) - num_inverse_steps * self.inverse_scheduler.order
738
+
739
+ with self.progress_bar(total=num_inverse_steps-1) as progress_bar:
740
+ for i, t in enumerate(inverse_timesteps[1:]):
741
+ # expand the latents if we are doing classifier free guidance
742
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
743
+ latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
744
+
745
+ noise_pred = self.unet(
746
+ latent_model_input,
747
+ t,
748
+ encoder_hidden_states=prompt_embeds,
749
+ cross_attention_kwargs=cross_attention_kwargs,
750
+ **kwargs,
751
+ ).sample
752
+
753
+ if t in timesteps:
754
+ saved_features0.append(self.unet.up_blocks[1].resnets[0].out_layers_features.cpu())
755
+ saved_features1.append(self.unet.up_blocks[1].resnets[1].out_layers_features.cpu())
756
+ saved_features2.append(self.unet.up_blocks[2].resnets[0].out_layers_features.cpu())
757
+ saved_q4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.q.cpu())
758
+ saved_k4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.k.cpu())
759
+ saved_q5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.q.cpu())
760
+ saved_k5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.k.cpu())
761
+ saved_q6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.q.cpu())
762
+ saved_k6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.k.cpu())
763
+ saved_q7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.q.cpu())
764
+ saved_k7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.k.cpu())
765
+ saved_q8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.q.cpu())
766
+ saved_k8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.k.cpu())
767
+ saved_q9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.q.cpu())
768
+ saved_k9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.k.cpu())
769
+
770
+
771
+ if do_classifier_free_guidance:
772
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
773
+ noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_uncond)
774
+
775
+ # compute the previous noisy sample x_t -> x_t-1
776
+ latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
777
+ if i == len(inverse_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0):
778
+ progress_bar.update()
779
+
780
+ saved_features0.reverse()
781
+ saved_features1.reverse()
782
+ saved_features2.reverse()
783
+ saved_q4.reverse()
784
+ saved_k4.reverse()
785
+ saved_q5.reverse()
786
+ saved_k5.reverse()
787
+ saved_q6.reverse()
788
+ saved_k6.reverse()
789
+ saved_q7.reverse()
790
+ saved_k7.reverse()
791
+ saved_q8.reverse()
792
+ saved_k8.reverse()
793
+ saved_q9.reverse()
794
+ saved_k9.reverse()
795
+
796
+ # video sampling
797
+ prompt_embeds = self._encode_prompt(
798
+ prompt,
799
+ device,
800
+ num_videos_per_prompt,
801
+ do_classifier_free_guidance,
802
+ negative_prompt,
803
+ prompt_embeds=None,
804
+ negative_prompt_embeds=negative_prompt_embeds,
805
+ )
806
+
807
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
808
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
809
+ for i, t in enumerate(timesteps):
810
+ torch.cuda.empty_cache()
811
+
812
+ # expand the latents if we are doing classifier free guidance
813
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
814
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
815
+
816
+ # inject features
817
+ if i < kwargs["inject_step"]:
818
+ self.unet.up_blocks[1].resnets[0].out_layers_inject_features = saved_features0[i].to(device)
819
+ self.unet.up_blocks[1].resnets[1].out_layers_inject_features = saved_features1[i].to(device)
820
+ self.unet.up_blocks[2].resnets[0].out_layers_inject_features = saved_features2[i].to(device)
821
+ self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = saved_q4[i].to(device)
822
+ self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = saved_k4[i].to(device)
823
+ self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = saved_q5[i].to(device)
824
+ self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = saved_k5[i].to(device)
825
+ self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = saved_q6[i].to(device)
826
+ self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = saved_k6[i].to(device)
827
+ self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = saved_q7[i].to(device)
828
+ self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = saved_k7[i].to(device)
829
+ self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = saved_q8[i].to(device)
830
+ self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = saved_k8[i].to(device)
831
+ self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = saved_q9[i].to(device)
832
+ self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = saved_k9[i].to(device)
833
+ else:
834
+ self.clean_features()
835
+
836
+ noise_pred = self.unet(
837
+ latent_model_input,
838
+ t,
839
+ encoder_hidden_states=prompt_embeds,
840
+ cross_attention_kwargs=cross_attention_kwargs,
841
+ **kwargs,
842
+ ).sample
843
+
844
+ self.clean_features()
845
+
846
+ # perform guidance
847
+ if do_classifier_free_guidance:
848
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
849
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
850
+
851
+ # compute the previous noisy sample x_t -> x_t-1
852
+ step_dict = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
853
+ latents = step_dict.prev_sample
854
+
855
+ # call the callback, if provided
856
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
857
+ progress_bar.update()
858
+ if callback is not None and i % callback_steps == 0:
859
+ callback(i, t, latents)
860
+
861
+ # If we do sequential model offloading, let's offload unet
862
+ # manually for max memory savings
863
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
864
+ self.unet.to("cpu")
865
+ torch.cuda.empty_cache()
866
+ # Post-processing
867
+ video = self.decode_latents(latents)
868
+
869
+ # Convert to tensor
870
+ if output_type == "tensor":
871
+ video = torch.from_numpy(video)
872
+
873
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
874
+ self.final_offload_hook.offload()
875
+
876
+ if not return_dict:
877
+ return video
878
+
879
+ return FlattenPipelineOutput(videos=video)
models/resnet.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+ class TemporalConv1d(nn.Conv1d):
21
+ def forward(self, x):
22
+ b, c, f, h, w = x.shape
23
+ y = rearrange(x.clone(), "b c f h w -> (b h w) c f")
24
+ y = super().forward(y)
25
+ y = rearrange(y, "(b h w) c f -> b c f h w", b=b, h=h, w=w)
26
+ return y
27
+
28
+
29
+ class Upsample3D(nn.Module):
30
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
31
+ super().__init__()
32
+ self.channels = channels
33
+ self.out_channels = out_channels or channels
34
+ self.use_conv = use_conv
35
+ self.use_conv_transpose = use_conv_transpose
36
+ self.name = name
37
+
38
+ conv = None
39
+ if use_conv_transpose:
40
+ raise NotImplementedError
41
+ elif use_conv:
42
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
43
+
44
+ if name == "conv":
45
+ self.conv = conv
46
+ else:
47
+ self.Conv2d_0 = conv
48
+
49
+ def forward(self, hidden_states, output_size=None):
50
+ assert hidden_states.shape[1] == self.channels
51
+
52
+ if self.use_conv_transpose:
53
+ raise NotImplementedError
54
+
55
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
56
+ dtype = hidden_states.dtype
57
+ if dtype == torch.bfloat16:
58
+ hidden_states = hidden_states.to(torch.float32)
59
+
60
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
61
+ if hidden_states.shape[0] >= 64:
62
+ hidden_states = hidden_states.contiguous()
63
+
64
+ # if `output_size` is passed we force the interpolation output
65
+ # size and do not make use of `scale_factor=2`
66
+ if output_size is None:
67
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
68
+ else:
69
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
70
+
71
+ # If the input is bfloat16, we cast back to bfloat16
72
+ if dtype == torch.bfloat16:
73
+ hidden_states = hidden_states.to(dtype)
74
+
75
+ if self.use_conv:
76
+ if self.name == "conv":
77
+ hidden_states = self.conv(hidden_states)
78
+ else:
79
+ hidden_states = self.Conv2d_0(hidden_states)
80
+
81
+ return hidden_states
82
+
83
+
84
+ class Downsample3D(nn.Module):
85
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
86
+ super().__init__()
87
+ self.channels = channels
88
+ self.out_channels = out_channels or channels
89
+ self.use_conv = use_conv
90
+ self.padding = padding
91
+ stride = 2
92
+ self.name = name
93
+
94
+ if use_conv:
95
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
96
+ else:
97
+ raise NotImplementedError
98
+
99
+ if name == "conv":
100
+ self.Conv2d_0 = conv
101
+ self.conv = conv
102
+ elif name == "Conv2d_0":
103
+ self.conv = conv
104
+ else:
105
+ self.conv = conv
106
+
107
+ def forward(self, hidden_states):
108
+ assert hidden_states.shape[1] == self.channels
109
+ if self.use_conv and self.padding == 0:
110
+ raise NotImplementedError
111
+
112
+ assert hidden_states.shape[1] == self.channels
113
+ hidden_states = self.conv(hidden_states)
114
+
115
+ return hidden_states
116
+
117
+
118
+ class ResnetBlock3D(nn.Module):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ in_channels,
123
+ out_channels=None,
124
+ conv_shortcut=False,
125
+ dropout=0.0,
126
+ temb_channels=512,
127
+ groups=32,
128
+ groups_out=None,
129
+ pre_norm=True,
130
+ eps=1e-6,
131
+ non_linearity="swish",
132
+ time_embedding_norm="default",
133
+ output_scale_factor=1.0,
134
+ use_in_shortcut=None,
135
+ ):
136
+ super().__init__()
137
+ self.pre_norm = pre_norm
138
+ self.pre_norm = True
139
+ self.in_channels = in_channels
140
+ out_channels = in_channels if out_channels is None else out_channels
141
+ self.out_channels = out_channels
142
+ self.use_conv_shortcut = conv_shortcut
143
+ self.time_embedding_norm = time_embedding_norm
144
+ self.output_scale_factor = output_scale_factor
145
+
146
+ if groups_out is None:
147
+ groups_out = groups
148
+
149
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
150
+
151
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
152
+
153
+ if temb_channels is not None:
154
+ if self.time_embedding_norm == "default":
155
+ time_emb_proj_out_channels = out_channels
156
+ elif self.time_embedding_norm == "scale_shift":
157
+ time_emb_proj_out_channels = out_channels * 2
158
+ else:
159
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
160
+
161
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
162
+ else:
163
+ self.time_emb_proj = None
164
+
165
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
166
+ self.dropout = torch.nn.Dropout(dropout)
167
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
168
+
169
+ if non_linearity == "swish":
170
+ self.nonlinearity = lambda x: F.silu(x)
171
+ elif non_linearity == "mish":
172
+ self.nonlinearity = Mish()
173
+ elif non_linearity == "silu":
174
+ self.nonlinearity = nn.SiLU()
175
+
176
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
177
+
178
+ self.conv_shortcut = None
179
+ if self.use_in_shortcut:
180
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
181
+
182
+ # save features
183
+ self.out_layers_features = None
184
+ self.out_layers_inject_features = None
185
+
186
+ def forward(self, input_tensor, temb):
187
+ hidden_states = input_tensor
188
+
189
+ hidden_states = self.norm1(hidden_states)
190
+ hidden_states = self.nonlinearity(hidden_states)
191
+
192
+ hidden_states = self.conv1(hidden_states)
193
+
194
+ if temb is not None:
195
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
196
+
197
+ if temb is not None and self.time_embedding_norm == "default":
198
+ hidden_states = hidden_states + temb
199
+
200
+ hidden_states = self.norm2(hidden_states)
201
+
202
+ if temb is not None and self.time_embedding_norm == "scale_shift":
203
+ scale, shift = torch.chunk(temb, 2, dim=1)
204
+ hidden_states = hidden_states * (1 + scale) + shift
205
+
206
+ hidden_states = self.nonlinearity(hidden_states)
207
+
208
+ hidden_states = self.dropout(hidden_states)
209
+ hidden_states = self.conv2(hidden_states)
210
+
211
+ if self.conv_shortcut is not None:
212
+ input_tensor = self.conv_shortcut(input_tensor)
213
+
214
+ # save features
215
+ self.out_layers_features = hidden_states
216
+ if self.out_layers_inject_features is not None:
217
+ hidden_states = self.out_layers_inject_features
218
+
219
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
220
+
221
+ return output_tensor
222
+
223
+
224
+ class Mish(torch.nn.Module):
225
+ def forward(self, hidden_states):
226
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
models/unet.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers import ModelMixin
15
+ from diffusers.utils import BaseOutput, logging
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from .unet_blocks import (
18
+ CrossAttnDownBlock3D,
19
+ CrossAttnUpBlock3D,
20
+ DownBlock3D,
21
+ UNetMidBlock3DCrossAttn,
22
+ UpBlock3D,
23
+ get_down_block,
24
+ get_up_block,
25
+ )
26
+ from .resnet import InflatedConv3d
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ @dataclass
33
+ class UNet3DConditionOutput(BaseOutput):
34
+ sample: torch.FloatTensor
35
+
36
+
37
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
38
+ _supports_gradient_checkpointing = True
39
+
40
+ @register_to_config
41
+ def __init__(
42
+ self,
43
+ sample_size: Optional[int] = None,
44
+ in_channels: int = 4,
45
+ out_channels: int = 4,
46
+ center_input_sample: bool = False,
47
+ flip_sin_to_cos: bool = True,
48
+ freq_shift: int = 0,
49
+ down_block_types: Tuple[str] = (
50
+ "CrossAttnDownBlock3D",
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "DownBlock3D",
54
+ ),
55
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
56
+ up_block_types: Tuple[str] = (
57
+ "UpBlock3D",
58
+ "CrossAttnUpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D"
61
+ ),
62
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
63
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64
+ layers_per_block: int = 2,
65
+ downsample_padding: int = 1,
66
+ mid_block_scale_factor: float = 1,
67
+ act_fn: str = "silu",
68
+ norm_num_groups: int = 32,
69
+ norm_eps: float = 1e-5,
70
+ cross_attention_dim: int = 1280,
71
+ attention_head_dim: Union[int, Tuple[int]] = 8,
72
+ dual_cross_attention: bool = False,
73
+ use_linear_projection: bool = False,
74
+ class_embed_type: Optional[str] = None,
75
+ num_class_embeds: Optional[int] = None,
76
+ upcast_attention: bool = False,
77
+ resnet_time_scale_shift: str = "default",
78
+ ):
79
+ super().__init__()
80
+
81
+ self.sample_size = sample_size
82
+ time_embed_dim = block_out_channels[0] * 4
83
+
84
+ # input
85
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86
+
87
+ # time
88
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
89
+ timestep_input_dim = block_out_channels[0]
90
+
91
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
92
+
93
+ # class embedding
94
+ if class_embed_type is None and num_class_embeds is not None:
95
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
96
+ elif class_embed_type == "timestep":
97
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
98
+ elif class_embed_type == "identity":
99
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
100
+ else:
101
+ self.class_embedding = None
102
+
103
+ self.down_blocks = nn.ModuleList([])
104
+ self.mid_block = None
105
+ self.up_blocks = nn.ModuleList([])
106
+
107
+ if isinstance(only_cross_attention, bool):
108
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
109
+
110
+ if isinstance(attention_head_dim, int):
111
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
112
+
113
+ # down
114
+ output_channel = block_out_channels[0]
115
+ for i, down_block_type in enumerate(down_block_types):
116
+ input_channel = output_channel
117
+ output_channel = block_out_channels[i]
118
+ is_final_block = i == len(block_out_channels) - 1
119
+
120
+ down_block = get_down_block(
121
+ down_block_type,
122
+ num_layers=layers_per_block,
123
+ in_channels=input_channel,
124
+ out_channels=output_channel,
125
+ temb_channels=time_embed_dim,
126
+ add_downsample=not is_final_block,
127
+ resnet_eps=norm_eps,
128
+ resnet_act_fn=act_fn,
129
+ resnet_groups=norm_num_groups,
130
+ cross_attention_dim=cross_attention_dim,
131
+ attn_num_head_channels=attention_head_dim[i],
132
+ downsample_padding=downsample_padding,
133
+ dual_cross_attention=dual_cross_attention,
134
+ use_linear_projection=use_linear_projection,
135
+ only_cross_attention=only_cross_attention[i],
136
+ upcast_attention=upcast_attention,
137
+ resnet_time_scale_shift=resnet_time_scale_shift,
138
+ )
139
+ self.down_blocks.append(down_block)
140
+
141
+ # mid
142
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
143
+ self.mid_block = UNetMidBlock3DCrossAttn(
144
+ in_channels=block_out_channels[-1],
145
+ temb_channels=time_embed_dim,
146
+ resnet_eps=norm_eps,
147
+ resnet_act_fn=act_fn,
148
+ output_scale_factor=mid_block_scale_factor,
149
+ resnet_time_scale_shift=resnet_time_scale_shift,
150
+ cross_attention_dim=cross_attention_dim,
151
+ attn_num_head_channels=attention_head_dim[-1],
152
+ resnet_groups=norm_num_groups,
153
+ dual_cross_attention=dual_cross_attention,
154
+ use_linear_projection=use_linear_projection,
155
+ upcast_attention=upcast_attention,
156
+ )
157
+ else:
158
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
159
+
160
+ # count how many layers upsample the videos
161
+ self.num_upsamplers = 0
162
+
163
+ # up
164
+ reversed_block_out_channels = list(reversed(block_out_channels))
165
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
166
+ only_cross_attention = list(reversed(only_cross_attention))
167
+ output_channel = reversed_block_out_channels[0]
168
+ for i, up_block_type in enumerate(up_block_types):
169
+ is_final_block = i == len(block_out_channels) - 1
170
+
171
+ prev_output_channel = output_channel
172
+ output_channel = reversed_block_out_channels[i]
173
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
174
+
175
+ # add upsample block for all BUT final layer
176
+ if not is_final_block:
177
+ add_upsample = True
178
+ self.num_upsamplers += 1
179
+ else:
180
+ add_upsample = False
181
+
182
+ up_block = get_up_block(
183
+ up_block_type,
184
+ num_layers=layers_per_block + 1,
185
+ in_channels=input_channel,
186
+ out_channels=output_channel,
187
+ prev_output_channel=prev_output_channel,
188
+ temb_channels=time_embed_dim,
189
+ add_upsample=add_upsample,
190
+ resnet_eps=norm_eps,
191
+ resnet_act_fn=act_fn,
192
+ resnet_groups=norm_num_groups,
193
+ cross_attention_dim=cross_attention_dim,
194
+ attn_num_head_channels=reversed_attention_head_dim[i],
195
+ dual_cross_attention=dual_cross_attention,
196
+ use_linear_projection=use_linear_projection,
197
+ only_cross_attention=only_cross_attention[i],
198
+ upcast_attention=upcast_attention,
199
+ resnet_time_scale_shift=resnet_time_scale_shift,
200
+ )
201
+ self.up_blocks.append(up_block)
202
+ prev_output_channel = output_channel
203
+
204
+ # out
205
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
206
+ self.conv_act = nn.SiLU()
207
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
208
+
209
+ def set_attention_slice(self, slice_size):
210
+ r"""
211
+ Enable sliced attention computation.
212
+
213
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
214
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
215
+
216
+ Args:
217
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
218
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
219
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
220
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
221
+ must be a multiple of `slice_size`.
222
+ """
223
+ sliceable_head_dims = []
224
+
225
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
226
+ if hasattr(module, "set_attention_slice"):
227
+ sliceable_head_dims.append(module.sliceable_head_dim)
228
+
229
+ for child in module.children():
230
+ fn_recursive_retrieve_slicable_dims(child)
231
+
232
+ # retrieve number of attention layers
233
+ for module in self.children():
234
+ fn_recursive_retrieve_slicable_dims(module)
235
+
236
+ num_slicable_layers = len(sliceable_head_dims)
237
+
238
+ if slice_size == "auto":
239
+ # half the attention head size is usually a good trade-off between
240
+ # speed and memory
241
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
242
+ elif slice_size == "max":
243
+ # make smallest slice possible
244
+ slice_size = num_slicable_layers * [1]
245
+
246
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
247
+
248
+ if len(slice_size) != len(sliceable_head_dims):
249
+ raise ValueError(
250
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
251
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
252
+ )
253
+
254
+ for i in range(len(slice_size)):
255
+ size = slice_size[i]
256
+ dim = sliceable_head_dims[i]
257
+ if size is not None and size > dim:
258
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
259
+
260
+ # Recursively walk through all the children.
261
+ # Any children which exposes the set_attention_slice method
262
+ # gets the message
263
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
264
+ if hasattr(module, "set_attention_slice"):
265
+ module.set_attention_slice(slice_size.pop())
266
+
267
+ for child in module.children():
268
+ fn_recursive_set_attention_slice(child, slice_size)
269
+
270
+ reversed_slice_size = list(reversed(slice_size))
271
+ for module in self.children():
272
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
273
+
274
+ def _set_gradient_checkpointing(self, module, value=False):
275
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
276
+ module.gradient_checkpointing = value
277
+
278
+ def forward(
279
+ self,
280
+ sample: torch.FloatTensor,
281
+ timestep: Union[torch.Tensor, float, int],
282
+ encoder_hidden_states: torch.Tensor,
283
+ class_labels: Optional[torch.Tensor] = None,
284
+ attention_mask: Optional[torch.Tensor] = None,
285
+ return_dict: bool = True,
286
+ cross_attention_kwargs = None,
287
+ inter_frame = False,
288
+ **kwargs,
289
+ ) -> Union[UNet3DConditionOutput, Tuple]:
290
+ r"""
291
+ Args:
292
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
293
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
294
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
295
+ return_dict (`bool`, *optional*, defaults to `True`):
296
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
297
+
298
+ Returns:
299
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
300
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
301
+ returning a tuple, the first element is the sample tensor.
302
+ """
303
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
304
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
305
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
306
+ # on the fly if necessary.
307
+ default_overall_up_factor = 2**self.num_upsamplers
308
+ kwargs["t"] = timestep
309
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
310
+ forward_upsample_size = False
311
+ upsample_size = None
312
+
313
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
314
+ logger.info("Forward upsample size to force interpolation output size.")
315
+ forward_upsample_size = True
316
+
317
+ # prepare attention_mask
318
+ if attention_mask is not None:
319
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
320
+ attention_mask = attention_mask.unsqueeze(1)
321
+
322
+ # center input if necessary
323
+ if self.config.center_input_sample:
324
+ sample = 2 * sample - 1.0
325
+
326
+ # time
327
+ timesteps = timestep
328
+ if not torch.is_tensor(timesteps):
329
+ # This would be a good case for the `match` statement (Python 3.10+)
330
+ is_mps = sample.device.type == "mps"
331
+ if isinstance(timestep, float):
332
+ dtype = torch.float32 if is_mps else torch.float64
333
+ else:
334
+ dtype = torch.int32 if is_mps else torch.int64
335
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
336
+ elif len(timesteps.shape) == 0:
337
+ timesteps = timesteps[None].to(sample.device)
338
+
339
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
340
+ timesteps = timesteps.expand(sample.shape[0])
341
+
342
+ t_emb = self.time_proj(timesteps)
343
+
344
+ # timesteps does not contain any weights and will always return f32 tensors
345
+ # but time_embedding might actually be running in fp16. so we need to cast here.
346
+ # there might be better ways to encapsulate this.
347
+ t_emb = t_emb.to(dtype=self.dtype)
348
+ emb = self.time_embedding(t_emb)
349
+
350
+ if self.class_embedding is not None:
351
+ if class_labels is None:
352
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
353
+
354
+ if self.config.class_embed_type == "timestep":
355
+ class_labels = self.time_proj(class_labels)
356
+
357
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
358
+ emb = emb + class_emb
359
+
360
+ # pre-process
361
+ sample = self.conv_in(sample)
362
+
363
+ # down
364
+ down_block_res_samples = (sample,)
365
+ for downsample_block in self.down_blocks:
366
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
367
+ sample, res_samples = downsample_block(
368
+ hidden_states=sample,
369
+ temb=emb,
370
+ encoder_hidden_states=encoder_hidden_states,
371
+ attention_mask=attention_mask,
372
+ inter_frame=inter_frame,
373
+ **kwargs,
374
+ )
375
+ else:
376
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
377
+
378
+ down_block_res_samples += res_samples
379
+
380
+ # mid
381
+ sample = self.mid_block(
382
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,
383
+ inter_frame=inter_frame,
384
+ **kwargs,
385
+ )
386
+
387
+ # up
388
+ for i, upsample_block in enumerate(self.up_blocks):
389
+ is_final_block = i == len(self.up_blocks) - 1
390
+
391
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
392
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
393
+
394
+ # if we have not reached the final block and need to forward the
395
+ # upsample size, we do it here
396
+ if not is_final_block and forward_upsample_size:
397
+ upsample_size = down_block_res_samples[-1].shape[2:]
398
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
399
+ sample = upsample_block(
400
+ hidden_states=sample,
401
+ temb=emb,
402
+ res_hidden_states_tuple=res_samples,
403
+ encoder_hidden_states=encoder_hidden_states,
404
+ upsample_size=upsample_size,
405
+ attention_mask=attention_mask,
406
+ inter_frame=inter_frame,
407
+ **kwargs,
408
+ )
409
+ else:
410
+ sample = upsample_block(
411
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
412
+ )
413
+ # post-process
414
+ sample = self.conv_norm_out(sample)
415
+ sample = self.conv_act(sample)
416
+ sample = self.conv_out(sample)
417
+
418
+ if not return_dict:
419
+ return (sample,)
420
+
421
+ return UNet3DConditionOutput(sample=sample)
422
+
423
+ @classmethod
424
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, dtype=torch.float32):
425
+ if subfolder is not None:
426
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
427
+
428
+ config_file = os.path.join(pretrained_model_path, 'config.json')
429
+ if not os.path.isfile(config_file):
430
+ raise RuntimeError(f"{config_file} does not exist")
431
+ with open(config_file, "r") as f:
432
+ config = json.load(f)
433
+ config["_class_name"] = cls.__name__
434
+ config["down_block_types"] = [
435
+ "CrossAttnDownBlock3D",
436
+ "CrossAttnDownBlock3D",
437
+ "CrossAttnDownBlock3D",
438
+ "DownBlock3D"
439
+ ]
440
+ config["up_block_types"] = [
441
+ "UpBlock3D",
442
+ "CrossAttnUpBlock3D",
443
+ "CrossAttnUpBlock3D",
444
+ "CrossAttnUpBlock3D"
445
+ ]
446
+
447
+ from diffusers.utils import WEIGHTS_NAME
448
+ model = cls.from_config(config)
449
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
450
+ if dtype == torch.float16:
451
+ model_file = model_file.replace(".bin", ".fp16.bin")
452
+ model = model.to(dtype=dtype)
453
+ if not os.path.isfile(model_file):
454
+ raise RuntimeError(f"{model_file} does not exist")
455
+ state_dict = torch.load(model_file, map_location="cpu")
456
+ # for k, v in model.state_dict().items():
457
+ # if '_temp.' in k:
458
+ # state_dict.update({k: v})
459
+ model.load_state_dict(state_dict, strict=False)
460
+
461
+ return model
models/unet_blocks.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ ):
29
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
30
+ if down_block_type == "DownBlock3D":
31
+ return DownBlock3D(
32
+ num_layers=num_layers,
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ temb_channels=temb_channels,
36
+ add_downsample=add_downsample,
37
+ resnet_eps=resnet_eps,
38
+ resnet_act_fn=resnet_act_fn,
39
+ resnet_groups=resnet_groups,
40
+ downsample_padding=downsample_padding,
41
+ resnet_time_scale_shift=resnet_time_scale_shift,
42
+ )
43
+ elif down_block_type == "CrossAttnDownBlock3D":
44
+ if cross_attention_dim is None:
45
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
46
+ return CrossAttnDownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ cross_attention_dim=cross_attention_dim,
57
+ attn_num_head_channels=attn_num_head_channels,
58
+ dual_cross_attention=dual_cross_attention,
59
+ use_linear_projection=use_linear_projection,
60
+ only_cross_attention=only_cross_attention,
61
+ upcast_attention=upcast_attention,
62
+ resnet_time_scale_shift=resnet_time_scale_shift,
63
+ )
64
+ raise ValueError(f"{down_block_type} does not exist.")
65
+
66
+
67
+ def get_up_block(
68
+ up_block_type,
69
+ num_layers,
70
+ in_channels,
71
+ out_channels,
72
+ prev_output_channel,
73
+ temb_channels,
74
+ add_upsample,
75
+ resnet_eps,
76
+ resnet_act_fn,
77
+ attn_num_head_channels,
78
+ resnet_groups=None,
79
+ cross_attention_dim=None,
80
+ dual_cross_attention=False,
81
+ use_linear_projection=False,
82
+ only_cross_attention=False,
83
+ upcast_attention=False,
84
+ resnet_time_scale_shift="default",
85
+ ):
86
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
87
+ if up_block_type == "UpBlock3D":
88
+ return UpBlock3D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ prev_output_channel=prev_output_channel,
93
+ temb_channels=temb_channels,
94
+ add_upsample=add_upsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ resnet_groups=resnet_groups,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ )
100
+ elif up_block_type == "CrossAttnUpBlock3D":
101
+ if cross_attention_dim is None:
102
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
103
+ return CrossAttnUpBlock3D(
104
+ num_layers=num_layers,
105
+ in_channels=in_channels,
106
+ out_channels=out_channels,
107
+ prev_output_channel=prev_output_channel,
108
+ temb_channels=temb_channels,
109
+ add_upsample=add_upsample,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ cross_attention_dim=cross_attention_dim,
114
+ attn_num_head_channels=attn_num_head_channels,
115
+ dual_cross_attention=dual_cross_attention,
116
+ use_linear_projection=use_linear_projection,
117
+ only_cross_attention=only_cross_attention,
118
+ upcast_attention=upcast_attention,
119
+ resnet_time_scale_shift=resnet_time_scale_shift,
120
+ )
121
+ raise ValueError(f"{up_block_type} does not exist.")
122
+
123
+
124
+ class UNetMidBlock3DCrossAttn(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_channels: int,
128
+ temb_channels: int,
129
+ dropout: float = 0.0,
130
+ num_layers: int = 1,
131
+ resnet_eps: float = 1e-6,
132
+ resnet_time_scale_shift: str = "default",
133
+ resnet_act_fn: str = "swish",
134
+ resnet_groups: int = 32,
135
+ resnet_pre_norm: bool = True,
136
+ attn_num_head_channels=1,
137
+ output_scale_factor=1.0,
138
+ cross_attention_dim=1280,
139
+ dual_cross_attention=False,
140
+ use_linear_projection=False,
141
+ upcast_attention=False,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.has_cross_attention = True
146
+ self.attn_num_head_channels = attn_num_head_channels
147
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
148
+
149
+ # there is always at least one resnet
150
+ resnets = [
151
+ ResnetBlock3D(
152
+ in_channels=in_channels,
153
+ out_channels=in_channels,
154
+ temb_channels=temb_channels,
155
+ eps=resnet_eps,
156
+ groups=resnet_groups,
157
+ dropout=dropout,
158
+ time_embedding_norm=resnet_time_scale_shift,
159
+ non_linearity=resnet_act_fn,
160
+ output_scale_factor=output_scale_factor,
161
+ pre_norm=resnet_pre_norm,
162
+ )
163
+ ]
164
+ attentions = []
165
+
166
+ for _ in range(num_layers):
167
+ if dual_cross_attention:
168
+ raise NotImplementedError
169
+ attentions.append(
170
+ Transformer3DModel(
171
+ attn_num_head_channels,
172
+ in_channels // attn_num_head_channels,
173
+ in_channels=in_channels,
174
+ num_layers=1,
175
+ cross_attention_dim=cross_attention_dim,
176
+ norm_num_groups=resnet_groups,
177
+ use_linear_projection=use_linear_projection,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ )
181
+ resnets.append(
182
+ ResnetBlock3D(
183
+ in_channels=in_channels,
184
+ out_channels=in_channels,
185
+ temb_channels=temb_channels,
186
+ eps=resnet_eps,
187
+ groups=resnet_groups,
188
+ dropout=dropout,
189
+ time_embedding_norm=resnet_time_scale_shift,
190
+ non_linearity=resnet_act_fn,
191
+ output_scale_factor=output_scale_factor,
192
+ pre_norm=resnet_pre_norm,
193
+ )
194
+ )
195
+
196
+ self.attentions = nn.ModuleList(attentions)
197
+ self.resnets = nn.ModuleList(resnets)
198
+
199
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False, **kwargs):
200
+ hidden_states = self.resnets[0](hidden_states, temb)
201
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
202
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample
203
+ hidden_states = resnet(hidden_states, temb)
204
+
205
+ return hidden_states
206
+
207
+
208
+ class CrossAttnDownBlock3D(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_channels: int,
212
+ out_channels: int,
213
+ temb_channels: int,
214
+ dropout: float = 0.0,
215
+ num_layers: int = 1,
216
+ resnet_eps: float = 1e-6,
217
+ resnet_time_scale_shift: str = "default",
218
+ resnet_act_fn: str = "swish",
219
+ resnet_groups: int = 32,
220
+ resnet_pre_norm: bool = True,
221
+ attn_num_head_channels=1,
222
+ cross_attention_dim=1280,
223
+ output_scale_factor=1.0,
224
+ downsample_padding=1,
225
+ add_downsample=True,
226
+ dual_cross_attention=False,
227
+ use_linear_projection=False,
228
+ only_cross_attention=False,
229
+ upcast_attention=False,
230
+ ):
231
+ super().__init__()
232
+ resnets = []
233
+ attentions = []
234
+
235
+ self.has_cross_attention = True
236
+ self.attn_num_head_channels = attn_num_head_channels
237
+
238
+ for i in range(num_layers):
239
+ in_channels = in_channels if i == 0 else out_channels
240
+ resnets.append(
241
+ ResnetBlock3D(
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ eps=resnet_eps,
246
+ groups=resnet_groups,
247
+ dropout=dropout,
248
+ time_embedding_norm=resnet_time_scale_shift,
249
+ non_linearity=resnet_act_fn,
250
+ output_scale_factor=output_scale_factor,
251
+ pre_norm=resnet_pre_norm,
252
+ )
253
+ )
254
+ if dual_cross_attention:
255
+ raise NotImplementedError
256
+ attentions.append(
257
+ Transformer3DModel(
258
+ attn_num_head_channels,
259
+ out_channels // attn_num_head_channels,
260
+ in_channels=out_channels,
261
+ num_layers=1,
262
+ cross_attention_dim=cross_attention_dim,
263
+ norm_num_groups=resnet_groups,
264
+ use_linear_projection=use_linear_projection,
265
+ only_cross_attention=only_cross_attention,
266
+ upcast_attention=upcast_attention,
267
+ )
268
+ )
269
+ self.attentions = nn.ModuleList(attentions)
270
+ self.resnets = nn.ModuleList(resnets)
271
+
272
+ if add_downsample:
273
+ self.downsamplers = nn.ModuleList(
274
+ [
275
+ Downsample3D(
276
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277
+ )
278
+ ]
279
+ )
280
+ else:
281
+ self.downsamplers = None
282
+
283
+ self.gradient_checkpointing = False
284
+
285
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False, **kwargs):
286
+ output_states = ()
287
+
288
+ for resnet, attn in zip(self.resnets, self.attentions):
289
+ if self.training and self.gradient_checkpointing:
290
+
291
+ def create_custom_forward(module, return_dict=None, inter_frame=None):
292
+ def custom_forward(*inputs):
293
+ if return_dict is not None:
294
+ return module(*inputs, return_dict=return_dict, inter_frame=inter_frame)
295
+ else:
296
+ return module(*inputs)
297
+
298
+ return custom_forward
299
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
300
+ hidden_states = torch.utils.checkpoint.checkpoint(
301
+ create_custom_forward(attn, return_dict=False, inter_frame=inter_frame),
302
+ hidden_states,
303
+ encoder_hidden_states,
304
+ )[0]
305
+ else:
306
+ hidden_states = resnet(hidden_states, temb)
307
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample
308
+
309
+ output_states += (hidden_states,)
310
+
311
+ if self.downsamplers is not None:
312
+ for downsampler in self.downsamplers:
313
+ hidden_states = downsampler(hidden_states)
314
+
315
+ output_states += (hidden_states,)
316
+
317
+ return hidden_states, output_states
318
+
319
+
320
+ class DownBlock3D(nn.Module):
321
+ def __init__(
322
+ self,
323
+ in_channels: int,
324
+ out_channels: int,
325
+ temb_channels: int,
326
+ dropout: float = 0.0,
327
+ num_layers: int = 1,
328
+ resnet_eps: float = 1e-6,
329
+ resnet_time_scale_shift: str = "default",
330
+ resnet_act_fn: str = "swish",
331
+ resnet_groups: int = 32,
332
+ resnet_pre_norm: bool = True,
333
+ output_scale_factor=1.0,
334
+ add_downsample=True,
335
+ downsample_padding=1,
336
+ ):
337
+ super().__init__()
338
+ resnets = []
339
+
340
+ for i in range(num_layers):
341
+ in_channels = in_channels if i == 0 else out_channels
342
+ resnets.append(
343
+ ResnetBlock3D(
344
+ in_channels=in_channels,
345
+ out_channels=out_channels,
346
+ temb_channels=temb_channels,
347
+ eps=resnet_eps,
348
+ groups=resnet_groups,
349
+ dropout=dropout,
350
+ time_embedding_norm=resnet_time_scale_shift,
351
+ non_linearity=resnet_act_fn,
352
+ output_scale_factor=output_scale_factor,
353
+ pre_norm=resnet_pre_norm,
354
+ )
355
+ )
356
+
357
+ self.resnets = nn.ModuleList(resnets)
358
+
359
+ if add_downsample:
360
+ self.downsamplers = nn.ModuleList(
361
+ [
362
+ Downsample3D(
363
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
364
+ )
365
+ ]
366
+ )
367
+ else:
368
+ self.downsamplers = None
369
+
370
+ self.gradient_checkpointing = False
371
+
372
+ def forward(self, hidden_states, temb=None):
373
+ output_states = ()
374
+
375
+ for resnet in self.resnets:
376
+ if self.training and self.gradient_checkpointing:
377
+
378
+ def create_custom_forward(module):
379
+ def custom_forward(*inputs):
380
+ return module(*inputs)
381
+
382
+ return custom_forward
383
+
384
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
385
+ else:
386
+ hidden_states = resnet(hidden_states, temb)
387
+
388
+ output_states += (hidden_states,)
389
+
390
+ if self.downsamplers is not None:
391
+ for downsampler in self.downsamplers:
392
+ hidden_states = downsampler(hidden_states)
393
+
394
+ output_states += (hidden_states,)
395
+
396
+ return hidden_states, output_states
397
+
398
+
399
+ class CrossAttnUpBlock3D(nn.Module):
400
+ def __init__(
401
+ self,
402
+ in_channels: int,
403
+ out_channels: int,
404
+ prev_output_channel: int,
405
+ temb_channels: int,
406
+ dropout: float = 0.0,
407
+ num_layers: int = 1,
408
+ resnet_eps: float = 1e-6,
409
+ resnet_time_scale_shift: str = "default",
410
+ resnet_act_fn: str = "swish",
411
+ resnet_groups: int = 32,
412
+ resnet_pre_norm: bool = True,
413
+ attn_num_head_channels=1,
414
+ cross_attention_dim=1280,
415
+ output_scale_factor=1.0,
416
+ add_upsample=True,
417
+ dual_cross_attention=False,
418
+ use_linear_projection=False,
419
+ only_cross_attention=False,
420
+ upcast_attention=False,
421
+ ):
422
+ super().__init__()
423
+ resnets = []
424
+ attentions = []
425
+
426
+ self.has_cross_attention = True
427
+ self.attn_num_head_channels = attn_num_head_channels
428
+
429
+ for i in range(num_layers):
430
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
431
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
432
+
433
+ resnets.append(
434
+ ResnetBlock3D(
435
+ in_channels=resnet_in_channels + res_skip_channels,
436
+ out_channels=out_channels,
437
+ temb_channels=temb_channels,
438
+ eps=resnet_eps,
439
+ groups=resnet_groups,
440
+ dropout=dropout,
441
+ time_embedding_norm=resnet_time_scale_shift,
442
+ non_linearity=resnet_act_fn,
443
+ output_scale_factor=output_scale_factor,
444
+ pre_norm=resnet_pre_norm,
445
+ )
446
+ )
447
+ if dual_cross_attention:
448
+ raise NotImplementedError
449
+ attentions.append(
450
+ Transformer3DModel(
451
+ attn_num_head_channels,
452
+ out_channels // attn_num_head_channels,
453
+ in_channels=out_channels,
454
+ num_layers=1,
455
+ cross_attention_dim=cross_attention_dim,
456
+ norm_num_groups=resnet_groups,
457
+ use_linear_projection=use_linear_projection,
458
+ only_cross_attention=only_cross_attention,
459
+ upcast_attention=upcast_attention,
460
+ )
461
+ )
462
+
463
+ self.attentions = nn.ModuleList(attentions)
464
+ self.resnets = nn.ModuleList(resnets)
465
+
466
+ if add_upsample:
467
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
468
+ else:
469
+ self.upsamplers = None
470
+
471
+ self.gradient_checkpointing = False
472
+
473
+ def forward(
474
+ self,
475
+ hidden_states,
476
+ res_hidden_states_tuple,
477
+ temb=None,
478
+ encoder_hidden_states=None,
479
+ upsample_size=None,
480
+ attention_mask=None,
481
+ inter_frame=False,
482
+ **kwargs,
483
+ ):
484
+ for resnet, attn in zip(self.resnets, self.attentions):
485
+ # pop res hidden states
486
+ res_hidden_states = res_hidden_states_tuple[-1]
487
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
488
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
489
+
490
+ if self.training and self.gradient_checkpointing:
491
+
492
+ def create_custom_forward(module, return_dict=None, inter_frame=None):
493
+ def custom_forward(*inputs):
494
+ if return_dict is not None:
495
+ return module(*inputs, return_dict=return_dict, inter_frame=inter_frame)
496
+ else:
497
+ return module(*inputs)
498
+
499
+ return custom_forward
500
+
501
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
502
+ hidden_states = torch.utils.checkpoint.checkpoint(
503
+ create_custom_forward(attn, return_dict=False, inter_frame=inter_frame),
504
+ hidden_states,
505
+ encoder_hidden_states,
506
+ )[0]
507
+ else:
508
+ hidden_states = resnet(hidden_states, temb)
509
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample
510
+
511
+ if self.upsamplers is not None:
512
+ for upsampler in self.upsamplers:
513
+ hidden_states = upsampler(hidden_states, upsample_size)
514
+
515
+ return hidden_states
516
+
517
+
518
+ class UpBlock3D(nn.Module):
519
+ def __init__(
520
+ self,
521
+ in_channels: int,
522
+ prev_output_channel: int,
523
+ out_channels: int,
524
+ temb_channels: int,
525
+ dropout: float = 0.0,
526
+ num_layers: int = 1,
527
+ resnet_eps: float = 1e-6,
528
+ resnet_time_scale_shift: str = "default",
529
+ resnet_act_fn: str = "swish",
530
+ resnet_groups: int = 32,
531
+ resnet_pre_norm: bool = True,
532
+ output_scale_factor=1.0,
533
+ add_upsample=True,
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+
538
+ for i in range(num_layers):
539
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
540
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
541
+
542
+ resnets.append(
543
+ ResnetBlock3D(
544
+ in_channels=resnet_in_channels + res_skip_channels,
545
+ out_channels=out_channels,
546
+ temb_channels=temb_channels,
547
+ eps=resnet_eps,
548
+ groups=resnet_groups,
549
+ dropout=dropout,
550
+ time_embedding_norm=resnet_time_scale_shift,
551
+ non_linearity=resnet_act_fn,
552
+ output_scale_factor=output_scale_factor,
553
+ pre_norm=resnet_pre_norm,
554
+ )
555
+ )
556
+
557
+ self.resnets = nn.ModuleList(resnets)
558
+
559
+ if add_upsample:
560
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
561
+ else:
562
+ self.upsamplers = None
563
+
564
+ self.gradient_checkpointing = False
565
+
566
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
567
+ for resnet in self.resnets:
568
+ # pop res hidden states
569
+ res_hidden_states = res_hidden_states_tuple[-1]
570
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
571
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
572
+
573
+ if self.training and self.gradient_checkpointing:
574
+
575
+ def create_custom_forward(module):
576
+ def custom_forward(*inputs):
577
+ return module(*inputs)
578
+
579
+ return custom_forward
580
+
581
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
582
+ else:
583
+ hidden_states = resnet(hidden_states, temb)
584
+
585
+ if self.upsamplers is not None:
586
+ for upsampler in self.upsamplers:
587
+ hidden_states = upsampler(hidden_states, upsample_size)
588
+
589
+ return hidden_states
models/util.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+ import decord
6
+ decord.bridge.set_bridge('torch')
7
+ import torch
8
+ import torchvision
9
+ import PIL
10
+ from typing import List
11
+ from tqdm import tqdm
12
+ from einops import rearrange
13
+ import torchvision.transforms.functional as F
14
+ import random
15
+
16
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
17
+ videos = rearrange(videos, "b c t h w -> t b c h w")
18
+ outputs = []
19
+ for x in videos:
20
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
21
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
22
+ if rescale:
23
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
24
+ x = (x * 255).numpy().astype(np.uint8)
25
+ outputs.append(x)
26
+
27
+ os.makedirs(os.path.dirname(path), exist_ok=True)
28
+ imageio.mimsave(path, outputs, fps=fps)
29
+
30
+ def save_videos_grid_pil(videos: List[PIL.Image.Image], path: str, rescale=False, n_rows=4, fps=8):
31
+ videos = rearrange(videos, "b c t h w -> t b c h w")
32
+ outputs = []
33
+ for x in videos:
34
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
35
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
36
+ if rescale:
37
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
38
+ x = (x * 255).numpy().astype(np.uint8)
39
+ outputs.append(x)
40
+
41
+ os.makedirs(os.path.dirname(path), exist_ok=True)
42
+ imageio.mimsave(path, outputs, fps=fps)
43
+
44
+ def read_video(video_path, video_length, width=512, height=512, frame_rate=None):
45
+ vr = decord.VideoReader(video_path, width=width, height=height)
46
+ if frame_rate is None:
47
+ frame_rate = max(1, len(vr) // video_length)
48
+ sample_index = list(range(0, len(vr), frame_rate))[:video_length]
49
+ video = vr.get_batch(sample_index)
50
+ video = rearrange(video, "f h w c -> f c h w")
51
+ video = (video / 127.5 - 1.0)
52
+ return video
53
+
54
+
55
+ # DDIM Inversion
56
+ @torch.no_grad()
57
+ def init_prompt(prompt, pipeline):
58
+ uncond_input = pipeline.tokenizer(
59
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
60
+ return_tensors="pt"
61
+ )
62
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
63
+ text_input = pipeline.tokenizer(
64
+ [prompt],
65
+ padding="max_length",
66
+ max_length=pipeline.tokenizer.model_max_length,
67
+ truncation=True,
68
+ return_tensors="pt",
69
+ )
70
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
71
+ context = torch.cat([uncond_embeddings, text_embeddings])
72
+
73
+ return context
74
+
75
+
76
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
77
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
78
+ timestep, next_timestep = min(
79
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
80
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
81
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
82
+ beta_prod_t = 1 - alpha_prod_t
83
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
84
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
85
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
86
+ return next_sample
87
+
88
+
89
+ def get_noise_pred_single(latents, t, context, unet):
90
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
91
+ return noise_pred
92
+
93
+
94
+ @torch.no_grad()
95
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
96
+ context = init_prompt(prompt, pipeline)
97
+ uncond_embeddings, cond_embeddings = context.chunk(2)
98
+ all_latent = [latent]
99
+ latent = latent.clone().detach()
100
+ for i in tqdm(range(num_inv_steps)):
101
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
102
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
103
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
104
+ all_latent.append(latent)
105
+ return all_latent
106
+
107
+
108
+ @torch.no_grad()
109
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
110
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
111
+ return ddim_latents
112
+
113
+
114
+ """optical flow and trajectories sampling"""
115
+ def preprocess(img1_batch, img2_batch, transforms):
116
+ img1_batch = F.resize(img1_batch, size=[512, 512], antialias=False)
117
+ img2_batch = F.resize(img2_batch, size=[512, 512], antialias=False)
118
+ return transforms(img1_batch, img2_batch)
119
+
120
+ def keys_with_same_value(dictionary):
121
+ result = {}
122
+ for key, value in dictionary.items():
123
+ if value not in result:
124
+ result[value] = [key]
125
+ else:
126
+ result[value].append(key)
127
+
128
+ conflict_points = {}
129
+ for k in result.keys():
130
+ if len(result[k]) > 1:
131
+ conflict_points[k] = result[k]
132
+ return conflict_points
133
+
134
+ def find_duplicates(input_list):
135
+ seen = set()
136
+ duplicates = set()
137
+
138
+ for item in input_list:
139
+ if item in seen:
140
+ duplicates.add(item)
141
+ else:
142
+ seen.add(item)
143
+
144
+ return list(duplicates)
145
+
146
+ def neighbors_index(point, window_size, H, W):
147
+ """return the spatial neighbor indices"""
148
+ t, x, y = point
149
+ neighbors = []
150
+ for i in range(-window_size, window_size + 1):
151
+ for j in range(-window_size, window_size + 1):
152
+ if i == 0 and j == 0:
153
+ continue
154
+ if x + i < 0 or x + i >= H or y + j < 0 or y + j >= W:
155
+ continue
156
+ neighbors.append((t, x + i, y + j))
157
+ return neighbors
158
+
159
+
160
+ @torch.no_grad()
161
+ def sample_trajectories(frames, device):
162
+ from torchvision.models.optical_flow import Raft_Large_Weights
163
+ from torchvision.models.optical_flow import raft_large
164
+
165
+ weights = Raft_Large_Weights.DEFAULT
166
+ transforms = weights.transforms()
167
+
168
+ # frames, _, _ = torchvision.io.read_video(str(video_path), output_format="TCHW")
169
+
170
+ clips = list(range(len(frames)))
171
+
172
+ model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
173
+ model = model.eval()
174
+
175
+ finished_trajectories = []
176
+
177
+ current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms)
178
+ list_of_flows = model(current_frames.to(device), next_frames.to(device))
179
+ predicted_flows = list_of_flows[-1]
180
+
181
+ predicted_flows = predicted_flows/512
182
+
183
+ resolutions = [64, 32, 16, 8]
184
+ res = {}
185
+ window_sizes = {64: 2,
186
+ 32: 1,
187
+ 16: 1,
188
+ 8: 1}
189
+
190
+ for resolution in resolutions:
191
+ print("="*30)
192
+ trajectories = {}
193
+ predicted_flow_resolu = torch.round(resolution*torch.nn.functional.interpolate(predicted_flows, scale_factor=(resolution/512, resolution/512)))
194
+
195
+ T = predicted_flow_resolu.shape[0]+1
196
+ H = predicted_flow_resolu.shape[2]
197
+ W = predicted_flow_resolu.shape[3]
198
+
199
+ is_activated = torch.zeros([T, H, W], dtype=torch.bool)
200
+
201
+ for t in range(T-1):
202
+ flow = predicted_flow_resolu[t]
203
+ for h in range(H):
204
+ for w in range(W):
205
+
206
+ if not is_activated[t, h, w]:
207
+ is_activated[t, h, w] = True
208
+ # this point has not been traversed, start new trajectory
209
+ x = h + int(flow[1, h, w])
210
+ y = w + int(flow[0, h, w])
211
+ if x >= 0 and x < H and y >= 0 and y < W:
212
+ # trajectories.append([(t, h, w), (t+1, x, y)])
213
+ trajectories[(t, h, w)]= (t+1, x, y)
214
+
215
+ conflict_points = keys_with_same_value(trajectories)
216
+ for k in conflict_points:
217
+ index_to_pop = random.randint(0, len(conflict_points[k]) - 1)
218
+ conflict_points[k].pop(index_to_pop)
219
+ for point in conflict_points[k]:
220
+ if point[0] != T-1:
221
+ trajectories[point]= (-1, -1, -1) # stupid padding with (-1, -1, -1)
222
+
223
+ active_traj = []
224
+ all_traj = []
225
+ for t in range(T):
226
+ pixel_set = {(t, x//H, x%H):0 for x in range(H*W)}
227
+ new_active_traj = []
228
+ for traj in active_traj:
229
+ if traj[-1] in trajectories:
230
+ v = trajectories[traj[-1]]
231
+ new_active_traj.append(traj + [v])
232
+ pixel_set[v] = 1
233
+ else:
234
+ all_traj.append(traj)
235
+ active_traj = new_active_traj
236
+ active_traj+=[[pixel] for pixel in pixel_set if pixel_set[pixel] == 0]
237
+ all_traj += active_traj
238
+
239
+ useful_traj = [i for i in all_traj if len(i)>1]
240
+ for idx in range(len(useful_traj)):
241
+ if useful_traj[idx][-1] == (-1, -1, -1):
242
+ useful_traj[idx] = useful_traj[idx][:-1]
243
+ print("how many points in all trajectories for resolution{}?".format(resolution), sum([len(i) for i in useful_traj]))
244
+ print("how many points in the video for resolution{}?".format(resolution), T*H*W)
245
+
246
+ # validate if there are no duplicates in the trajectories
247
+ trajs = []
248
+ for traj in useful_traj:
249
+ trajs = trajs + traj
250
+ assert len(find_duplicates(trajs)) == 0, "There should not be duplicates in the useful trajectories."
251
+
252
+ # check if non-appearing points + appearing points = all the points in the video
253
+ all_points = set([(t, x, y) for t in range(T) for x in range(H) for y in range(W)])
254
+ left_points = all_points- set(trajs)
255
+ print("How many points not in the trajectories for resolution{}?".format(resolution), len(left_points))
256
+ for p in list(left_points):
257
+ useful_traj.append([p])
258
+ print("how many points in all trajectories for resolution{} after pending?".format(resolution), sum([len(i) for i in useful_traj]))
259
+
260
+
261
+ longest_length = max([len(i) for i in useful_traj])
262
+ sequence_length = (window_sizes[resolution]*2+1)**2 + longest_length - 1
263
+
264
+ seqs = []
265
+ masks = []
266
+
267
+ # create a dictionary to facilitate checking the trajectories to which each point belongs.
268
+ point_to_traj = {}
269
+ for traj in useful_traj:
270
+ for p in traj:
271
+ point_to_traj[p] = traj
272
+
273
+ for t in range(T):
274
+ for x in range(H):
275
+ for y in range(W):
276
+ neighbours = neighbors_index((t,x,y), window_sizes[resolution], H, W)
277
+ sequence = [(t,x,y)]+neighbours + [(0,0,0) for i in range((window_sizes[resolution]*2+1)**2-1-len(neighbours))]
278
+ sequence_mask = torch.zeros(sequence_length, dtype=torch.bool)
279
+ sequence_mask[:len(neighbours)+1] = True
280
+
281
+ traj = point_to_traj[(t,x,y)].copy()
282
+ traj.remove((t,x,y))
283
+ sequence = sequence + traj + [(0,0,0) for k in range(longest_length-1-len(traj))]
284
+ sequence_mask[(window_sizes[resolution]*2+1)**2: (window_sizes[resolution]*2+1)**2 + len(traj)] = True
285
+
286
+ seqs.append(sequence)
287
+ masks.append(sequence_mask)
288
+
289
+ seqs = torch.tensor(seqs)
290
+ masks = torch.stack(masks)
291
+ res["traj{}".format(resolution)] = seqs
292
+ res["mask{}".format(resolution)] = masks
293
+ return res
294
+
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.1 # if you encounter CUDA verision mismatches error, try to install torch manually to specify the version that matches your cuda version
2
+ xformers==0.0.23
3
+ accelerate==0.24.1
4
+ diffusers==0.19.0
5
+ transformers==4.35.0
6
+ imageio==2.34.2
7
+ numpy==1.23.5
8
+ imageio-ffmpeg==0.5.1
9
+ fastapi==0.111.0
10
+ einops
11
+ decord
12
+ av
13
+ # dlib==19.24.2
14
+ # tensorboard==2.12.0
15
+ # PyYAML
16
+ # pyfacer
17
+ # timm
18
+ # huggingface-hub
19
+ # gdown
20
+ # natsort
21
+ # imutils
22
+ # batch-face
23
+ # datasets
24
+ # albumentations
25
+ # spiga
26
+ # omegaconf