Safetensors
long_vita
custom_code
File size: 2,670 Bytes
0afdfbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn

import math



class ResamplerProjector(nn.Module):
    def __init__(self, config, vision_model_config):
        super().__init__()
        self.hw = vision_model_config.image_size // vision_model_config.patch_size

        self.vision_downsample_ratio = 0.5
        proj_input_size = vision_model_config.hidden_size * int(1 / self.vision_downsample_ratio) ** 2

        self.pre_proj_layernorm = torch.nn.LayerNorm(proj_input_size)

        self.mlp = nn.Sequential(
            nn.Linear(proj_input_size, vision_model_config.hidden_size, bias=False),
            nn.GELU(),
            nn.Linear(vision_model_config.hidden_size, config.hidden_size, bias=False),
        )
        self.mlp.apply(init_weights)
        self.pre_proj_layernorm.apply(init_weights)

    def forward(self, x, *args, **kwargs):
        x = x.reshape(x.shape[0], self.hw, self.hw, -1)
        x = pixel_shuffle(x, scale_factor=self.vision_downsample_ratio)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        x = self.pre_proj_layernorm(x)
        x = self.mlp(x)
        # print(torch.distributed.get_rank(), {name: [param, param.grad] for name, param in self.pre_proj_layernorm.named_parameters()})
        # print(torch.distributed.get_rank(), {name: [param, param.grad] for name, param in self.mlp.named_parameters()})
        return x

def pixel_shuffle(x, scale_factor=0.5):
    n, w, h, c = x.size()
    # N, W, H, C --> N, W, H * scale, C // scale
    x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
    # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
    x = x.permute(0, 2, 1, 3).contiguous()
    # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
    x = x.view(n, int(h * scale_factor), int(w * scale_factor),
               int(c / (scale_factor * scale_factor)))
    x = x.permute(0, 2, 1, 3).contiguous()
    return x

def pixel_shuffle_v2(x, scale_stride=2):
    n, w, h, c = x.size()
    assert w == h
    pl = (scale_stride - (h % scale_stride)) % scale_stride
    x = torch.nn.functional.pad(x, (0, 0, 0, pl, 0, pl), "constant", 0)
    h += pl
    w += pl

    x = x.reshape(n, w // scale_stride, scale_stride, h // scale_stride, scale_stride, c)
    x = x.permute(0, 1, 3, 2, 4, 5) 
    x = x.flatten(3)
    x = x.reshape(n, -1, scale_stride * scale_stride * c)
    return x

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

    if isinstance(m, nn.LayerNorm):
        torch.nn.init.ones_(m.weight)
        torch.nn.init.zeros_(m.bias)