File size: 6,364 Bytes
b443c25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor, nn

import transformers
from transformers import SamProcessor
from transformers import SamModel, SamVisionConfig, SamVisionConfig
from transformers import SamImageProcessor
from PIL import Image


# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
class SamLayerNorm(nn.Module):
    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.

    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,

    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).

    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError(f"Unsupported data format: {self.data_format}")
        self.normalized_shape = (normalized_shape,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.data_format == "channels_last":
            x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            input_dtype = x.dtype
            x = x.float()
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = x.to(dtype=input_dtype)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x



class ShortSamVisionNeck(nn.Module):
    def __init__(self, config: SamVisionConfig):
        super().__init__()
        self.config = config

        self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
        self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")

    def forward(self, hidden_states):
        hidden_states = hidden_states.permute(0, 3, 1, 2)
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = hidden_states.permute(0,2,3,1)
        return hidden_states


class SAMVisionTower(nn.Module):
    def __init__(self, vision_tower, args):
        super().__init__()

        self.args = args
        self.is_loaded = False
        self.vision_tower_name = vision_tower
        self.input_image_size = args.input_image_size
        self.pixel_shuffle = getattr(args, 'add_pixel_shuffle', False)

        self.freeze = args.freeze_vision

        self.load_model()

    def load_model(self):
        if self.is_loaded:
            return

        self.image_processor= SamProcessor.from_pretrained("facebook/sam-vit-large")
        sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder
        sam_model.neck = ShortSamVisionNeck(sam_model.config)
        self.image_processor.preprocess = self.image_processor.__call__
        self.image_processor.image_mean = [0.485,0.456,0.406]
        self.vision_tower = sam_model
        
        if self.freeze:
            self.vision_tower.requires_grad_(False)
            
        self.is_loaded = True

        
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_feature = self.vision_tower(image.to(device=self.device).unsqueeze(0))
                image_features.append(image_feature)
        else:
            image_features = self.vision_tower(images.to(device=self.device)).last_hidden_state.flatten(start_dim=1, end_dim=2).to(device=self.device)

        if self.pixel_shuffle:
            b, n, c = image_features.shape
            h = w = int(n ** 0.5)
            image_features = image_features.transpose(1,2).reshape(b, c, h, w) 
            image_features = nn.functional.pixel_unshuffle(image_features, 2)

        return image_features
    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return next(self.vision_tower.parameters()).dtype

    @property
    def device(self):
        return next(self.vision_tower.parameters()).device

    @property
    def config(self):
        # if self.is_loaded:
        #     return self.vision_tower.config
        # else:
        #     return self.cfg_only
        config_info = SamVisionConfig()
        return SamVisionConfig()

    @property
    def hidden_size(self):
        #return self.config.hidden_size
        if self.pixel_shuffle:
            hidden_size = 256 * 4
        else:
            hidden_size = 256
        return hidden_size

    @property
    def num_patches(self):
        # return (self.config.image_size // self.config.patch_size) ** 2
        return self.config.num_patches


#main
if __name__ == "__main__":
    sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder
    #sam_model = SamModel.from_pretrained("facebook/sam-vit-large")
    sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
    for name, param in sam_model.named_parameters():
        param.requires_grad = False

    #raw_image = torch.rand(1, 3, 224, 224).to('cuda')
    raw_image = Image.open('/lustre/fsw/portfolios/llmservice/users/fuxiaol/image/me.jpg').convert("RGB")
    inputs = sam_processor(raw_image, return_tensors="pt")
    #print(inputs)
    #print(inputs['pixel_values'])
    out = sam_model(inputs['pixel_values'])

    print(out[0].size())
    #vision_config = SamVisionConfig()
    #print('=============')
    #print(vision_config.hidden_size)
    #print('=============')
    #print(out)


    #print(out)
    #print(out)
    #config_vision
    #vision_config = SamVisionConfig()
    #print(sam_model.layers)
    #print(vision_config)