amos1088 commited on
Commit
5a35e98
·
1 Parent(s): 7957b28
app.py CHANGED
@@ -4,11 +4,14 @@ import torch
4
  import gradio as gr
5
  import spaces
6
  from huggingface_hub import login
 
 
 
 
 
 
7
  from diffusers.utils import load_image
8
 
9
- from models.transformer_sd3 import SD3Transformer2DModel
10
- from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
11
-
12
  # ----------------------------
13
  # Step 1: Download IP Adapter if not exists
14
  # ----------------------------
@@ -32,16 +35,25 @@ if not token:
32
  raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
33
  login(token=token)
34
 
35
- model_path = 'stabilityai/stable-diffusion-3.5-large'
36
  ip_adapter_path = './ip-adapter.bin'
37
  image_encoder_path = "google/siglip-so400m-patch14-384"
38
-
39
- transformer = SD3Transformer2DModel.from_pretrained(
40
- model_path, subfolder="transformer", torch_dtype=torch.bfloat16
41
- )
42
-
43
- pipe = StableDiffusion3Pipeline.from_pretrained(
44
- model_path, transformer=transformer, torch_dtype=torch.bfloat16
 
 
 
 
 
 
 
 
 
45
  ).to("cuda")
46
 
47
  pipe.init_ipadapter(
@@ -51,25 +63,28 @@ pipe.init_ipadapter(
51
  )
52
 
53
 
 
54
  # ----------------------------
55
  # Step 6: Gradio Function
56
  # ----------------------------
57
  @spaces.GPU
58
  def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale):
59
-
60
-
61
  ref_img = load_image(ref_img.name).convert('RGB')
62
 
63
- # please note that SD3.5 Large is sensitive to highres generation like 1536x1536
 
 
64
  image = pipe(
65
  width=1024,
66
  height=1024,
67
  prompt=prompt,
68
  negative_prompt=negative_prompt,
69
- num_inference_steps=24,
70
  guidance_scale=guidance_scale,
71
- generator=torch.Generator("cuda").manual_seed(42),
72
  clip_image=ref_img,
 
 
 
73
  ipadapter_scale=ipadapter_scale,
74
  ).images[0]
75
 
 
4
  import gradio as gr
5
  import spaces
6
  from huggingface_hub import login
7
+ # from diffusers.utils import load_image
8
+ #
9
+ # from models.transformer_sd3 import SD3Transformer2DModel
10
+ # from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
11
+ import torch
12
+ from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
13
  from diffusers.utils import load_image
14
 
 
 
 
15
  # ----------------------------
16
  # Step 1: Download IP Adapter if not exists
17
  # ----------------------------
 
35
  raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
36
  login(token=token)
37
 
38
+ # model_path = 'stabilityai/stable-diffusion-3.5-large'
39
  ip_adapter_path = './ip-adapter.bin'
40
  image_encoder_path = "google/siglip-so400m-patch14-384"
41
+ #
42
+ # transformer = SD3Transformer2DModel.from_pretrained(
43
+ # model_path, subfolder="transformer", torch_dtype=torch.bfloat16
44
+ # )
45
+ #
46
+ # pipe = StableDiffusion3Pipeline.from_pretrained(
47
+ # model_path, transformer=transformer, torch_dtype=torch.bfloat16
48
+ # ).to("cuda")
49
+
50
+
51
+
52
+ controlnet = SD3ControlNetModel.from_pretrained("stabilityai/stable-diffusion-3.5-large-controlnet-depth", torch_dtype=torch.float16)
53
+ pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
54
+ "stabilityai/stable-diffusion-3.5-large",
55
+ controlnet=controlnet,
56
+ torch_dtype=torch.float16,
57
  ).to("cuda")
58
 
59
  pipe.init_ipadapter(
 
63
  )
64
 
65
 
66
+
67
  # ----------------------------
68
  # Step 6: Gradio Function
69
  # ----------------------------
70
  @spaces.GPU
71
  def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale):
 
 
72
  ref_img = load_image(ref_img.name).convert('RGB')
73
 
