File size: 4,211 Bytes
daa6779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""

import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from vae.modules.attention import CrossAttention, SelfAttention


class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim))

    def forward(self, x):
        return self.net(x)


class AttentionBlock(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        dim_context=None,
        qknorm=False,
        gradient_checkpointing=True,
        qknorm_type="LayerNorm",
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.dim_context = dim_context
        self.gradient_checkpointing = gradient_checkpointing

        self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
        if dim_context is not None:
            self.norm_context = nn.LayerNorm(dim_context, eps=1e-6, elementwise_affine=False)
            self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
        else:
            self.attn = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type)

        self.norm_ff = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
        self.ff = FeedForward(dim)

    def forward(self, x, c=None, mask=None, mask_c=None):
        if self.training and self.gradient_checkpointing:
            return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
        else:
            return self._forward(x, c, mask, mask_c)

    def _forward(self, x, c=None, mask=None, mask_c=None):
        # x: [B, N, C], hidden states
        # c: [B, M, C'], condition (assume normed and projected to C)
        # mask: [B, N], mask for x
        # mask_c: [B, M], mask for c
        # return: [B, N, C], updated hidden states

        if c is not None:
            x = x + self.attn(self.norm_attn(x), self.norm_context(c), mask_q=mask, mask_kv=mask_c)
        else:
            x = x + self.attn(self.norm_attn(x), mask=mask)

        x = x + self.ff(self.norm_ff(x))

        return x


# special attention block for the last cross-attn query layer
# 1. simple feed-forward (mult=1, no post ln)
# 2. no residual connection
# 3. no context ln
class FlashQueryLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        dim_context,
        qknorm=False,
        gradient_checkpointing=True,
        qknorm_type="LayerNorm",
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.dim_context = dim_context
        self.gradient_checkpointing = gradient_checkpointing

        self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
        self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
        self.ff = FeedForward(dim, mult=1)

    def forward(self, x, c=None, mask=None, mask_c=None):
        if self.training and self.gradient_checkpointing:
            return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
        else:
            return self._forward(x, c, mask, mask_c)

    def _forward(self, x, c, mask=None, mask_c=None):
        # x: [B, N, C], hidden states
        # c: [B, M, C'], condition (assume normed and projected to C)
        # mask: [B, N], mask for x
        # mask_c: [B, M], mask for c
        # return: [B, N, C], updated hidden states

        x = self.attn(self.norm_attn(x), c, mask_q=mask, mask_kv=mask_c)
        x = self.ff(x)

        return x