PartPacker / vae /modules /transformer.py
ashawkey's picture
init
daa6779
raw
history blame
4.21 kB
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from vae.modules.attention import CrossAttention, SelfAttention
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim))
def forward(self, x):
return self.net(x)
class AttentionBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
dim_context=None,
qknorm=False,
gradient_checkpointing=True,
qknorm_type="LayerNorm",
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.dim_context = dim_context
self.gradient_checkpointing = gradient_checkpointing
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
if dim_context is not None:
self.norm_context = nn.LayerNorm(dim_context, eps=1e-6, elementwise_affine=False)
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
else:
self.attn = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type)
self.norm_ff = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
self.ff = FeedForward(dim)
def forward(self, x, c=None, mask=None, mask_c=None):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
else:
return self._forward(x, c, mask, mask_c)
def _forward(self, x, c=None, mask=None, mask_c=None):
# x: [B, N, C], hidden states
# c: [B, M, C'], condition (assume normed and projected to C)
# mask: [B, N], mask for x
# mask_c: [B, M], mask for c
# return: [B, N, C], updated hidden states
if c is not None:
x = x + self.attn(self.norm_attn(x), self.norm_context(c), mask_q=mask, mask_kv=mask_c)
else:
x = x + self.attn(self.norm_attn(x), mask=mask)
x = x + self.ff(self.norm_ff(x))
return x
# special attention block for the last cross-attn query layer
# 1. simple feed-forward (mult=1, no post ln)
# 2. no residual connection
# 3. no context ln
class FlashQueryLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
dim_context,
qknorm=False,
gradient_checkpointing=True,
qknorm_type="LayerNorm",
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.dim_context = dim_context
self.gradient_checkpointing = gradient_checkpointing
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
self.ff = FeedForward(dim, mult=1)
def forward(self, x, c=None, mask=None, mask_c=None):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
else:
return self._forward(x, c, mask, mask_c)
def _forward(self, x, c, mask=None, mask_c=None):
# x: [B, N, C], hidden states
# c: [B, M, C'], condition (assume normed and projected to C)
# mask: [B, N], mask for x
# mask_c: [B, M], mask for c
# return: [B, N, C], updated hidden states
x = self.attn(self.norm_attn(x), c, mask_q=mask, mask_kv=mask_c)
x = self.ff(x)
return x