File size: 3,258 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn

from .processor import Blip2ImageTrainProcessor
from .eva_vit import create_eva_vit_g


class EvaClipVisionTower(nn.Module):

    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()
        self.is_loaded = False
        self.vision_tower_name = vision_tower
        # self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
        self.args = args

        if not delay_load:
            self.load_model()

        # self.is_loaded = True


    def load_model(self, device_map=None):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return
        
        if not hasattr(self.args, 'dynamic_resolution'):
            dynamic_resolution = None
        else:
            dynamic_resolution = self.args.dynamic_resolution


        if (not hasattr(self.args, 'freeze_vision_encoder')) or self.args.freeze_vision_encoder:
            use_checkpoint = False
        else:
            use_checkpoint = True
            assert self.args.vit_precision == 'fp32',  'if the vision encoder is training, the type needs to be fp32'
                    
        
        self.image_processor = Blip2ImageTrainProcessor(
            image_size=self.args.img_size,
            dynamic_resolution= dynamic_resolution
        )
        self.vision_tower = create_eva_vit_g(
            img_size=self.args.img_size,
            drop_path_rate=self.args.drop_path_rate,
            precision=self.args.vit_precision,
            vit_model_path=self.args.vit_model_path,
            use_checkpoint=use_checkpoint
        )

        # self.vision_tower.requires_grad_(False)

        self.is_loaded = True


    def feature_select(self, image_features):
        if self.select_feature == 'patch':
            image_features = image_features[:, 1:]
        elif self.select_feature == 'cls_patch':
            image_features = image_features
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features

    # @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_forward_out = self.vision_tower(image.unsqueeze(0))
                image_features.append(self.feature_select(image_forward_out).to(image.dtype))

        else:
            image_features = self.vision_tower(images.to(dtype=self.dtype))
            image_features = self.feature_select(image_features).to(images.dtype)

        return image_features

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, dtype=torch.float)

    @property
    def hidden_size(self):
        return self.vision_tower.hidden_size

    @property
    def num_patches(self):
        return (self.vision_tower.image_size // self.vision_tower.patch_size) ** 2
    
    @property
    def num_patches_per_side(self):
        return (self.vision_tower.image_size // self.vision_tower.patch_size)

    @property
    def dtype(self):
        return self.vision_tower.pos_embed.dtype