English
cloud-adapter-models / cloud_adapter /cloud_adapter_dinov2.py
XavierJiezou's picture
Add files using upload-large-folder tool
0467378 verified
raw
history blame
5.02 kB
from mmseg.models.builder import BACKBONES, MODELS
from torch import nn as nn
from .cloud_adapter import CloudAdapter
from .dino_v2 import DinoVisionTransformer
from .utils import set_requires_grad, set_train
import torch
import torch.nn.functional as F
@BACKBONES.register_module()
class CloudAdapterDinoVisionTransformer(DinoVisionTransformer):
def __init__(
self,
cloud_adapter_config=None,
has_cat=False,
# [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, ],
adapter_index=[0, 6, 12, 18], # Transformer Block 的索引
**kwargs,
):
super().__init__(**kwargs)
self.cloud_adapter: CloudAdapter = MODELS.build(cloud_adapter_config)
self.has_cat = has_cat
self.adapter_index = adapter_index
def forward_features(self, x, masks=None):
B, _, h, w = x.shape
cache = self.cloud_adapter.cnn(x) # 得到多尺度特征或者单个特征
H, W = h // self.patch_size, w // self.patch_size
x = self.prepare_tokens_with_masks(x, masks)
outs = []
cur_idx = 0 # 交互模块的索引
for idx, blk in enumerate(self.blocks):
x = blk(x)
if idx in self.adapter_index:
x = self.cloud_adapter.forward(
x,
cur_idx,
batch_first=True,
has_cls_token=True,
cache=cache,
)
cur_idx += 1
if idx in self.out_indices:
outs.append(
x[:, 1:, :].permute(0, 2, 1).reshape(
B, -1, H, W).contiguous()
)
return outs, cache
def process_cache(self,ret,cache):
cache = F.interpolate(
cache,size=(ret.shape[-2],ret.shape[-1]),mode="bilinear",align_corners=False)
return cache
def forward(self, *args, **kwargs):
ret, cache = self.forward_features(*args, **kwargs)
if isinstance(ret[0], torch.Tensor):
ret[0] = F.interpolate(
ret[0], scale_factor=4, mode="bilinear", align_corners=False
)
ret[1] = F.interpolate(
ret[1], scale_factor=2, mode="bilinear", align_corners=False
)
ret[3] = F.interpolate(
ret[3], scale_factor=0.5, mode="bilinear", align_corners=False
)
if self.has_cat:
if isinstance(cache,tuple) or isinstance(cache,list):
ret[0] = torch.cat((ret[0], cache[0]), dim=1)
ret[1] = torch.cat((ret[1], cache[1]), dim=1)
ret[2] = torch.cat((ret[2], cache[2]), dim=1)
ret[3] = torch.cat((ret[3], cache[3]), dim=1)
else:
ret[0] = torch.cat((ret[0], self.process_cache(ret[0],cache)), dim=1)
ret[1] = torch.cat((ret[1], self.process_cache(ret[1],cache)), dim=1)
ret[2] = torch.cat((ret[2], self.process_cache(ret[2],cache)), dim=1)
ret[3] = torch.cat((ret[3], self.process_cache(ret[3],cache)), dim=1)
# ret[0] = torch.cat(ret[0], cache[0], dim=1) # bs 1024 128 128, bs 256 128 128
else:
ret[0][0] = F.interpolate(
ret[0][0], scale_factor=4, mode="bilinear", align_corners=False
)
ret[0][1] = F.interpolate(
ret[0][1], scale_factor=2, mode="bilinear", align_corners=False
)
ret[0][3] = F.interpolate(
ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False
)
if self.has_cat:
if isinstance(cache,tuple) or isinstance(cache,list):
ret[0][0] = torch.cat((ret[0][0], cache[0]), dim=1)
ret[0][1] = torch.cat((ret[0][1], cache[1]), dim=1)
ret[0][2] = torch.cat((ret[0][2], cache[2]), dim=1)
ret[0][3] = torch.cat((ret[0][3], cache[3]), dim=1)
else:
ret[0][0] = torch.cat((ret[0][0], self.process_cache(ret[0][0],cache)), dim=1)
ret[0][1] = torch.cat((ret[0][1], self.process_cache(ret[0][1],cache)), dim=1)
ret[0][2] = torch.cat((ret[0][2], self.process_cache(ret[0][2],cache)), dim=1)
ret[0][3] = torch.cat((ret[0][3], self.process_cache(ret[0][3],cache)), dim=1)
return ret
def train(self, mode: bool = True):
if not mode:
return super().train(mode)
set_requires_grad(self, ["cloud_adapter"])
set_train(self, ["cloud_adapter"])
def state_dict(self, destination, prefix, keep_vars):
state = super().state_dict(destination, prefix, keep_vars)
keys = [k for k in state.keys() if "cloud_adapter" not in k]
for key in keys:
state.pop(key)
if key in destination:
destination.pop(key)
return state