English
File size: 5,024 Bytes
0467378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from mmseg.models.builder import BACKBONES, MODELS
from torch import nn as nn
from .cloud_adapter import CloudAdapter
from .dino_v2 import DinoVisionTransformer
from .utils import set_requires_grad, set_train
import torch
import torch.nn.functional as F


@BACKBONES.register_module()
class CloudAdapterDinoVisionTransformer(DinoVisionTransformer):
    def __init__(
        self,
        cloud_adapter_config=None,
        has_cat=False,
        # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, ],
        adapter_index=[0, 6, 12, 18],  # Transformer Block 的索引
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.cloud_adapter: CloudAdapter = MODELS.build(cloud_adapter_config)
        self.has_cat = has_cat
        self.adapter_index = adapter_index

    def forward_features(self, x, masks=None):
        B, _, h, w = x.shape
        cache = self.cloud_adapter.cnn(x)  # 得到多尺度特征或者单个特征
        H, W = h // self.patch_size, w // self.patch_size
        x = self.prepare_tokens_with_masks(x, masks)
        outs = []
        cur_idx = 0  # 交互模块的索引
        for idx, blk in enumerate(self.blocks):
            x = blk(x)
            if idx in self.adapter_index:
                x = self.cloud_adapter.forward(
                    x,
                    cur_idx,
                    batch_first=True,
                    has_cls_token=True,
                    cache=cache,
                )
                cur_idx += 1
            if idx in self.out_indices:
                outs.append(
                    x[:, 1:, :].permute(0, 2, 1).reshape(
                        B, -1, H, W).contiguous()
                )
        return outs, cache

    def process_cache(self,ret,cache):
        cache = F.interpolate(
            cache,size=(ret.shape[-2],ret.shape[-1]),mode="bilinear",align_corners=False)
        return cache

    def forward(self, *args, **kwargs):
        ret, cache = self.forward_features(*args, **kwargs)
        if isinstance(ret[0], torch.Tensor):
            ret[0] = F.interpolate(
                ret[0], scale_factor=4, mode="bilinear", align_corners=False
            )
            ret[1] = F.interpolate(
                ret[1], scale_factor=2, mode="bilinear", align_corners=False
            )
            ret[3] = F.interpolate(
                ret[3], scale_factor=0.5, mode="bilinear", align_corners=False
            )
            if self.has_cat:
                if isinstance(cache,tuple) or isinstance(cache,list):
                    ret[0] = torch.cat((ret[0], cache[0]), dim=1)
                    ret[1] = torch.cat((ret[1], cache[1]), dim=1)
                    ret[2] = torch.cat((ret[2], cache[2]), dim=1)
                    ret[3] = torch.cat((ret[3], cache[3]), dim=1)
                else:
                    ret[0] = torch.cat((ret[0], self.process_cache(ret[0],cache)), dim=1)
                    ret[1] = torch.cat((ret[1], self.process_cache(ret[1],cache)), dim=1)
                    ret[2] = torch.cat((ret[2], self.process_cache(ret[2],cache)), dim=1)
                    ret[3] = torch.cat((ret[3], self.process_cache(ret[3],cache)), dim=1)
            # ret[0] = torch.cat(ret[0], cache[0], dim=1) # bs 1024 128 128, bs 256 128 128
        else:
            ret[0][0] = F.interpolate(
                ret[0][0], scale_factor=4, mode="bilinear", align_corners=False
            )
            ret[0][1] = F.interpolate(
                ret[0][1], scale_factor=2, mode="bilinear", align_corners=False
            )
            ret[0][3] = F.interpolate(
                ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False
            )
            if self.has_cat:
                if isinstance(cache,tuple) or isinstance(cache,list):
                    ret[0][0] = torch.cat((ret[0][0], cache[0]), dim=1)
                    ret[0][1] = torch.cat((ret[0][1], cache[1]), dim=1)
                    ret[0][2] = torch.cat((ret[0][2], cache[2]), dim=1)
                    ret[0][3] = torch.cat((ret[0][3], cache[3]), dim=1)
                else:
                    ret[0][0] = torch.cat((ret[0][0], self.process_cache(ret[0][0],cache)), dim=1)
                    ret[0][1] = torch.cat((ret[0][1], self.process_cache(ret[0][1],cache)), dim=1)
                    ret[0][2] = torch.cat((ret[0][2], self.process_cache(ret[0][2],cache)), dim=1)
                    ret[0][3] = torch.cat((ret[0][3], self.process_cache(ret[0][3],cache)), dim=1)
        return ret

    def train(self, mode: bool = True):
        if not mode:
            return super().train(mode)
        set_requires_grad(self, ["cloud_adapter"])
        set_train(self, ["cloud_adapter"])

    def state_dict(self, destination, prefix, keep_vars):
        state = super().state_dict(destination, prefix, keep_vars)
        keys = [k for k in state.keys() if "cloud_adapter" not in k]
        for key in keys:
            state.pop(key)
            if key in destination:
                destination.pop(key)
        return state