Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,211 Bytes
daa6779 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from vae.modules.attention import CrossAttention, SelfAttention
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim))
def forward(self, x):
return self.net(x)
class AttentionBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
dim_context=None,
qknorm=False,
gradient_checkpointing=True,
qknorm_type="LayerNorm",
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.dim_context = dim_context
self.gradient_checkpointing = gradient_checkpointing
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
if dim_context is not None:
self.norm_context = nn.LayerNorm(dim_context, eps=1e-6, elementwise_affine=False)
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
else:
self.attn = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type)
self.norm_ff = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
self.ff = FeedForward(dim)
def forward(self, x, c=None, mask=None, mask_c=None):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
else:
return self._forward(x, c, mask, mask_c)
def _forward(self, x, c=None, mask=None, mask_c=None):
# x: [B, N, C], hidden states
# c: [B, M, C'], condition (assume normed and projected to C)
# mask: [B, N], mask for x
# mask_c: [B, M], mask for c
# return: [B, N, C], updated hidden states
if c is not None:
x = x + self.attn(self.norm_attn(x), self.norm_context(c), mask_q=mask, mask_kv=mask_c)
else:
x = x + self.attn(self.norm_attn(x), mask=mask)
x = x + self.ff(self.norm_ff(x))
return x
# special attention block for the last cross-attn query layer
# 1. simple feed-forward (mult=1, no post ln)
# 2. no residual connection
# 3. no context ln
class FlashQueryLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
dim_context,
qknorm=False,
gradient_checkpointing=True,
qknorm_type="LayerNorm",
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.dim_context = dim_context
self.gradient_checkpointing = gradient_checkpointing
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
self.ff = FeedForward(dim, mult=1)
def forward(self, x, c=None, mask=None, mask_c=None):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
else:
return self._forward(x, c, mask, mask_c)
def _forward(self, x, c, mask=None, mask_c=None):
# x: [B, N, C], hidden states
# c: [B, M, C'], condition (assume normed and projected to C)
# mask: [B, N], mask for x
# mask_c: [B, M], mask for c
# return: [B, N, C], updated hidden states
x = self.attn(self.norm_attn(x), c, mask_q=mask, mask_kv=mask_c)
x = self.ff(x)
return x
|