File size: 6,098 Bytes
50eec37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device


class Dino2AttentionOutput(torch.nn.Module):
    def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
        super().__init__()
        self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)

    def forward(self, x):
        return self.dense(x)


class Dino2AttentionBlock(torch.nn.Module):
    def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
        super().__init__()
        self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
        self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)

    def forward(self, x, mask, optimized_attention):
        return self.output(self.attention(x, mask, optimized_attention))


class LayerScale(torch.nn.Module):
    def __init__(self, dim, dtype, device, operations):
        super().__init__()
        self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))

    def forward(self, x):
        return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)


class SwiGLUFFN(torch.nn.Module):
    def __init__(self, dim, dtype, device, operations):
        super().__init__()
        in_features = out_features = dim
        hidden_features = int(dim * 4)
        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8

        self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
        self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)

    def forward(self, x):
        x = self.weights_in(x)
        x1, x2 = x.chunk(2, dim=-1)
        x = torch.nn.functional.silu(x1) * x2
        return self.weights_out(x)


class Dino2Block(torch.nn.Module):
    def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
        super().__init__()
        self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
        self.layer_scale1 = LayerScale(dim, dtype, device, operations)
        self.layer_scale2 = LayerScale(dim, dtype, device, operations)
        self.mlp = SwiGLUFFN(dim, dtype, device, operations)
        self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
        self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)

    def forward(self, x, optimized_attention):
        x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
        x = x + self.layer_scale2(self.mlp(self.norm2(x)))
        return x


class Dino2Encoder(torch.nn.Module):
    def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
        super().__init__()
        self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])

    def forward(self, x, intermediate_output=None):
        optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)

        if intermediate_output is not None:
            if intermediate_output < 0:
                intermediate_output = len(self.layer) + intermediate_output

        intermediate = None
        for i, l in enumerate(self.layer):
            x = l(x, optimized_attention)
            if i == intermediate_output:
                intermediate = x.clone()
        return x, intermediate


class Dino2PatchEmbeddings(torch.nn.Module):
    def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
        super().__init__()
        self.projection = operations.Conv2d(
            in_channels=num_channels,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size,
            bias=True,
            dtype=dtype,
            device=device
        )

    def forward(self, pixel_values):
        return self.projection(pixel_values).flatten(2).transpose(1, 2)


class Dino2Embeddings(torch.nn.Module):
    def __init__(self, dim, dtype, device, operations):
        super().__init__()
        patch_size = 14
        image_size = 518

        self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
        self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
        self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
        self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))

    def forward(self, pixel_values):
        x = self.patch_embeddings(pixel_values)
        # TODO: mask_token?
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
        return x


class Dinov2Model(torch.nn.Module):
    def __init__(self, config_dict, dtype, device, operations):
        super().__init__()
        num_layers = config_dict["num_hidden_layers"]
        dim = config_dict["hidden_size"]
        heads = config_dict["num_attention_heads"]
        layer_norm_eps = config_dict["layer_norm_eps"]

        self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
        self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
        self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)

    def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
        x = self.embeddings(pixel_values)
        x, i = self.encoder(x, intermediate_output=intermediate_output)
        x = self.layernorm(x)
        pooled_output = x[:, 0, :]
        return x, i, pooled_output, None