bhuvanmdev commited on
Commit
6dc1111
1 Parent(s): 3a65b35

Update lgm/lgm.py

Browse files
Files changed (1) hide show
  1. lgm/lgm.py +815 -808
lgm/lgm.py CHANGED
@@ -1,808 +1,815 @@
1
- import os
2
- import warnings
3
- from functools import partial
4
- from typing import Literal, Tuple
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from diff_gaussian_rasterization import (
10
- GaussianRasterizationSettings,
11
- GaussianRasterizer,
12
- )
13
- from diffusers import ConfigMixin, ModelMixin
14
- from torch import Tensor, nn
15
-
16
-
17
- def look_at(campos):
18
- forward_vector = -campos / np.linalg.norm(campos, axis=-1)
19
- up_vector = np.array([0, 1, 0], dtype=np.float32)
20
- right_vector = np.cross(up_vector, forward_vector)
21
- up_vector = np.cross(forward_vector, right_vector)
22
- R = np.stack([right_vector, up_vector, forward_vector], axis=-1)
23
- return R
24
-
25
-
26
- def orbit_camera(elevation, azimuth, radius=1):
27
- elevation = np.deg2rad(elevation)
28
- azimuth = np.deg2rad(azimuth)
29
- x = radius * np.cos(elevation) * np.sin(azimuth)
30
- y = -radius * np.sin(elevation)
31
- z = radius * np.cos(elevation) * np.cos(azimuth)
32
- campos = np.array([x, y, z])
33
- T = np.eye(4, dtype=np.float32)
34
- T[:3, :3] = look_at(campos)
35
- T[:3, 3] = campos
36
- return T
37
-
38
-
39
- def get_rays(pose, h, w, fovy, opengl=True):
40
- x, y = torch.meshgrid(
41
- torch.arange(w, device=pose.device),
42
- torch.arange(h, device=pose.device),
43
- indexing="xy",
44
- )
45
- x = x.flatten()
46
- y = y.flatten()
47
-
48
- cx = w * 0.5
49
- cy = h * 0.5
50
-
51
- focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
52
-
53
- camera_dirs = F.pad(
54
- torch.stack(
55
- [
56
- (x - cx + 0.5) / focal,
57
- (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
58
- ],
59
- dim=-1,
60
- ),
61
- (0, 1),
62
- value=(-1.0 if opengl else 1.0),
63
- )
64
-
65
- rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1)
66
- rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d)
67
-
68
- rays_o = rays_o.view(h, w, 3)
69
- rays_d = F.normalize(rays_d, dim=-1).view(h, w, 3)
70
-
71
- return rays_o, rays_d
72
-
73
-
74
- class GaussianRenderer:
75
- def __init__(self, fovy, output_size):
76
- self.output_size = output_size
77
-
78
- self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
79
-
80
- zfar = 2.5
81
- znear = 0.1
82
- self.tan_half_fov = np.tan(0.5 * np.deg2rad(fovy))
83
- self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
84
- self.proj_matrix[0, 0] = 1 / self.tan_half_fov
85
- self.proj_matrix[1, 1] = 1 / self.tan_half_fov
86
- self.proj_matrix[2, 2] = (zfar + znear) / (zfar - znear)
87
- self.proj_matrix[3, 2] = -(zfar * znear) / (zfar - znear)
88
- self.proj_matrix[2, 3] = 1
89
-
90
- def render(
91
- self,
92
- gaussians,
93
- cam_view,
94
- cam_view_proj,
95
- cam_pos,
96
- bg_color=None,
97
- scale_modifier=1,
98
- ):
99
- device = gaussians.device
100
- B, V = cam_view.shape[:2]
101
-
102
- images = []
103
- alphas = []
104
- for b in range(B):
105
-
106
- means3D = gaussians[b, :, 0:3].contiguous().float()
107
- opacity = gaussians[b, :, 3:4].contiguous().float()
108
- scales = gaussians[b, :, 4:7].contiguous().float()
109
- rotations = gaussians[b, :, 7:11].contiguous().float()
110
- rgbs = gaussians[b, :, 11:].contiguous().float()
111
-
112
- for v in range(V):
113
- view_matrix = cam_view[b, v].float()
114
- view_proj_matrix = cam_view_proj[b, v].float()
115
- campos = cam_pos[b, v].float()
116
-
117
- raster_settings = GaussianRasterizationSettings(
118
- image_height=self.output_size,
119
- image_width=self.output_size,
120
- tanfovx=self.tan_half_fov,
121
- tanfovy=self.tan_half_fov,
122
- bg=self.bg_color if bg_color is None else bg_color,
123
- scale_modifier=scale_modifier,
124
- viewmatrix=view_matrix,
125
- projmatrix=view_proj_matrix,
126
- sh_degree=0,
127
- campos=campos,
128
- prefiltered=False,
129
- debug=False,
130
- )
131
-
132
- rasterizer = GaussianRasterizer(raster_settings=raster_settings)
133
-
134
- rendered_image, _, _, rendered_alpha = rasterizer(
135
- means3D=means3D,
136
- means2D=torch.zeros_like(
137
- means3D, dtype=torch.float32, device=device
138
- ),
139
- shs=None,
140
- colors_precomp=rgbs,
141
- opacities=opacity,
142
- scales=scales,
143
- rotations=rotations,
144
- cov3D_precomp=None,
145
- )
146
-
147
- rendered_image = rendered_image.clamp(0, 1)
148
-
149
- images.append(rendered_image)
150
- alphas.append(rendered_alpha)
151
-
152
- images = torch.stack(images, dim=0).view(
153
- B, V, 3, self.output_size, self.output_size
154
- )
155
- alphas = torch.stack(alphas, dim=0).view(
156
- B, V, 1, self.output_size, self.output_size
157
- )
158
-
159
- return {"image": images, "alpha": alphas}
160
-
161
- def save_ply(self, gaussians, path):
162
- assert gaussians.shape[0] == 1, "only support batch size 1"
163
-
164
- from plyfile import PlyData, PlyElement
165
-
166
- means3D = gaussians[0, :, 0:3].contiguous().float()
167
- opacity = gaussians[0, :, 3:4].contiguous().float()
168
- scales = gaussians[0, :, 4:7].contiguous().float()
169
- rotations = gaussians[0, :, 7:11].contiguous().float()
170
- shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float()
171
-
172
- mask = opacity.squeeze(-1) >= 0.005
173
- means3D = means3D[mask]
174
- opacity = opacity[mask]
175
- scales = scales[mask]
176
- rotations = rotations[mask]
177
- shs = shs[mask]
178
-
179
- opacity = opacity.clamp(1e-6, 1 - 1e-6)
180
- opacity = torch.log(opacity / (1 - opacity))
181
- scales = torch.log(scales + 1e-8)
182
- shs = (shs - 0.5) / 0.28209479177387814
183
-
184
- xyzs = means3D.detach().cpu().numpy()
185
- f_dc = (
186
- shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
187
- )
188
- opacities = opacity.detach().cpu().numpy()
189
- scales = scales.detach().cpu().numpy()
190
- rotations = rotations.detach().cpu().numpy()
191
-
192
- h = ["x", "y", "z"]
193
- for i in range(f_dc.shape[1]):
194
- h.append("f_dc_{}".format(i))
195
- h.append("opacity")
196
- for i in range(scales.shape[1]):
197
- h.append("scale_{}".format(i))
198
- for i in range(rotations.shape[1]):
199
- h.append("rot_{}".format(i))
200
-
201
- dtype_full = [(attribute, "f4") for attribute in h]
202
-
203
- elements = np.empty(xyzs.shape[0], dtype=dtype_full)
204
- attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
205
- elements[:] = list(map(tuple, attributes))
206
- el = PlyElement.describe(elements, "vertex")
207
-
208
- PlyData([el]).write(path)
209
-
210
-
211
- class LGM(ModelMixin, ConfigMixin):
212
- def __init__(self):
213
- super().__init__()
214
-
215
- self.input_size = 256
216
- self.splat_size = 128
217
- self.output_size = 512
218
- self.radius = 1.5
219
- self.fovy = 49.1
220
-
221
- self.unet = UNet(
222
- 9,
223
- 14,
224
- down_channels=(64, 128, 256, 512, 1024, 1024),
225
- down_attention=(False, False, False, True, True, True),
226
- mid_attention=True,
227
- up_channels=(1024, 1024, 512, 256, 128),
228
- up_attention=(True, True, True, False, False),
229
- )
230
-
231
- self.conv = nn.Conv2d(14, 14, kernel_size=1)
232
- self.gs = GaussianRenderer(self.fovy, self.output_size)
233
-
234
- self.pos_act = lambda x: x.clamp(-1, 1)
235
- self.scale_act = lambda x: 0.1 * F.softplus(x)
236
- self.opacity_act = lambda x: torch.sigmoid(x)
237
- self.rot_act = F.normalize
238
- self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
239
-
240
- def prepare_default_rays(self, device, elevation=0):
241
- cam_poses = np.stack(
242
- [
243
- orbit_camera(elevation, 0, radius=self.radius),
244
- orbit_camera(elevation, 90, radius=self.radius),
245
- orbit_camera(elevation, 180, radius=self.radius),
246
- orbit_camera(elevation, 270, radius=self.radius),
247
- ],
248
- axis=0,
249
- )
250
- cam_poses = torch.from_numpy(cam_poses)
251
-
252
- rays_embeddings = []
253
- for i in range(cam_poses.shape[0]):
254
- rays_o, rays_d = get_rays(
255
- cam_poses[i], self.input_size, self.input_size, self.fovy
256
- )
257
- rays_plucker = torch.cat(
258
- [torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1
259
- )
260
- rays_embeddings.append(rays_plucker)
261
-
262
- rays_embeddings = (
263
- torch.stack(rays_embeddings, dim=0)
264
- .permute(0, 3, 1, 2)
265
- .contiguous()
266
- .to(device)
267
- )
268
-
269
- return rays_embeddings
270
-
271
- def forward(self, images):
272
- B, V, C, H, W = images.shape
273
- images = images.view(B * V, C, H, W)
274
-
275
- x = self.unet(images)
276
- x = self.conv(x)
277
-
278
- x = x.reshape(B, 4, 14, self.splat_size, self.splat_size)
279
-
280
- x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
281
-
282
- pos = self.pos_act(x[..., 0:3])
283
- opacity = self.opacity_act(x[..., 3:4])
284
- scale = self.scale_act(x[..., 4:7])
285
- rotation = self.rot_act(x[..., 7:11])
286
- rgbs = self.rgb_act(x[..., 11:])
287
-
288
- q = torch.tensor([0, 0, 1, 0], dtype=pos.dtype, device=pos.device)
289
- R = torch.tensor(
290
- [
291
- [-1, 0, 0],
292
- [0, -1, 0],
293
- [0, 0, 1],
294
- ],
295
- dtype=pos.dtype,
296
- device=pos.device,
297
- )
298
-
299
- pos = torch.matmul(pos, R.T)
300
-
301
- def multiply_quat(q1, q2):
302
- w1, x1, y1, z1 = q1.unbind(-1)
303
- w2, x2, y2, z2 = q2.unbind(-1)
304
- w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
305
- x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
306
- y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
307
- z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
308
- return torch.stack([w, x, y, z], dim=-1)
309
-
310
- for i in range(B):
311
- rotation[i, :] = multiply_quat(q, rotation[i, :])
312
-
313
- gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1)
314
-
315
- return gaussians
316
-
317
-
318
- # =============================================================================
319
- # Copyright (c) Meta Platforms, Inc. and affiliates.
320
- #
321
- # This source code is licensed under the Apache License, Version 2.0
322
- # found in the LICENSE file in the root directory of this source tree.
323
-
324
- # References:
325
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
326
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
327
- # =============================================================================
328
- XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
329
- try:
330
- if XFORMERS_ENABLED:
331
- from xformers.ops import memory_efficient_attention, unbind
332
-
333
- XFORMERS_AVAILABLE = True
334
- warnings.warn("xFormers is available (Attention)")
335
- else:
336
- warnings.warn("xFormers is disabled (Attention)")
337
- raise ImportError
338
- except ImportError:
339
- XFORMERS_AVAILABLE = False
340
- warnings.warn("xFormers is not available (Attention)")
341
-
342
-
343
- class Attention(nn.Module):
344
- def __init__(
345
- self,
346
- dim: int,
347
- num_heads: int = 8,
348
- qkv_bias: bool = False,
349
- proj_bias: bool = True,
350
- attn_drop: float = 0.0,
351
- proj_drop: float = 0.0,
352
- ) -> None:
353
- super().__init__()
354
- self.num_heads = num_heads
355
- head_dim = dim // num_heads
356
- self.scale = head_dim**-0.5
357
-
358
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
359
- self.attn_drop = nn.Dropout(attn_drop)
360
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
361
- self.proj_drop = nn.Dropout(proj_drop)
362
-
363
- def forward(self, x: Tensor) -> Tensor:
364
- B, N, C = x.shape
365
- qkv = (
366
- self.qkv(x)
367
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
368
- .permute(2, 0, 3, 1, 4)
369
- )
370
-
371
- q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
372
- attn = q @ k.transpose(-2, -1)
373
-
374
- attn = attn.softmax(dim=-1)
375
- attn = self.attn_drop(attn)
376
-
377
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
378
- x = self.proj(x)
379
- x = self.proj_drop(x)
380
- return x
381
-
382
-
383
- class MemEffAttention(Attention):
384
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
385
- if not XFORMERS_AVAILABLE:
386
- if attn_bias is not None:
387
- raise AssertionError("xFormers is required for using nested tensors")
388
- return super().forward(x)
389
-
390
- B, N, C = x.shape
391
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
392
-
393
- q, k, v = unbind(qkv, 2)
394
-
395
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
396
- x = x.reshape([B, N, C])
397
-
398
- x = self.proj(x)
399
- x = self.proj_drop(x)
400
- return x
401
-
402
-
403
- class CrossAttention(nn.Module):
404
- def __init__(
405
- self,
406
- dim: int,
407
- dim_q: int,
408
- dim_k: int,
409
- dim_v: int,
410
- num_heads: int = 8,
411
- qkv_bias: bool = False,
412
- proj_bias: bool = True,
413
- attn_drop: float = 0.0,
414
- proj_drop: float = 0.0,
415
- ) -> None:
416
- super().__init__()
417
- self.dim = dim
418
- self.num_heads = num_heads
419
- head_dim = dim // num_heads
420
- self.scale = head_dim**-0.5
421
-
422
- self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
423
- self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
424
- self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
425
- self.attn_drop = nn.Dropout(attn_drop)
426
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
427
- self.proj_drop = nn.Dropout(proj_drop)
428
-
429
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
430
- B, N, _ = q.shape
431
- M = k.shape[1]
432
-
433
- q = self.scale * self.to_q(q).reshape(
434
- B, N, self.num_heads, self.dim // self.num_heads
435
- ).permute(0, 2, 1, 3)
436
- k = (
437
- self.to_k(k)
438
- .reshape(B, M, self.num_heads, self.dim // self.num_heads)
439
- .permute(0, 2, 1, 3)
440
- )
441
- v = (
442
- self.to_v(v)
443
- .reshape(B, M, self.num_heads, self.dim // self.num_heads)
444
- .permute(0, 2, 1, 3)
445
- )
446
-
447
- attn = q @ k.transpose(-2, -1)
448
-
449
- attn = attn.softmax(dim=-1)
450
- attn = self.attn_drop(attn)
451
-
452
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
453
- x = self.proj(x)
454
- x = self.proj_drop(x)
455
- return x
456
-
457
-
458
- class MemEffCrossAttention(CrossAttention):
459
- def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
460
- if not XFORMERS_AVAILABLE:
461
- if attn_bias is not None:
462
- raise AssertionError("xFormers is required for using nested tensors")
463
- return super().forward(q, k, v)
464
-
465
- B, N, _ = q.shape
466
- M = k.shape[1]
467
-
468
- q = self.scale * self.to_q(q).reshape(
469
- B, N, self.num_heads, self.dim // self.num_heads
470
- )
471
- k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads)
472
- v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads)
473
-
474
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
475
- x = x.reshape(B, N, -1)
476
-
477
- x = self.proj(x)
478
- x = self.proj_drop(x)
479
- return x
480
-
481
-
482
- # =============================================================================
483
- # End of xFormers
484
-
485
-
486
- class MVAttention(nn.Module):
487
- def __init__(
488
- self,
489
- dim: int,
490
- num_heads: int = 8,
491
- qkv_bias: bool = False,
492
- proj_bias: bool = True,
493
- attn_drop: float = 0.0,
494
- proj_drop: float = 0.0,
495
- groups: int = 32,
496
- eps: float = 1e-5,
497
- residual: bool = True,
498
- skip_scale: float = 1,
499
- num_frames: int = 4,
500
- ):
501
- super().__init__()
502
-
503
- self.residual = residual
504
- self.skip_scale = skip_scale
505
- self.num_frames = num_frames
506
-
507
- self.norm = nn.GroupNorm(
508
- num_groups=groups, num_channels=dim, eps=eps, affine=True
509
- )
510
- self.attn = MemEffAttention(
511
- dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop
512
- )
513
-
514
- def forward(self, x):
515
- BV, C, H, W = x.shape
516
- B = BV // self.num_frames
517
-
518
- res = x
519
- x = self.norm(x)
520
-
521
- x = (
522
- x.reshape(B, self.num_frames, C, H, W)
523
- .permute(0, 1, 3, 4, 2)
524
- .reshape(B, -1, C)
525
- )
526
- x = self.attn(x)
527
- x = (
528
- x.reshape(B, self.num_frames, H, W, C)
529
- .permute(0, 1, 4, 2, 3)
530
- .reshape(BV, C, H, W)
531
- )
532
-
533
- if self.residual:
534
- x = (x + res) * self.skip_scale
535
- return x
536
-
537
-
538
- class ResnetBlock(nn.Module):
539
- def __init__(
540
- self,
541
- in_channels: int,
542
- out_channels: int,
543
- resample: Literal["default", "up", "down"] = "default",
544
- groups: int = 32,
545
- eps: float = 1e-5,
546
- skip_scale: float = 1,
547
- ):
548
- super().__init__()
549
-
550
- self.in_channels = in_channels
551
- self.out_channels = out_channels
552
- self.skip_scale = skip_scale
553
-
554
- self.norm1 = nn.GroupNorm(
555
- num_groups=groups, num_channels=in_channels, eps=eps, affine=True
556
- )
557
- self.conv1 = nn.Conv2d(
558
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
559
- )
560
-
561
- self.norm2 = nn.GroupNorm(
562
- num_groups=groups, num_channels=out_channels, eps=eps, affine=True
563
- )
564
- self.conv2 = nn.Conv2d(
565
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
566
- )
567
-
568
- self.act = F.silu
569
-
570
- self.resample = None
571
- if resample == "up":
572
- self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
573
- elif resample == "down":
574
- self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
575
-
576
- self.shortcut = nn.Identity()
577
- if self.in_channels != self.out_channels:
578
- self.shortcut = nn.Conv2d(
579
- in_channels, out_channels, kernel_size=1, bias=True
580
- )
581
-
582
- def forward(self, x):
583
- res = x
584
- x = self.norm1(x)
585
- x = self.act(x)
586
- if self.resample:
587
- res = self.resample(res)
588
- x = self.resample(x)
589
- x = self.conv1(x)
590
- x = self.norm2(x)
591
- x = self.act(x)
592
- x = self.conv2(x)
593
- x = (x + self.shortcut(res)) * self.skip_scale
594
- return x
595
-
596
-
597
- class DownBlock(nn.Module):
598
- def __init__(
599
- self,
600
- in_channels: int,
601
- out_channels: int,
602
- num_layers: int = 1,
603
- downsample: bool = True,
604
- attention: bool = True,
605
- attention_heads: int = 16,
606
- skip_scale: float = 1,
607
- ):
608
- super().__init__()
609
-
610
- nets = []
611
- attns = []
612
- for i in range(num_layers):
613
- in_channels = in_channels if i == 0 else out_channels
614
- nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
615
- if attention:
616
- attns.append(
617
- MVAttention(out_channels, attention_heads, skip_scale=skip_scale)
618
- )
619
- else:
620
- attns.append(None)
621
- self.nets = nn.ModuleList(nets)
622
- self.attns = nn.ModuleList(attns)
623
-
624
- self.downsample = None
625
- if downsample:
626
- self.downsample = nn.Conv2d(
627
- out_channels, out_channels, kernel_size=3, stride=2, padding=1
628
- )
629
-
630
- def forward(self, x):
631
- xs = []
632
- for attn, net in zip(self.attns, self.nets):
633
- x = net(x)
634
- if attn:
635
- x = attn(x)
636
- xs.append(x)
637
- if self.downsample:
638
- x = self.downsample(x)
639
- xs.append(x)
640
- return x, xs
641
-
642
-
643
- class MidBlock(nn.Module):
644
- def __init__(
645
- self,
646
- in_channels: int,
647
- num_layers: int = 1,
648
- attention: bool = True,
649
- attention_heads: int = 16,
650
- skip_scale: float = 1,
651
- ):
652
- super().__init__()
653
-
654
- nets = []
655
- attns = []
656
- nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
657
- for _ in range(num_layers):
658
- nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
659
- if attention:
660
- attns.append(
661
- MVAttention(in_channels, attention_heads, skip_scale=skip_scale)
662
- )
663
- else:
664
- attns.append(None)
665
- self.nets = nn.ModuleList(nets)
666
- self.attns = nn.ModuleList(attns)
667
-
668
- def forward(self, x):
669
- x = self.nets[0](x)
670
- for attn, net in zip(self.attns, self.nets[1:]):
671
- if attn:
672
- x = attn(x)
673
- x = net(x)
674
- return x
675
-
676
-
677
- class UpBlock(nn.Module):
678
- def __init__(
679
- self,
680
- in_channels: int,
681
- prev_out_channels: int,
682
- out_channels: int,
683
- num_layers: int = 1,
684
- upsample: bool = True,
685
- attention: bool = True,
686
- attention_heads: int = 16,
687
- skip_scale: float = 1,
688
- ):
689
- super().__init__()
690
-
691
- nets = []
692
- attns = []
693
- for i in range(num_layers):
694
- cin = in_channels if i == 0 else out_channels
695
- cskip = prev_out_channels if (i == num_layers - 1) else out_channels
696
-
697
- nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
698
- if attention:
699
- attns.append(
700
- MVAttention(out_channels, attention_heads, skip_scale=skip_scale)
701
- )
702
- else:
703
- attns.append(None)
704
- self.nets = nn.ModuleList(nets)
705
- self.attns = nn.ModuleList(attns)
706
-
707
- self.upsample = None
708
- if upsample:
709
- self.upsample = nn.Conv2d(
710
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
711
- )
712
-
713
- def forward(self, x, xs):
714
- for attn, net in zip(self.attns, self.nets):
715
- res_x = xs[-1]
716
- xs = xs[:-1]
717
- x = torch.cat([x, res_x], dim=1)
718
- x = net(x)
719
- if attn:
720
- x = attn(x)
721
- if self.upsample:
722
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
723
- x = self.upsample(x)
724
- return x
725
-
726
-
727
- class UNet(nn.Module):
728
- def __init__(
729
- self,
730
- in_channels: int = 9,
731
- out_channels: int = 14,
732
- down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024),
733
- down_attention: Tuple[bool, ...] = (False, False, False, True, True, True),
734
- mid_attention: bool = True,
735
- up_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
736
- up_attention: Tuple[bool, ...] = (True, True, True, False, False),
737
- layers_per_block: int = 2,
738
- skip_scale: float = np.sqrt(0.5),
739
- ):
740
- super().__init__()
741
-
742
- self.conv_in = nn.Conv2d(
743
- in_channels, down_channels[0], kernel_size=3, stride=1, padding=1
744
- )
745
-
746
- down_blocks = []
747
- cout = down_channels[0]
748
- for i in range(len(down_channels)):
749
- cin = cout
750
- cout = down_channels[i]
751
-
752
- down_blocks.append(
753
- DownBlock(
754
- cin,
755
- cout,
756
- num_layers=layers_per_block,
757
- downsample=(i != len(down_channels) - 1),
758
- attention=down_attention[i],
759
- skip_scale=skip_scale,
760
- )
761
- )
762
- self.down_blocks = nn.ModuleList(down_blocks)
763
-
764
- self.mid_block = MidBlock(
765
- down_channels[-1], attention=mid_attention, skip_scale=skip_scale
766
- )
767
-
768
- up_blocks = []
769
- cout = up_channels[0]
770
- for i in range(len(up_channels)):
771
- cin = cout
772
- cout = up_channels[i]
773
- cskip = down_channels[max(-2 - i, -len(down_channels))]
774
-
775
- up_blocks.append(
776
- UpBlock(
777
- cin,
778
- cskip,
779
- cout,
780
- num_layers=layers_per_block + 1,
781
- upsample=(i != len(up_channels) - 1),
782
- attention=up_attention[i],
783
- skip_scale=skip_scale,
784
- )
785
- )
786
- self.up_blocks = nn.ModuleList(up_blocks)
787
- self.norm_out = nn.GroupNorm(
788
- num_channels=up_channels[-1], num_groups=32, eps=1e-5
789
- )
790
- self.conv_out = nn.Conv2d(
791
- up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1
792
- )
793
-
794
- def forward(self, x):
795
- x = self.conv_in(x)
796
- xss = [x]
797
- for block in self.down_blocks:
798
- x, xs = block(x)
799
- xss.extend(xs)
800
- x = self.mid_block(x)
801
- for block in self.up_blocks:
802
- xs = xss[-len(block.nets) :]
803
- xss = xss[: -len(block.nets)]
804
- x = block(x, xs)
805
- x = self.norm_out(x)
806
- x = F.silu(x)
807
- x = self.conv_out(x)
808
- return x
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from functools import partial
4
+ from typing import Literal, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from diff_gaussian_rasterization import (
10
+ GaussianRasterizationSettings,
11
+ GaussianRasterizer,
12
+ )
13
+ from diffusers import ConfigMixin, ModelMixin
14
+ from torch import Tensor, nn
15
+
16
+
17
+ def look_at(campos):
18
+ forward_vector = -campos / np.linalg.norm(campos, axis=-1)
19
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
20
+ right_vector = np.cross(up_vector, forward_vector)
21
+ up_vector = np.cross(forward_vector, right_vector)
22
+ R = np.stack([right_vector, up_vector, forward_vector], axis=-1)
23
+ return R
24
+
25
+
26
+ def orbit_camera(elevation, azimuth, radius=1):
27
+ elevation = np.deg2rad(elevation)
28
+ azimuth = np.deg2rad(azimuth)
29
+ x = radius * np.cos(elevation) * np.sin(azimuth)
30
+ y = -radius * np.sin(elevation)
31
+ z = radius * np.cos(elevation) * np.cos(azimuth)
32
+ campos = np.array([x, y, z])
33
+ T = np.eye(4, dtype=np.float32)
34
+ T[:3, :3] = look_at(campos)
35
+ T[:3, 3] = campos
36
+ return T
37
+
38
+
39
+ def get_rays(pose, h, w, fovy, opengl=True):
40
+ x, y = torch.meshgrid(
41
+ torch.arange(w, device=pose.device),
42
+ torch.arange(h, device=pose.device),
43
+ indexing="xy",
44
+ )
45
+ x = x.flatten()
46
+ y = y.flatten()
47
+
48
+ cx = w * 0.5
49
+ cy = h * 0.5
50
+
51
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
52
+
53
+ camera_dirs = F.pad(
54
+ torch.stack(
55
+ [
56
+ (x - cx + 0.5) / focal,
57
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
58
+ ],
59
+ dim=-1,
60
+ ),
61
+ (0, 1),
62
+ value=(-1.0 if opengl else 1.0),
63
+ )
64
+
65
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1)
66
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d)
67
+
68
+ rays_o = rays_o.view(h, w, 3)
69
+ rays_d = F.normalize(rays_d, dim=-1).view(h, w, 3)
70
+
71
+ return rays_o, rays_d
72
+
73
+
74
+ class GaussianRenderer:
75
+ def __init__(self, fovy, output_size):
76
+ self.output_size = output_size
77
+
78
+ self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
79
+
80
+ zfar = 2.5
81
+ znear = 0.1
82
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(fovy))
83
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
84
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
85
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
86
+ self.proj_matrix[2, 2] = (zfar + znear) / (zfar - znear)
87
+ self.proj_matrix[3, 2] = -(zfar * znear) / (zfar - znear)
88
+ self.proj_matrix[2, 3] = 1
89
+
90
+ def render(
91
+ self,
92
+ gaussians,
93
+ cam_view,
94
+ cam_view_proj,
95
+ cam_pos,
96
+ bg_color=None,
97
+ scale_modifier=1,
98
+ ):
99
+ device = gaussians.device
100
+ B, V = cam_view.shape[:2]
101
+
102
+ images = []
103
+ alphas = []
104
+ for b in range(B):
105
+
106
+ means3D = gaussians[b, :, 0:3].contiguous().float()
107
+ opacity = gaussians[b, :, 3:4].contiguous().float()
108
+ scales = gaussians[b, :, 4:7].contiguous().float()
109
+ rotations = gaussians[b, :, 7:11].contiguous().float()
110
+ rgbs = gaussians[b, :, 11:].contiguous().float()
111
+
112
+ for v in range(V):
113
+ view_matrix = cam_view[b, v].float()
114
+ view_proj_matrix = cam_view_proj[b, v].float()
115
+ campos = cam_pos[b, v].float()
116
+
117
+ raster_settings = GaussianRasterizationSettings(
118
+ image_height=self.output_size,
119
+ image_width=self.output_size,
120
+ tanfovx=self.tan_half_fov,
121
+ tanfovy=self.tan_half_fov,
122
+ bg=self.bg_color if bg_color is None else bg_color,
123
+ scale_modifier=scale_modifier,
124
+ viewmatrix=view_matrix,
125
+ projmatrix=view_proj_matrix,
126
+ sh_degree=0,
127
+ campos=campos,
128
+ prefiltered=False,
129
+ debug=False,
130
+ )
131
+
132
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
133
+
134
+ rendered_image, _, _, rendered_alpha = rasterizer(
135
+ means3D=means3D,
136
+ means2D=torch.zeros_like(
137
+ means3D, dtype=torch.float32, device=device
138
+ ),
139
+ shs=None,
140
+ colors_precomp=rgbs,
141
+ opacities=opacity,
142
+ scales=scales,
143
+ rotations=rotations,
144
+ cov3D_precomp=None,
145
+ )
146
+
147
+ rendered_image = rendered_image.clamp(0, 1)
148
+
149
+ images.append(rendered_image)
150
+ alphas.append(rendered_alpha)
151
+
152
+ images = torch.stack(images, dim=0).view(
153
+ B, V, 3, self.output_size, self.output_size
154
+ )
155
+ alphas = torch.stack(alphas, dim=0).view(
156
+ B, V, 1, self.output_size, self.output_size
157
+ )
158
+
159
+ return {"image": images, "alpha": alphas}
160
+
161
+ def save_ply(self, gaussians, path):
162
+ assert gaussians.shape[0] == 1, "only support batch size 1"
163
+
164
+ from plyfile import PlyData, PlyElement
165
+
166
+ means3D = gaussians[0, :, 0:3].contiguous().float()
167
+ opacity = gaussians[0, :, 3:4].contiguous().float()
168
+ scales = gaussians[0, :, 4:7].contiguous().float()
169
+ rotations = gaussians[0, :, 7:11].contiguous().float()
170
+ shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float()
171
+
172
+ mask = opacity.squeeze(-1) >= 0.005
173
+ means3D = means3D[mask]
174
+ opacity = opacity[mask]
175
+ scales = scales[mask]
176
+ rotations = rotations[mask]
177
+ shs = shs[mask]
178
+
179
+ opacity = opacity.clamp(1e-6, 1 - 1e-6)
180
+ opacity = torch.log(opacity / (1 - opacity))
181
+ scales = torch.log(scales + 1e-8)
182
+ shs = (shs - 0.5) / 0.28209479177387814
183
+
184
+ xyzs = means3D.detach().cpu().numpy()
185
+ f_dc = (
186
+ shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
187
+ )
188
+ opacities = opacity.detach().cpu().numpy()
189
+ scales = scales.detach().cpu().numpy()
190
+ rotations = rotations.detach().cpu().numpy()
191
+
192
+ h = ["x", "y", "z"]
193
+ for i in range(f_dc.shape[1]):
194
+ h.append("f_dc_{}".format(i))
195
+ h.append("opacity")
196
+ for i in range(scales.shape[1]):
197
+ h.append("scale_{}".format(i))
198
+ for i in range(rotations.shape[1]):
199
+ h.append("rot_{}".format(i))
200
+
201
+ dtype_full = [(attribute, "f4") for attribute in h]
202
+
203
+ elements = np.empty(xyzs.shape[0], dtype=dtype_full)
204
+ attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
205
+ elements[:] = list(map(tuple, attributes))
206
+ el = PlyElement.describe(elements, "vertex")
207
+
208
+ PlyData([el]).write(path)
209
+
210
+
211
+ class LGM(ModelMixin, ConfigMixin):
212
+ def __init__(self):
213
+ super().__init__()
214
+
215
+ self.input_size = 256
216
+ self.splat_size = 128
217
+ self.output_size = 512
218
+ self.radius = 1.5
219
+ self.fovy = 49.1
220
+
221
+ self.unet = UNet(
222
+ 9,
223
+ 14,
224
+ down_channels=(64, 128, 256, 512, 1024, 1024),
225
+ down_attention=(False, False, False, True, True, True),
226
+ mid_attention=True,
227
+ up_channels=(1024, 1024, 512, 256, 128),
228
+ up_attention=(True, True, True, False, False),
229
+ )
230
+
231
+ self.conv = nn.Conv2d(14, 14, kernel_size=1)
232
+ self.gs = GaussianRenderer(self.fovy, self.output_size)
233
+
234
+ self.pos_act = lambda x: x.clamp(-1, 1)
235
+ self.scale_act = lambda x: 0.1 * F.softplus(x)
236
+ self.opacity_act = lambda x: torch.sigmoid(x)
237
+ self.rot_act = F.normalize
238
+ self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
239
+
240
+ def prepare_default_rays(self, device, elevation=0):
241
+ # cam_poses = np.stack(
242
+ # [
243
+ # orbit_camera(elevation, 0, radius=self.radius),
244
+ # orbit_camera(elevation, 90, radius=self.radius),
245
+ # orbit_camera(elevation, 180, radius=self.radius),
246
+ # orbit_camera(elevation, 270, radius=self.radius),
247
+ # ],
248
+ # axis=0,
249
+ # )
250
+ angles = np.linspace(0, 360, self.views, endpoint=False)
251
+ cam_poses = np.stack(
252
+ [
253
+ orbit_camera(elevation, angle, radius=self.radius) for angle in angles
254
+ ],
255
+ axis=0
256
+ )
257
+ cam_poses = torch.from_numpy(cam_poses)
258
+
259
+ rays_embeddings = []
260
+ for i in range(cam_poses.shape[0]):
261
+ rays_o, rays_d = get_rays(
262
+ cam_poses[i], self.input_size, self.input_size, self.fovy
263
+ )
264
+ rays_plucker = torch.cat(
265
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1
266
+ )
267
+ rays_embeddings.append(rays_plucker)
268
+
269
+ rays_embeddings = (
270
+ torch.stack(rays_embeddings, dim=0)
271
+ .permute(0, 3, 1, 2)
272
+ .contiguous()
273
+ .to(device)
274
+ )
275
+
276
+ return rays_embeddings
277
+
278
+ def forward(self, images):
279
+ B, V, C, H, W = images.shape
280
+ images = images.view(B * V, C, H, W)
281
+
282
+ x = self.unet(images)
283
+ x = self.conv(x)
284
+
285
+ x = x.reshape(B, 4, 14, self.splat_size, self.splat_size)
286
+
287
+ x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
288
+
289
+ pos = self.pos_act(x[..., 0:3])
290
+ opacity = self.opacity_act(x[..., 3:4])
291
+ scale = self.scale_act(x[..., 4:7])
292
+ rotation = self.rot_act(x[..., 7:11])
293
+ rgbs = self.rgb_act(x[..., 11:])
294
+
295
+ q = torch.tensor([0, 0, 1, 0], dtype=pos.dtype, device=pos.device)
296
+ R = torch.tensor(
297
+ [
298
+ [-1, 0, 0],
299
+ [0, -1, 0],
300
+ [0, 0, 1],
301
+ ],
302
+ dtype=pos.dtype,
303
+ device=pos.device,
304
+ )
305
+
306
+ pos = torch.matmul(pos, R.T)
307
+
308
+ def multiply_quat(q1, q2):
309
+ w1, x1, y1, z1 = q1.unbind(-1)
310
+ w2, x2, y2, z2 = q2.unbind(-1)
311
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
312
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
313
+ y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
314
+ z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
315
+ return torch.stack([w, x, y, z], dim=-1)
316
+
317
+ for i in range(B):
318
+ rotation[i, :] = multiply_quat(q, rotation[i, :])
319
+
320
+ gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1)
321
+
322
+ return gaussians
323
+
324
+
325
+ # =============================================================================
326
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
327
+ #
328
+ # This source code is licensed under the Apache License, Version 2.0
329
+ # found in the LICENSE file in the root directory of this source tree.
330
+
331
+ # References:
332
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
333
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
334
+ # =============================================================================
335
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
336
+ try:
337
+ if XFORMERS_ENABLED:
338
+ from xformers.ops import memory_efficient_attention, unbind
339
+
340
+ XFORMERS_AVAILABLE = True
341
+ warnings.warn("xFormers is available (Attention)")
342
+ else:
343
+ warnings.warn("xFormers is disabled (Attention)")
344
+ raise ImportError
345
+ except ImportError:
346
+ XFORMERS_AVAILABLE = False
347
+ warnings.warn("xFormers is not available (Attention)")
348
+
349
+
350
+ class Attention(nn.Module):
351
+ def __init__(
352
+ self,
353
+ dim: int,
354
+ num_heads: int = 8,
355
+ qkv_bias: bool = False,
356
+ proj_bias: bool = True,
357
+ attn_drop: float = 0.0,
358
+ proj_drop: float = 0.0,
359
+ ) -> None:
360
+ super().__init__()
361
+ self.num_heads = num_heads
362
+ head_dim = dim // num_heads
363
+ self.scale = head_dim**-0.5
364
+
365
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
366
+ self.attn_drop = nn.Dropout(attn_drop)
367
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
368
+ self.proj_drop = nn.Dropout(proj_drop)
369
+
370
+ def forward(self, x: Tensor) -> Tensor:
371
+ B, N, C = x.shape
372
+ qkv = (
373
+ self.qkv(x)
374
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
375
+ .permute(2, 0, 3, 1, 4)
376
+ )
377
+
378
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
379
+ attn = q @ k.transpose(-2, -1)
380
+
381
+ attn = attn.softmax(dim=-1)
382
+ attn = self.attn_drop(attn)
383
+
384
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
385
+ x = self.proj(x)
386
+ x = self.proj_drop(x)
387
+ return x
388
+
389
+
390
+ class MemEffAttention(Attention):
391
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
392
+ if not XFORMERS_AVAILABLE:
393
+ if attn_bias is not None:
394
+ raise AssertionError("xFormers is required for using nested tensors")
395
+ return super().forward(x)
396
+
397
+ B, N, C = x.shape
398
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
399
+
400
+ q, k, v = unbind(qkv, 2)
401
+
402
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
403
+ x = x.reshape([B, N, C])
404
+
405
+ x = self.proj(x)
406
+ x = self.proj_drop(x)
407
+ return x
408
+
409
+
410
+ class CrossAttention(nn.Module):
411
+ def __init__(
412
+ self,
413
+ dim: int,
414
+ dim_q: int,
415
+ dim_k: int,
416
+ dim_v: int,
417
+ num_heads: int = 8,
418
+ qkv_bias: bool = False,
419
+ proj_bias: bool = True,
420
+ attn_drop: float = 0.0,
421
+ proj_drop: float = 0.0,
422
+ ) -> None:
423
+ super().__init__()
424
+ self.dim = dim
425
+ self.num_heads = num_heads
426
+ head_dim = dim // num_heads
427
+ self.scale = head_dim**-0.5
428
+
429
+ self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
430
+ self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
431
+ self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
432
+ self.attn_drop = nn.Dropout(attn_drop)
433
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
434
+ self.proj_drop = nn.Dropout(proj_drop)
435
+
436
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
437
+ B, N, _ = q.shape
438
+ M = k.shape[1]
439
+
440
+ q = self.scale * self.to_q(q).reshape(
441
+ B, N, self.num_heads, self.dim // self.num_heads
442
+ ).permute(0, 2, 1, 3)
443
+ k = (
444
+ self.to_k(k)
445
+ .reshape(B, M, self.num_heads, self.dim // self.num_heads)
446
+ .permute(0, 2, 1, 3)
447
+ )
448
+ v = (
449
+ self.to_v(v)
450
+ .reshape(B, M, self.num_heads, self.dim // self.num_heads)
451
+ .permute(0, 2, 1, 3)
452
+ )
453
+
454
+ attn = q @ k.transpose(-2, -1)
455
+
456
+ attn = attn.softmax(dim=-1)
457
+ attn = self.attn_drop(attn)
458
+
459
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
460
+ x = self.proj(x)
461
+ x = self.proj_drop(x)
462
+ return x
463
+
464
+
465
+ class MemEffCrossAttention(CrossAttention):
466
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
467
+ if not XFORMERS_AVAILABLE:
468
+ if attn_bias is not None:
469
+ raise AssertionError("xFormers is required for using nested tensors")
470
+ return super().forward(q, k, v)
471
+
472
+ B, N, _ = q.shape
473
+ M = k.shape[1]
474
+
475
+ q = self.scale * self.to_q(q).reshape(
476
+ B, N, self.num_heads, self.dim // self.num_heads
477
+ )
478
+ k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads)
479
+ v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads)
480
+
481
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
482
+ x = x.reshape(B, N, -1)
483
+
484
+ x = self.proj(x)
485
+ x = self.proj_drop(x)
486
+ return x
487
+
488
+
489
+ # =============================================================================
490
+ # End of xFormers
491
+
492
+
493
+ class MVAttention(nn.Module):
494
+ def __init__(
495
+ self,
496
+ dim: int,
497
+ num_heads: int = 8,
498
+ qkv_bias: bool = False,
499
+ proj_bias: bool = True,
500
+ attn_drop: float = 0.0,
501
+ proj_drop: float = 0.0,
502
+ groups: int = 32,
503
+ eps: float = 1e-5,
504
+ residual: bool = True,
505
+ skip_scale: float = 1,
506
+ num_frames: int = 4,
507
+ ):
508
+ super().__init__()
509
+
510
+ self.residual = residual
511
+ self.skip_scale = skip_scale
512
+ self.num_frames = num_frames
513
+
514
+ self.norm = nn.GroupNorm(
515
+ num_groups=groups, num_channels=dim, eps=eps, affine=True
516
+ )
517
+ self.attn = MemEffAttention(
518
+ dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop
519
+ )
520
+
521
+ def forward(self, x):
522
+ BV, C, H, W = x.shape
523
+ B = BV // self.num_frames
524
+
525
+ res = x
526
+ x = self.norm(x)
527
+
528
+ x = (
529
+ x.reshape(B, self.num_frames, C, H, W)
530
+ .permute(0, 1, 3, 4, 2)
531
+ .reshape(B, -1, C)
532
+ )
533
+ x = self.attn(x)
534
+ x = (
535
+ x.reshape(B, self.num_frames, H, W, C)
536
+ .permute(0, 1, 4, 2, 3)
537
+ .reshape(BV, C, H, W)
538
+ )
539
+
540
+ if self.residual:
541
+ x = (x + res) * self.skip_scale
542
+ return x
543
+
544
+
545
+ class ResnetBlock(nn.Module):
546
+ def __init__(
547
+ self,
548
+ in_channels: int,
549
+ out_channels: int,
550
+ resample: Literal["default", "up", "down"] = "default",
551
+ groups: int = 32,
552
+ eps: float = 1e-5,
553
+ skip_scale: float = 1,
554
+ ):
555
+ super().__init__()
556
+
557
+ self.in_channels = in_channels
558
+ self.out_channels = out_channels
559
+ self.skip_scale = skip_scale
560
+
561
+ self.norm1 = nn.GroupNorm(
562
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
563
+ )
564
+ self.conv1 = nn.Conv2d(
565
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
566
+ )
567
+
568
+ self.norm2 = nn.GroupNorm(
569
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
570
+ )
571
+ self.conv2 = nn.Conv2d(
572
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
573
+ )
574
+
575
+ self.act = F.silu
576
+
577
+ self.resample = None
578
+ if resample == "up":
579
+ self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
580
+ elif resample == "down":
581
+ self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
582
+
583
+ self.shortcut = nn.Identity()
584
+ if self.in_channels != self.out_channels:
585
+ self.shortcut = nn.Conv2d(
586
+ in_channels, out_channels, kernel_size=1, bias=True
587
+ )
588
+
589
+ def forward(self, x):
590
+ res = x
591
+ x = self.norm1(x)
592
+ x = self.act(x)
593
+ if self.resample:
594
+ res = self.resample(res)
595
+ x = self.resample(x)
596
+ x = self.conv1(x)
597
+ x = self.norm2(x)
598
+ x = self.act(x)
599
+ x = self.conv2(x)
600
+ x = (x + self.shortcut(res)) * self.skip_scale
601
+ return x
602
+
603
+
604
+ class DownBlock(nn.Module):
605
+ def __init__(
606
+ self,
607
+ in_channels: int,
608
+ out_channels: int,
609
+ num_layers: int = 1,
610
+ downsample: bool = True,
611
+ attention: bool = True,
612
+ attention_heads: int = 16,
613
+ skip_scale: float = 1,
614
+ ):
615
+ super().__init__()
616
+
617
+ nets = []
618
+ attns = []
619
+ for i in range(num_layers):
620
+ in_channels = in_channels if i == 0 else out_channels
621
+ nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
622
+ if attention:
623
+ attns.append(
624
+ MVAttention(out_channels, attention_heads, skip_scale=skip_scale)
625
+ )
626
+ else:
627
+ attns.append(None)
628
+ self.nets = nn.ModuleList(nets)
629
+ self.attns = nn.ModuleList(attns)
630
+
631
+ self.downsample = None
632
+ if downsample:
633
+ self.downsample = nn.Conv2d(
634
+ out_channels, out_channels, kernel_size=3, stride=2, padding=1
635
+ )
636
+
637
+ def forward(self, x):
638
+ xs = []
639
+ for attn, net in zip(self.attns, self.nets):
640
+ x = net(x)
641
+ if attn:
642
+ x = attn(x)
643
+ xs.append(x)
644
+ if self.downsample:
645
+ x = self.downsample(x)
646
+ xs.append(x)
647
+ return x, xs
648
+
649
+
650
+ class MidBlock(nn.Module):
651
+ def __init__(
652
+ self,
653
+ in_channels: int,
654
+ num_layers: int = 1,
655
+ attention: bool = True,
656
+ attention_heads: int = 16,
657
+ skip_scale: float = 1,
658
+ ):
659
+ super().__init__()
660
+
661
+ nets = []
662
+ attns = []
663
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
664
+ for _ in range(num_layers):
665
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
666
+ if attention:
667
+ attns.append(
668
+ MVAttention(in_channels, attention_heads, skip_scale=skip_scale)
669
+ )
670
+ else:
671
+ attns.append(None)
672
+ self.nets = nn.ModuleList(nets)
673
+ self.attns = nn.ModuleList(attns)
674
+
675
+ def forward(self, x):
676
+ x = self.nets[0](x)
677
+ for attn, net in zip(self.attns, self.nets[1:]):
678
+ if attn:
679
+ x = attn(x)
680
+ x = net(x)
681
+ return x
682
+
683
+
684
+ class UpBlock(nn.Module):
685
+ def __init__(
686
+ self,
687
+ in_channels: int,
688
+ prev_out_channels: int,
689
+ out_channels: int,
690
+ num_layers: int = 1,
691
+ upsample: bool = True,
692
+ attention: bool = True,
693
+ attention_heads: int = 16,
694
+ skip_scale: float = 1,
695
+ ):
696
+ super().__init__()
697
+
698
+ nets = []
699
+ attns = []
700
+ for i in range(num_layers):
701
+ cin = in_channels if i == 0 else out_channels
702
+ cskip = prev_out_channels if (i == num_layers - 1) else out_channels
703
+
704
+ nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
705
+ if attention:
706
+ attns.append(
707
+ MVAttention(out_channels, attention_heads, skip_scale=skip_scale)
708
+ )
709
+ else:
710
+ attns.append(None)
711
+ self.nets = nn.ModuleList(nets)
712
+ self.attns = nn.ModuleList(attns)
713
+
714
+ self.upsample = None
715
+ if upsample:
716
+ self.upsample = nn.Conv2d(
717
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
718
+ )
719
+
720
+ def forward(self, x, xs):
721
+ for attn, net in zip(self.attns, self.nets):
722
+ res_x = xs[-1]
723
+ xs = xs[:-1]
724
+ x = torch.cat([x, res_x], dim=1)
725
+ x = net(x)
726
+ if attn:
727
+ x = attn(x)
728
+ if self.upsample:
729
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
730
+ x = self.upsample(x)
731
+ return x
732
+
733
+
734
+ class UNet(nn.Module):
735
+ def __init__(
736
+ self,
737
+ in_channels: int = 9,
738
+ out_channels: int = 14,
739
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024),
740
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True, True),
741
+ mid_attention: bool = True,
742
+ up_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
743
+ up_attention: Tuple[bool, ...] = (True, True, True, False, False),
744
+ layers_per_block: int = 2,
745
+ skip_scale: float = np.sqrt(0.5),
746
+ ):
747
+ super().__init__()
748
+
749
+ self.conv_in = nn.Conv2d(
750
+ in_channels, down_channels[0], kernel_size=3, stride=1, padding=1
751
+ )
752
+
753
+ down_blocks = []
754
+ cout = down_channels[0]
755
+ for i in range(len(down_channels)):
756
+ cin = cout
757
+ cout = down_channels[i]
758
+
759
+ down_blocks.append(
760
+ DownBlock(
761
+ cin,
762
+ cout,
763
+ num_layers=layers_per_block,
764
+ downsample=(i != len(down_channels) - 1),
765
+ attention=down_attention[i],
766
+ skip_scale=skip_scale,
767
+ )
768
+ )
769
+ self.down_blocks = nn.ModuleList(down_blocks)
770
+
771
+ self.mid_block = MidBlock(
772
+ down_channels[-1], attention=mid_attention, skip_scale=skip_scale
773
+ )
774
+
775
+ up_blocks = []
776
+ cout = up_channels[0]
777
+ for i in range(len(up_channels)):
778
+ cin = cout
779
+ cout = up_channels[i]
780
+ cskip = down_channels[max(-2 - i, -len(down_channels))]
781
+
782
+ up_blocks.append(
783
+ UpBlock(
784
+ cin,
785
+ cskip,
786
+ cout,
787
+ num_layers=layers_per_block + 1,
788
+ upsample=(i != len(up_channels) - 1),
789
+ attention=up_attention[i],
790
+ skip_scale=skip_scale,
791
+ )
792
+ )
793
+ self.up_blocks = nn.ModuleList(up_blocks)
794
+ self.norm_out = nn.GroupNorm(
795
+ num_channels=up_channels[-1], num_groups=32, eps=1e-5
796
+ )
797
+ self.conv_out = nn.Conv2d(
798
+ up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1
799
+ )
800
+
801
+ def forward(self, x):
802
+ x = self.conv_in(x)
803
+ xss = [x]
804
+ for block in self.down_blocks:
805
+ x, xs = block(x)
806
+ xss.extend(xs)
807
+ x = self.mid_block(x)
808
+ for block in self.up_blocks:
809
+ xs = xss[-len(block.nets) :]
810
+ xss = xss[: -len(block.nets)]
811
+ x = block(x, xs)
812
+ x = self.norm_out(x)
813
+ x = F.silu(x)
814
+ x = self.conv_out(x)
815
+ return x