74
+ control_image = load_image(
75
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png")
76
+ generator = torch.Generator(device="cpu").manual_seed(0)
77
  image = pipe(
78
  width=1024,
79
  height=1024,
80
  prompt=prompt,
81
  negative_prompt=negative_prompt,
82
+ control_image=control_image,
83
  guidance_scale=guidance_scale,
 
84
  clip_image=ref_img,
85
+ num_inference_steps=40,
86
+ generator=generator,
87
+ max_sequence_length=77,
88
  ipadapter_scale=ipadapter_scale,
89
  ).images[0]
90
 
depthfm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
4
+ from dfm import DepthFM
5
+ from unet import UNetModel
depthfm/dfm.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+ from functools import partial
7
+ from torchdiffeq import odeint
8
+
9
+ from unet import UNetModel
10
+ from diffusers import AutoencoderKL
11
+
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ class DepthFM(nn.Module):
18
+ def __init__(self, ckpt_path: str):
19
+ super().__init__()
20
+ vae_id = "runwayml/stable-diffusion-v1-5"
21
+ self.vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae")
22
+ self.scale_factor = 0.18215
23
+
24
+ # set with checkpoint
25
+ ckpt = torch.load(ckpt_path, map_location="cpu")
26
+ self.noising_step = ckpt['noising_step']
27
+ self.empty_text_embed = ckpt['empty_text_embedding']
28
+ self.model = UNetModel(**ckpt['ldm_hparams'])
29
+ self.model.load_state_dict(ckpt['state_dict'])
30
+
31
+ def ode_fn(self, t: Tensor, x: Tensor, **kwargs):
32
+ if t.numel() == 1:
33
+ t = t.expand(x.size(0))
34
+ return self.model(x=x, t=t, **kwargs)
35
+
36
+ def generate(self, z: Tensor, num_steps: int = 4, n_intermediates: int = 0, **kwargs):
37
+ """
38
+ ODE solving from z0 (ims) to z1 (depth).
39
+ """
40
+ ode_kwargs = dict(method="euler", rtol=1e-5, atol=1e-5, options=dict(step_size=1.0 / num_steps))
41
+
42
+ # t specifies which intermediate times should the solver return
43
+ # e.g. t = [0, 0.5, 1] means return the solution at t=0, t=0.5 and t=1
44
+ # but it also specifies the number of steps for fixed step size methods
45
+ t = torch.linspace(0, 1, n_intermediates + 2, device=z.device, dtype=z.dtype)
46
+ # t = torch.tensor([0., 1.], device=z.device, dtype=z.dtype)
47
+
48
+ # allow conditioning information for model
49
+ ode_fn = partial(self.ode_fn, **kwargs)
50
+
51
+ ode_results = odeint(ode_fn, z, t, **ode_kwargs)
52
+
53
+ if n_intermediates > 0:
54
+ return ode_results
55
+ return ode_results[-1]
56
+
57
+ def forward(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
58
+ """
59
+ Args:
60
+ ims: Tensor of shape (b, 3, h, w) in range [-1, 1]
61
+ Returns:
62
+ depth: Tensor of shape (b, 1, h, w) in range [0, 1]
63
+ """
64
+ if ensemble_size > 1:
65
+ assert ims.shape[0] == 1, "Ensemble mode only supported with batch size 1"
66
+ ims = ims.repeat(ensemble_size, 1, 1, 1)
67
+
68
+ bs, dev = ims.shape[0], ims.device
69
+
70
+ ims_z = self.encode(ims, sample_posterior=False)
71
+
72
+ conditioning = torch.tensor(self.empty_text_embed).to(dev).repeat(bs, 1, 1)
73
+ context = ims_z
74
+
75
+ x_source = ims_z
76
+
77
+ if self.noising_step > 0:
78
+ x_source = q_sample(x_source, self.noising_step)
79
+
80
+ # solve ODE
81
+ depth_z = self.generate(x_source, num_steps=num_steps, context=context, context_ca=conditioning)
82
+
83
+ depth = self.decode(depth_z)
84
+ depth = depth.mean(dim=1, keepdim=True)
85
+
86
+ if ensemble_size > 1:
87
+ depth = depth.mean(dim=0, keepdim=True)
88
+
89
+ # normalize depth maps to range [-1, 1]
90
+ depth = per_sample_min_max_normalization(depth.exp())
91
+
92
+ return depth
93
+
94
+ @torch.no_grad()
95
+ def predict_depth(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
96
+ """ Inference method for DepthFM. """
97
+ return self.forward(ims, num_steps, ensemble_size)
98
+
99
+ @torch.no_grad()
100
+ def encode(self, x: Tensor, sample_posterior: bool = True):
101
+ posterior = self.vae.encode(x)
102
+ if sample_posterior:
103
+ z = posterior.latent_dist.sample()
104
+ else:
105
+ z = posterior.latent_dist.mode()
106
+ # normalize latent code
107
+ z = z * self.scale_factor
108
+ return z
109
+
110
+ @torch.no_grad()
111
+ def decode(self, z: Tensor):
112
+ z = 1.0 / self.scale_factor * z
113
+ return self.vae.decode(z).sample
114
+
115
+
116
+ def sigmoid(x):
117
+ return 1 / (1 + np.exp(-x))
118
+
119
+
120
+ def cosine_log_snr(t, eps=0.00001):
121
+ """
122
+ Returns log Signal-to-Noise ratio for time step t and image size 64
123
+ eps: avoid division by zero
124
+ """
125
+ return -2 * np.log(np.tan((np.pi * t) / 2) + eps)
126
+
127
+
128
+ def cosine_alpha_bar(t):
129
+ return sigmoid(cosine_log_snr(t))
130
+
131
+
132
+ def q_sample(x_start: torch.Tensor, t: int, noise: torch.Tensor = None, n_diffusion_timesteps: int = 1000):
133
+ """
134
+ Diffuse the data for a given number of diffusion steps. In other
135
+ words sample from q(x_t | x_0).
136
+ """
137
+ dev = x_start.device
138
+ dtype = x_start.dtype
139
+
140
+ if noise is None:
141
+ noise = torch.randn_like(x_start)
142
+
143
+ alpha_bar_t = cosine_alpha_bar(t / n_diffusion_timesteps)
144
+ alpha_bar_t = torch.tensor(alpha_bar_t).to(dev).to(dtype)
145
+
146
+ return torch.sqrt(alpha_bar_t) * x_start + torch.sqrt(1 - alpha_bar_t) * noise
147
+
148
+
149
+ def per_sample_min_max_normalization(x):
150
+ """ Normalize each sample in a batch independently
151
+ with min-max normalization to [0, 1] """
152
+ bs, *shape = x.shape
153
+ x_ = einops.rearrange(x, "b ... -> b (...)")
154
+ min_val = einops.reduce(x_, "b ... -> b", "min")[..., None]
155
+ max_val = einops.reduce(x_, "b ... -> b", "max")[..., None]
156
+ x_ = (x_ - min_val) / (max_val - min_val)
157
+ return x_.reshape(bs, *shape)
depthfm/unet/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
4
+ from openaimodel import UNetModel
depthfm/unet/attention.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange
5
+ from inspect import isfunction
6
+ import torch.nn.functional as F
7
+ from typing import Optional, Any
8
+
9
+ from util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ print("WARNING: xformers is not available, inference might be slow.")
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+ # CrossAttn precision handling
21
+ import os
22
+
23
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return {el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = (
68
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
69
+ if not glu
70
+ else GEGLU(dim, inner_dim)
71
+ )
72
+
73
+ self.net = nn.Sequential(
74
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+
81
+ def zero_module(module):
82
+ """
83
+ Zero out the parameters of a module and return it.
84
+ """
85
+ for p in module.parameters():
86
+ p.detach().zero_()
87
+ return module
88
+
89
+
90
+ def Normalize(in_channels):
91
+ return torch.nn.GroupNorm(
92
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
93
+ )
94
+
95
+
96
+ class SpatialSelfAttention(nn.Module):
97
+ def __init__(self, in_channels):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+
101
+ self.norm = Normalize(in_channels)
102
+ self.q = torch.nn.Conv2d(
103
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
104
+ )
105
+ self.k = torch.nn.Conv2d(
106
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
107
+ )
108
+ self.v = torch.nn.Conv2d(
109
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
+ )
111
+ self.proj_out = torch.nn.Conv2d(
112
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
+ )
114
+
115
+ def forward(self, x):
116
+ h_ = x
117
+ h_ = self.norm(h_)
118
+ q = self.q(h_)
119
+ k = self.k(h_)
120
+ v = self.v(h_)
121
+
122
+ # compute attention
123
+ b, c, h, w = q.shape
124
+ q = rearrange(q, "b c h w -> b (h w) c")
125
+ k = rearrange(k, "b c h w -> b c (h w)")
126
+ w_ = torch.einsum("bij,bjk->bik", q, k)
127
+
128
+ w_ = w_ * (int(c) ** (-0.5))
129
+ w_ = torch.nn.functional.softmax(w_, dim=2)
130
+
131
+ # attend to values
132
+ v = rearrange(v, "b c h w -> b c (h w)")
133
+ w_ = rearrange(w_, "b i j -> b j i")
134
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
135
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
136
+ h_ = self.proj_out(h_)
137
+
138
+ return x + h_
139
+
140
+
141
+ class CrossAttention(nn.Module):
142
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
143
+ super().__init__()
144
+ inner_dim = dim_head * heads
145
+ context_dim = default(context_dim, query_dim)
146
+
147
+ self.dim_head = dim_head
148
+
149
+ self.scale = dim_head**-0.5
150
+ self.heads = heads
151
+
152
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
153
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
154
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
155
+
156
+ self.to_out = nn.Sequential(
157
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
158
+ )
159
+
160
+ def forward(self, x, context=None, mask=None, rescale_attention=True):
161
+
162
+ is_self_attention = context is None
163
+
164
+ n_tokens = x.shape[1]
165
+
166
+ h = self.heads
167
+
168
+ q = self.to_q(x)
169
+ context = default(context, x)
170
+ k = self.to_k(context)
171
+ v = self.to_v(context)
172
+
173
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
174
+
175
+ if rescale_attention:
176
+ out = F.scaled_dot_product_attention(q, k, v, scale=(math.log(n_tokens) / math.log(n_tokens*4) / self.dim_head)**0.5 if is_self_attention else None)
177
+ else:
178
+ out = F.scaled_dot_product_attention(q, k, v)
179
+
180
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
181
+ return self.to_out(out)
182
+
183
+
184
+ class MemoryEfficientCrossAttention(nn.Module):
185
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
186
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
187
+ super().__init__()
188
+ # print(
189
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
190
+ # f"{heads} heads."
191
+ # )
192
+ inner_dim = dim_head * heads
193
+ context_dim = default(context_dim, query_dim)
194
+
195
+ self.heads = heads
196
+ self.dim_head = dim_head
197
+
198
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
199
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
200
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
201
+
202
+ self.to_out = nn.Sequential(
203
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
204
+ )
205
+ self.attention_op: Optional[Any] = None
206
+
207
+ def forward(self, x, context=None, mask=None):
208
+ q = self.to_q(x)
209
+ context = default(context, x)
210
+ k = self.to_k(context)
211
+ v = self.to_v(context)
212
+
213
+ b, _, _ = q.shape
214
+ q, k, v = map(
215
+ lambda t: t.unsqueeze(3)
216
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
217
+ .permute(0, 2, 1, 3)
218
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
219
+ .contiguous(),
220
+ (q, k, v),
221
+ )
222
+
223
+ # actually compute the attention, what we cannot get enough of
224
+ out = xformers.ops.memory_efficient_attention(
225
+ q, k, v, attn_bias=None, op=self.attention_op
226
+ )
227
+
228
+ if exists(mask):
229
+ raise NotImplementedError
230
+ out = (
231
+ out.unsqueeze(0)
232
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
233
+ .permute(0, 2, 1, 3)
234
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
235
+ )
236
+ return self.to_out(out)
237
+
238
+
239
+ class BasicTransformerBlock(nn.Module):
240
+ ATTENTION_MODES = {
241
+ "softmax": CrossAttention, # vanilla attention
242
+ "softmax-xformers": MemoryEfficientCrossAttention,
243
+ }
244
+
245
+ def __init__(
246
+ self,
247
+ dim,
248
+ n_heads,
249
+ d_head,
250
+ dropout=0.0,
251
+ context_dim=None,
252
+ gated_ff=True,
253
+ checkpoint=True,
254
+ disable_self_attn=False,
255
+ ):
256
+ super().__init__()
257
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
258
+ assert attn_mode in self.ATTENTION_MODES
259
+ attn_cls = self.ATTENTION_MODES[attn_mode]
260
+ self.disable_self_attn = disable_self_attn
261
+ self.attn1 = attn_cls(
262
+ query_dim=dim,
263
+ heads=n_heads,
264
+ dim_head=d_head,
265
+ dropout=dropout,
266
+ context_dim=context_dim if self.disable_self_attn else None,
267
+ ) # is a self-attention if not self.disable_self_attn
268
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
269
+ self.attn2 = attn_cls(
270
+ query_dim=dim,
271
+ context_dim=context_dim,
272
+ heads=n_heads,
273
+ dim_head=d_head,
274
+ dropout=dropout,
275
+ ) # is self-attn if context is none
276
+ self.norm1 = nn.LayerNorm(dim)
277
+ self.norm2 = nn.LayerNorm(dim)
278
+ self.norm3 = nn.LayerNorm(dim)
279
+ self.checkpoint = checkpoint
280
+
281
+ def forward(self, x, context=None):
282
+ return checkpoint(
283
+ self._forward, (x, context), self.parameters(), self.checkpoint
284
+ )
285
+
286
+ def _forward(self, x, context=None):
287
+ x = (
288
+ self.attn1(
289
+ self.norm1(x), context=context if self.disable_self_attn else None
290
+ )
291
+ + x
292
+ )
293
+ x = self.attn2(self.norm2(x), context=context) + x
294
+ x = self.ff(self.norm3(x)) + x
295
+ return x
296
+
297
+
298
+ class SpatialTransformer(nn.Module):
299
+ """
300
+ Transformer block for image-like data.
301
+ First, project the input (aka embedding)
302
+ and reshape to b, t, d.
303
+ Then apply standard transformer action.
304
+ Finally, reshape to image
305
+ NEW: use_linear for more efficiency instead of the 1x1 convs
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ in_channels,
311
+ n_heads,
312
+ d_head,
313
+ depth=1,
314
+ dropout=0.0,
315
+ context_dim=None,
316
+ disable_self_attn=False,
317
+ use_linear=False,
318
+ use_checkpoint=True,
319
+ ):
320
+ super().__init__()
321
+ if exists(context_dim) and not isinstance(context_dim, list):
322
+ context_dim = [context_dim]
323
+ self.in_channels = in_channels
324
+ inner_dim = n_heads * d_head
325
+ self.norm = Normalize(in_channels)
326
+ if not use_linear:
327
+ self.proj_in = nn.Conv2d(
328
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
329
+ )
330
+ else:
331
+ self.proj_in = nn.Linear(in_channels, inner_dim)
332
+
333
+ self.transformer_blocks = nn.ModuleList(
334
+ [
335
+ BasicTransformerBlock(
336
+ inner_dim,
337
+ n_heads,
338
+ d_head,
339
+ dropout=dropout,
340
+ context_dim=context_dim[d],
341
+ disable_self_attn=disable_self_attn,
342
+ checkpoint=use_checkpoint,
343
+ )
344
+ for d in range(depth)
345
+ ]
346
+ )
347
+ if not use_linear:
348
+ self.proj_out = zero_module(
349
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
350
+ )
351
+ else:
352
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
353
+ self.use_linear = use_linear
354
+
355
+ def forward(self, x, context=None):
356
+ # note: if no context is given, cross-attention defaults to self-attention
357
+ if not isinstance(context, list):
358
+ context = [context]
359
+ b, c, h, w = x.shape
360
+ x_in = x
361
+ x = self.norm(x)
362
+ if not self.use_linear:
363
+ x = self.proj_in(x)
364
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
365
+ if self.use_linear:
366
+ x = self.proj_in(x)
367
+ for i, block in enumerate(self.transformer_blocks):
368
+ x = block(x, context=context[i])
369
+ if self.use_linear:
370
+ x = self.proj_out(x)
371
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
372
+ if not self.use_linear:
373
+ x = self.proj_out(x)
374
+ return x + x_in
depthfm/unet/openaimodel.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch as th
4
+ import torch.nn as nn
5
+ from abc import abstractmethod
6
+ import torch.nn.functional as F
7
+
8
+ from util import (
9
+ checkpoint,
10
+ conv_nd,
11
+ linear,
12
+ avg_pool_nd,
13
+ zero_module,
14
+ normalization,
15
+ timestep_embedding,
16
+ )
17
+ from attention import SpatialTransformer
18
+
19
+
20
+ def exists(x):
21
+ return x is not None
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ """
64
+ Any module where forward() takes timestep embeddings as a second argument.
65
+ """
66
+
67
+ @abstractmethod
68
+ def forward(self, x, emb):
69
+ """
70
+ Apply the module to `x` given `emb` timestep embeddings.
71
+ """
72
+
73
+
74
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
+ """
76
+ A sequential module that passes timestep embeddings to the children that
77
+ support it as an extra input.
78
+ """
79
+
80
+ def forward(self, x, emb, context=None):
81
+ for layer in self:
82
+ if isinstance(layer, TimestepBlock):
83
+ x = layer(x, emb)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class Timestep(nn.Module):
414
+ def __init__(self, dim):
415
+ super().__init__()
416
+ self.dim = dim
417
+
418
+ def forward(self, t):
419
+ return timestep_embedding(t, self.dim)
420
+
421
+
422
+ class UNetModel(nn.Module):
423
+ """
424
+ The full UNet model with attention and timestep embedding.
425
+ :param in_channels: channels in the input Tensor.
426
+ :param model_channels: base channel count for the model.
427
+ :param out_channels: channels in the output Tensor.
428
+ :param num_res_blocks: number of residual blocks per downsample.
429
+ :param attention_resolutions: a collection of downsample rates at which
430
+ attention will take place. May be a set, list, or tuple.
431
+ For example, if this contains 4, then at 4x downsampling, attention
432
+ will be used.
433
+ :param dropout: the dropout probability.
434
+ :param channel_mult: channel multiplier for each level of the UNet.
435
+ :param conv_resample: if True, use learned convolutions for upsampling and
436
+ downsampling.
437
+ :param dims: determines if the signal is 1D, 2D, or 3D.
438
+ :param num_classes: if specified (as an int), then this model will be
439
+ class-conditional with `num_classes` classes.
440
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
441
+ :param num_heads: the number of attention heads in each attention layer.
442
+ :param num_heads_channels: if specified, ignore num_heads and instead use
443
+ a fixed channel width per attention head.
444
+ :param num_heads_upsample: works with num_heads to set a different number
445
+ of heads for upsampling. Deprecated.
446
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
447
+ :param resblock_updown: use residual blocks for up/downsampling.
448
+ :param use_new_attention_order: use a different attention pattern for potentially
449
+ increased efficiency.
450
+ """
451
+
452
+ def __init__(
453
+ self,
454
+ image_size,
455
+ in_channels,
456
+ model_channels,
457
+ out_channels,
458
+ num_res_blocks,
459
+ attention_resolutions,
460
+ dropout=0,
461
+ channel_mult=(1, 2, 4, 8),
462
+ conv_resample=True,
463
+ dims=2,
464
+ num_classes=None,
465
+ use_checkpoint=False,
466
+ use_fp16=False,
467
+ use_bf16=False,
468
+ num_heads=-1,
469
+ num_head_channels=-1,
470
+ num_heads_upsample=-1,
471
+ use_scale_shift_norm=False,
472
+ resblock_updown=False,
473
+ use_new_attention_order=False,
474
+ use_spatial_transformer=False, # custom transformer support
475
+ transformer_depth=1, # custom transformer support
476
+ context_dim=None, # custom transformer support
477
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
478
+ legacy=True,
479
+ disable_self_attentions=None,
480
+ num_attention_blocks=None,
481
+ disable_middle_self_attn=False,
482
+ use_linear_in_transformer=False,
483
+ adm_in_channels=None,
484
+ load_from_ckpt=None,
485
+ ):
486
+ super().__init__()
487
+ if use_spatial_transformer:
488
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
489
+
490
+ if context_dim is not None:
491
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
492
+ from omegaconf.listconfig import ListConfig
493
+ if type(context_dim) == ListConfig:
494
+ context_dim = list(context_dim)
495
+
496
+ if num_heads_upsample == -1:
497
+ num_heads_upsample = num_heads
498
+
499
+ if num_heads == -1:
500
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
501
+
502
+ if num_head_channels == -1:
503
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
504
+
505
+ self.image_size = image_size
506
+ self.in_channels = in_channels
507
+ self.model_channels = model_channels
508
+ self.out_channels = out_channels
509
+ if isinstance(num_res_blocks, int):
510
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
511
+ else:
512
+ if len(num_res_blocks) != len(channel_mult):
513
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
514
+ "as a list/tuple (per-level) with the same length as channel_mult")
515
+ self.num_res_blocks = num_res_blocks
516
+ if disable_self_attentions is not None:
517
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
518
+ assert len(disable_self_attentions) == len(channel_mult)
519
+ if num_attention_blocks is not None:
520
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
521
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
522
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
523
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
524
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
525
+ f"attention will still not be set.")
526
+
527
+ self.attention_resolutions = attention_resolutions
528
+ self.dropout = dropout
529
+ self.channel_mult = channel_mult
530
+ self.conv_resample = conv_resample
531
+ self.num_classes = num_classes
532
+ self.use_checkpoint = use_checkpoint
533
+ self.dtype = th.float16 if use_fp16 else th.float32
534
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
535
+ self.num_heads = num_heads
536
+ self.num_head_channels = num_head_channels
537
+ self.num_heads_upsample = num_heads_upsample
538
+ self.predict_codebook_ids = n_embed is not None
539
+
540
+ time_embed_dim = model_channels * 4
541
+ self.time_embed = nn.Sequential(
542
+ linear(model_channels, time_embed_dim),
543
+ nn.SiLU(),
544
+ linear(time_embed_dim, time_embed_dim),
545
+ )
546
+
547
+ if self.num_classes is not None:
548
+ if isinstance(self.num_classes, int):
549
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
550
+ elif self.num_classes == "continuous":
551
+ print("setting up linear c_adm embedding layer")
552
+ self.label_emb = nn.Linear(1, time_embed_dim)
553
+ elif self.num_classes == "sequential":
554
+ assert adm_in_channels is not None
555
+ self.label_emb = nn.Sequential(
556
+ nn.Sequential(
557
+ linear(adm_in_channels, time_embed_dim),
558
+ nn.SiLU(),
559
+ linear(time_embed_dim, time_embed_dim),
560
+ )
561
+ )
562
+ else:
563
+ raise ValueError()
564
+
565
+ self.input_blocks = nn.ModuleList(
566
+ [
567
+ TimestepEmbedSequential(
568
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
569
+ )
570
+ ]
571
+ )
572
+ self._feature_size = model_channels
573
+ input_block_chans = [model_channels]
574
+ ch = model_channels
575
+ ds = 1
576
+ for level, mult in enumerate(channel_mult):
577
+ for nr in range(self.num_res_blocks[level]):
578
+ layers = [
579
+ ResBlock(
580
+ ch,
581
+ time_embed_dim,
582
+ dropout,
583
+ out_channels=mult * model_channels,
584
+ dims=dims,
585
+ use_checkpoint=use_checkpoint,
586
+ use_scale_shift_norm=use_scale_shift_norm,
587
+ )
588
+ ]
589
+ ch = mult * model_channels
590
+ if ds in attention_resolutions:
591
+ if num_head_channels == -1:
592
+ dim_head = ch // num_heads
593
+ else:
594
+ num_heads = ch // num_head_channels
595
+ dim_head = num_head_channels
596
+ if legacy:
597
+ #num_heads = 1
598
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
599
+ if exists(disable_self_attentions):
600
+ disabled_sa = disable_self_attentions[level]
601
+ else:
602
+ disabled_sa = False
603
+
604
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
605
+ layers.append(
606
+ AttentionBlock(
607
+ ch,
608
+ use_checkpoint=use_checkpoint,
609
+ num_heads=num_heads,
610
+ num_head_channels=dim_head,
611
+ use_new_attention_order=use_new_attention_order,
612
+ ) if not use_spatial_transformer else SpatialTransformer(
613
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
614
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
615
+ use_checkpoint=use_checkpoint
616
+ )
617
+ )
618
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
619
+ self._feature_size += ch
620
+ input_block_chans.append(ch)
621
+ if level != len(channel_mult) - 1:
622
+ out_ch = ch
623
+ self.input_blocks.append(
624
+ TimestepEmbedSequential(
625
+ ResBlock(
626
+ ch,
627
+ time_embed_dim,
628
+ dropout,
629
+ out_channels=out_ch,
630
+ dims=dims,
631
+ use_checkpoint=use_checkpoint,
632
+ use_scale_shift_norm=use_scale_shift_norm,
633
+ down=True,
634
+ )
635
+ if resblock_updown
636
+ else Downsample(
637
+ ch, conv_resample, dims=dims, out_channels=out_ch
638
+ )
639
+ )
640
+ )
641
+ ch = out_ch
642
+ input_block_chans.append(ch)
643
+ ds *= 2
644
+ self._feature_size += ch
645
+
646
+ if num_head_channels == -1:
647
+ dim_head = ch // num_heads
648
+ else:
649
+ num_heads = ch // num_head_channels
650
+ dim_head = num_head_channels
651
+ if legacy:
652
+ #num_heads = 1
653
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
654
+ self.middle_block = TimestepEmbedSequential(
655
+ ResBlock(
656
+ ch,
657
+ time_embed_dim,
658
+ dropout,
659
+ dims=dims,
660
+ use_checkpoint=use_checkpoint,
661
+ use_scale_shift_norm=use_scale_shift_norm,
662
+ ),
663
+ AttentionBlock(
664
+ ch,
665
+ use_checkpoint=use_checkpoint,
666
+ num_heads=num_heads,
667
+ num_head_channels=dim_head,
668
+ use_new_attention_order=use_new_attention_order,
669
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
670
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
671
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
672
+ use_checkpoint=use_checkpoint
673
+ ),
674
+ ResBlock(
675
+ ch,
676
+ time_embed_dim,
677
+ dropout,
678
+ dims=dims,
679
+ use_checkpoint=use_checkpoint,
680
+ use_scale_shift_norm=use_scale_shift_norm,
681
+ ),
682
+ )
683
+ self._feature_size += ch
684
+
685
+ self.output_blocks = nn.ModuleList([])
686
+ for level, mult in list(enumerate(channel_mult))[::-1]:
687
+ for i in range(self.num_res_blocks[level] + 1):
688
+ ich = input_block_chans.pop()
689
+ layers = [
690
+ ResBlock(
691
+ ch + ich,
692
+ time_embed_dim,
693
+ dropout,
694
+ out_channels=model_channels * mult,
695
+ dims=dims,
696
+ use_checkpoint=use_checkpoint,
697
+ use_scale_shift_norm=use_scale_shift_norm,
698
+ )
699
+ ]
700
+ ch = model_channels * mult
701
+ if ds in attention_resolutions:
702
+ if num_head_channels == -1:
703
+ dim_head = ch // num_heads
704
+ else:
705
+ num_heads = ch // num_head_channels
706
+ dim_head = num_head_channels
707
+ if legacy:
708
+ #num_heads = 1
709
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
710
+ if exists(disable_self_attentions):
711
+ disabled_sa = disable_self_attentions[level]
712
+ else:
713
+ disabled_sa = False
714
+
715
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
716
+ layers.append(
717
+ AttentionBlock(
718
+ ch,
719
+ use_checkpoint=use_checkpoint,
720
+ num_heads=num_heads_upsample,
721
+ num_head_channels=dim_head,
722
+ use_new_attention_order=use_new_attention_order,
723
+ ) if not use_spatial_transformer else SpatialTransformer(
724
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
725
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
726
+ use_checkpoint=use_checkpoint
727
+ )
728
+ )
729
+ if level and i == self.num_res_blocks[level]:
730
+ out_ch = ch
731
+ layers.append(
732
+ ResBlock(
733
+ ch,
734
+ time_embed_dim,
735
+ dropout,
736
+ out_channels=out_ch,
737
+ dims=dims,
738
+ use_checkpoint=use_checkpoint,
739
+ use_scale_shift_norm=use_scale_shift_norm,
740
+ up=True,
741
+ )
742
+ if resblock_updown
743
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
744
+ )
745
+ ds //= 2
746
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
747
+ self._feature_size += ch
748
+
749
+ self.out = nn.Sequential(
750
+ normalization(ch),
751
+ nn.SiLU(),
752
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
753
+ )
754
+ if self.predict_codebook_ids:
755
+ self.id_predictor = nn.Sequential(
756
+ normalization(ch),
757
+ conv_nd(dims, model_channels, n_embed, 1),
758
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
759
+ )
760
+
761
+ if load_from_ckpt is not None:
762
+ self.load_from_ckpt(load_from_ckpt)
763
+
764
+ def load_from_ckpt(self, ckpt_path):
765
+ input_ch = self.state_dict()["input_blocks.0.0.weight"].shape[1]
766
+ assert input_ch >= 4 and input_ch // 4 * 4 == input_ch, "Input channels must be at a multiplier 4 to load from SD ckpt"
767
+ output_ch = self.state_dict()["out.2.weight"].shape[0]
768
+ assert output_ch >= 4 and output_ch // 4 * 4 == output_ch, "Output channels must be at a multiplier 4 to load from SD ckpt"
769
+ sd = th.load(ckpt_path)
770
+ sd_ = {}
771
+ for k,v in sd["state_dict"].items():
772
+ if k.startswith("model.diffusion_model"):
773
+ sd_[k.replace("model.diffusion_model.", "")] = v
774
+
775
+ if input_ch > 4:
776
+ # Scaling for input channels so that the gradients are not too large
777
+ scale = input_ch // 4
778
+ sd_["input_blocks.0.0.weight"] = sd_["input_blocks.0.0.weight"] / scale
779
+ sd_["input_blocks.0.0.weight"] = sd_["input_blocks.0.0.weight"].repeat(1, scale, 1, 1)
780
+
781
+ if output_ch > 4:
782
+ # No scaling for output channels
783
+ scale = output_ch // 4
784
+ sd_["out.2.weight"] = sd_["out.2.weight"].repeat(scale, 1, 1, 1)
785
+ sd_["out.2.bias"] = sd_["out.2.bias"].repeat(scale)
786
+
787
+ missing, unexpected = self.load_state_dict(sd_, strict=False)
788
+
789
+ if len(missing) > 0:
790
+ print(f"Load model weights - missing keys: {len(missing)}")
791
+ print(missing)
792
+ if len(unexpected) > 0:
793
+ print(f"Load model weights - unexpected keys: {len(unexpected)}")
794
+ print(unexpected)
795
+
796
+
797
+ def convert_to_fp16(self):
798
+ """
799
+ Convert the torso of the model to float16.
800
+ """
801
+ self.input_blocks.apply(convert_module_to_f16)
802
+ self.middle_block.apply(convert_module_to_f16)
803
+ self.output_blocks.apply(convert_module_to_f16)
804
+
805
+ def convert_to_fp32(self):
806
+ """
807
+ Convert the torso of the model to float32.
808
+ """
809
+ self.input_blocks.apply(convert_module_to_f32)
810
+ self.middle_block.apply(convert_module_to_f32)
811
+ self.output_blocks.apply(convert_module_to_f32)
812
+
813
+ def forward(self, x, t=None, context=None, context_ca=None, y=None,**kwargs):
814
+ """
815
+ Apply the model to an input batch.
816
+ :param x: an [N x C x ...] Tensor of inputs.
817
+ :param t: a 1-D batch of timesteps.
818
+ :param context: conditioning plugged in via crossattn
819
+ :param y: an [N] Tensor of labels, if class-conditional.
820
+ :return: an [N x C x ...] Tensor of outputs.
821
+ """
822
+ assert (y is not None) == (
823
+ self.num_classes is not None
824
+ ), "must specify y if and only if the model is class-conditional"
825
+ hs = []
826
+ t_emb = timestep_embedding(t, self.model_channels, repeat_only=False)
827
+ emb = self.time_embed(t_emb)
828
+
829
+ if self.num_classes is not None:
830
+ assert y.shape[0] == x.shape[0]
831
+ emb = emb + self.label_emb(y)
832
+
833
+ h = x.type(self.dtype)
834
+ if context is not None:
835
+ h = th.cat([h, context], dim=1)
836
+ for module in self.input_blocks:
837
+ h = module(h, emb, context_ca)
838
+ hs.append(h)
839
+ h = self.middle_block(h, emb, context_ca)
840
+ for module in self.output_blocks:
841
+ h = th.cat([h, hs.pop()], dim=1)
842
+ h = module(h, emb, context_ca)
843
+ h = h.type(x.dtype)
844
+ if self.predict_codebook_ids:
845
+ return self.id_predictor(h)
846
+ else:
847
+ return self.out(h)
848
+
849
+ def get_midblock_features(self, x, t=None, context=None, context_ca=None, y=None, **kwargs):
850
+ """
851
+ Apply the model to an input batch and return the features from the middle block.
852
+ :param x: an [N x C x ...] Tensor of inputs.
853
+ :param t: a 1-D batch of timesteps.
854
+ :param context: conditioning plugged in via crossattn
855
+ :param y: an [N] Tensor of labels, if class-conditional
856
+ """
857
+ assert (y is not None) == (
858
+ self.num_classes is not None
859
+ ), "must specify y if and only if the model is class-conditional"
860
+ hs = []
861
+ t_emb = timestep_embedding(t, self.model_channels, repeat_only=False)
862
+ emb = self.time_embed(t_emb)
863
+
864
+ if self.num_classes is not None:
865
+ assert y.shape[0] == x.shape[0]
866
+ emb = emb + self.label_emb(y)
867
+
868
+ h = x.type(self.dtype)
869
+ if context is not None:
870
+ h = th.cat([h, context], dim=1)
871
+ for module in self.input_blocks:
872
+ h = module(h, emb, context_ca)
873
+ hs.append(h)
874
+ h = self.middle_block(h, emb, context_ca)
875
+ return h
876
+
877
+ if __name__ == "__main__":
878
+ unet = UNetModel(
879
+ image_size=32,
880
+ in_channels=8,
881
+ model_channels=320,
882
+ out_channels=4,
883
+ num_res_blocks=2,
884
+ attention_resolutions=(4,2,1),
885
+ dropout=0.0,
886
+ channel_mult=(1, 2, 4, 4),
887
+ num_heads=8,
888
+ use_spatial_transformer=True,
889
+ context_dim=768,
890
+ transformer_depth=1,
891
+ legacy=False,
892
+ load_from_ckpt="/export/scratch/ra97ram/checkpoints/sd/v1-5-pruned.ckpt"
893
+ )
894
+ print(f"UNetModel has {sum(p.numel() for p in unet.parameters())} parameters")
depthfm/unet/util.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+
19
+ def extract_into_tensor(a, t, x_shape):
20
+ b, *_ = t.shape
21
+ out = a.gather(-1, t)
22
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
23
+
24
+
25
+ def checkpoint(func, inputs, params, flag):
26
+ """
27
+ Evaluate a function without caching intermediate activations, allowing for
28
+ reduced memory at the expense of extra compute in the backward pass.
29
+ :param func: the function to evaluate.
30
+ :param inputs: the argument sequence to pass to `func`.
31
+ :param params: a sequence of parameters `func` depends on but does not
32
+ explicitly take as arguments.
33
+ :param flag: if False, disable gradient checkpointing.
34
+ """
35
+ if flag:
36
+ args = tuple(inputs) + tuple(params)
37
+ return CheckpointFunction.apply(func, len(inputs), *args)
38
+ else:
39
+ return func(*inputs)
40
+
41
+
42
+ class CheckpointFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, run_function, length, *args):
45
+ ctx.run_function = run_function
46
+ ctx.input_tensors = list(args[:length])
47
+ ctx.input_params = list(args[length:])
48
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
49
+ "dtype": torch.get_autocast_gpu_dtype(),
50
+ "cache_enabled": torch.is_autocast_cache_enabled()}
51
+ with torch.no_grad():
52
+ output_tensors = ctx.run_function(*ctx.input_tensors)
53
+ return output_tensors
54
+
55
+ @staticmethod
56
+ def backward(ctx, *output_grads):
57
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
58
+ with torch.enable_grad(), \
59
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
60
+ # Fixes a bug where the first op in run_function modifies the
61
+ # Tensor storage in place, which is not allowed for detach()'d
62
+ # Tensors.
63
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
64
+ output_tensors = ctx.run_function(*shallow_copies)
65
+ input_grads = torch.autograd.grad(
66
+ output_tensors,
67
+ ctx.input_tensors + ctx.input_params,
68
+ output_grads,
69
+ allow_unused=True,
70
+ )
71
+ del ctx.input_tensors
72
+ del ctx.input_params
73
+ del output_tensors
74
+ return (None, None) + input_grads
75
+
76
+
77
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
78
+ """
79
+ Create sinusoidal timestep embeddings.
80
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
81
+ These may be fractional.
82
+ :param dim: the dimension of the output.
83
+ :param max_period: controls the minimum frequency of the embeddings.
84
+ :return: an [N x dim] Tensor of positional embeddings.
85
+ """
86
+ if not repeat_only:
87
+ half = dim // 2
88
+ freqs = torch.exp(
89
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
90
+ ).to(device=timesteps.device)
91
+ args = timesteps[:, None].float() * freqs[None]
92
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
93
+ if dim % 2:
94
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
95
+ else:
96
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
97
+ return embedding
98
+
99
+
100
+ def zero_module(module):
101
+ """
102
+ Zero out the parameters of a module and return it.
103
+ """
104
+ for p in module.parameters():
105
+ p.detach().zero_()
106
+ return module
107
+
108
+
109
+ def scale_module(module, scale):
110
+ """
111
+ Scale the parameters of a module and return it.
112
+ """
113
+ for p in module.parameters():
114
+ p.detach().mul_(scale)
115
+ return module
116
+
117
+
118
+ def mean_flat(tensor):
119
+ """
120
+ Take the mean over all non-batch dimensions.
121
+ """
122
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
123
+
124
+
125
+ def normalization(channels):
126
+ """
127
+ Make a standard normalization layer.
128
+ :param channels: number of input channels.
129
+ :return: an nn.Module for normalization.
130
+ """
131
+ return GroupNorm32(32, channels)
132
+
133
+
134
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
135
+ class SiLU(nn.Module):
136
+ def forward(self, x):
137
+ return x * torch.sigmoid(x)
138
+
139
+
140
+ class GroupNorm32(nn.GroupNorm):
141
+ def forward(self, x):
142
+ return super().forward(x.float()).type(x.dtype)
143
+
144
+
145
+ def conv_nd(dims, *args, **kwargs):
146
+ """
147
+ Create a 1D, 2D, or 3D convolution module.
148
+ """
149
+ if dims == 1:
150
+ return nn.Conv1d(*args, **kwargs)
151
+ elif dims == 2:
152
+ return nn.Conv2d(*args, **kwargs)
153
+ elif dims == 3:
154
+ return nn.Conv3d(*args, **kwargs)
155
+ raise ValueError(f"unsupported dimensions: {dims}")
156
+
157
+
158
+ def linear(*args, **kwargs):
159
+ """
160
+ Create a linear module.
161
+ """
162
+ return nn.Linear(*args, **kwargs)
163
+
164
+
165
+ def avg_pool_nd(dims, *args, **kwargs):
166
+ """
167
+ Create a 1D, 2D, or 3D average pooling module.
168
+ """
169
+ if dims == 1:
170
+ return nn.AvgPool1d(*args, **kwargs)
171
+ elif dims == 2:
172
+ return nn.AvgPool2d(*args, **kwargs)
173
+ elif dims == 3:
174
+ return nn.AvgPool3d(*args, **kwargs)
175
+ raise ValueError(f"unsupported dimensions: {dims}")