Spaces:
Running
on
Zero
Running
on
Zero
""" | |
----------------------------------------------------------------------------- | |
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
NVIDIA CORPORATION and its licensors retain all intellectual property | |
and proprietary rights in and to this software, related documentation | |
and any modifications thereto. Any use, reproduction, disclosure or | |
distribution of this software and related documentation without an express | |
license agreement from NVIDIA CORPORATION is strictly prohibited. | |
----------------------------------------------------------------------------- | |
""" | |
import torch.nn as nn | |
from torch.utils.checkpoint import checkpoint | |
from vae.modules.attention import CrossAttention, SelfAttention | |
class FeedForward(nn.Module): | |
def __init__(self, dim, mult=4): | |
super().__init__() | |
self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim)) | |
def forward(self, x): | |
return self.net(x) | |
class AttentionBlock(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
dim_context=None, | |
qknorm=False, | |
gradient_checkpointing=True, | |
qknorm_type="LayerNorm", | |
): | |
super().__init__() | |
self.dim = dim | |
self.num_heads = num_heads | |
self.dim_context = dim_context | |
self.gradient_checkpointing = gradient_checkpointing | |
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) | |
if dim_context is not None: | |
self.norm_context = nn.LayerNorm(dim_context, eps=1e-6, elementwise_affine=False) | |
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type) | |
else: | |
self.attn = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type) | |
self.norm_ff = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) | |
self.ff = FeedForward(dim) | |
def forward(self, x, c=None, mask=None, mask_c=None): | |
if self.training and self.gradient_checkpointing: | |
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False) | |
else: | |
return self._forward(x, c, mask, mask_c) | |
def _forward(self, x, c=None, mask=None, mask_c=None): | |
# x: [B, N, C], hidden states | |
# c: [B, M, C'], condition (assume normed and projected to C) | |
# mask: [B, N], mask for x | |
# mask_c: [B, M], mask for c | |
# return: [B, N, C], updated hidden states | |
if c is not None: | |
x = x + self.attn(self.norm_attn(x), self.norm_context(c), mask_q=mask, mask_kv=mask_c) | |
else: | |
x = x + self.attn(self.norm_attn(x), mask=mask) | |
x = x + self.ff(self.norm_ff(x)) | |
return x | |
# special attention block for the last cross-attn query layer | |
# 1. simple feed-forward (mult=1, no post ln) | |
# 2. no residual connection | |
# 3. no context ln | |
class FlashQueryLayer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
dim_context, | |
qknorm=False, | |
gradient_checkpointing=True, | |
qknorm_type="LayerNorm", | |
): | |
super().__init__() | |
self.dim = dim | |
self.num_heads = num_heads | |
self.dim_context = dim_context | |
self.gradient_checkpointing = gradient_checkpointing | |
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) | |
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type) | |
self.ff = FeedForward(dim, mult=1) | |
def forward(self, x, c=None, mask=None, mask_c=None): | |
if self.training and self.gradient_checkpointing: | |
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False) | |
else: | |
return self._forward(x, c, mask, mask_c) | |
def _forward(self, x, c, mask=None, mask_c=None): | |
# x: [B, N, C], hidden states | |
# c: [B, M, C'], condition (assume normed and projected to C) | |
# mask: [B, N], mask for x | |
# mask_c: [B, M], mask for c | |
# return: [B, N, C], updated hidden states | |
x = self.attn(self.norm_attn(x), c, mask_q=mask, mask_kv=mask_c) | |
x = self.ff(x) | |
return x | |