|
from mmseg.models.builder import BACKBONES, MODELS |
|
from torch import nn as nn |
|
from .cloud_adapter import CloudAdapter |
|
from .dino_v2 import DinoVisionTransformer |
|
from .utils import set_requires_grad, set_train |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
@BACKBONES.register_module() |
|
class CloudAdapterDinoVisionTransformer(DinoVisionTransformer): |
|
def __init__( |
|
self, |
|
cloud_adapter_config=None, |
|
has_cat=False, |
|
|
|
adapter_index=[0, 6, 12, 18], |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.cloud_adapter: CloudAdapter = MODELS.build(cloud_adapter_config) |
|
self.has_cat = has_cat |
|
self.adapter_index = adapter_index |
|
|
|
def forward_features(self, x, masks=None): |
|
B, _, h, w = x.shape |
|
cache = self.cloud_adapter.cnn(x) |
|
H, W = h // self.patch_size, w // self.patch_size |
|
x = self.prepare_tokens_with_masks(x, masks) |
|
outs = [] |
|
cur_idx = 0 |
|
for idx, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
if idx in self.adapter_index: |
|
x = self.cloud_adapter.forward( |
|
x, |
|
cur_idx, |
|
batch_first=True, |
|
has_cls_token=True, |
|
cache=cache, |
|
) |
|
cur_idx += 1 |
|
if idx in self.out_indices: |
|
outs.append( |
|
x[:, 1:, :].permute(0, 2, 1).reshape( |
|
B, -1, H, W).contiguous() |
|
) |
|
return outs, cache |
|
|
|
def process_cache(self,ret,cache): |
|
cache = F.interpolate( |
|
cache,size=(ret.shape[-2],ret.shape[-1]),mode="bilinear",align_corners=False) |
|
return cache |
|
|
|
def forward(self, *args, **kwargs): |
|
ret, cache = self.forward_features(*args, **kwargs) |
|
if isinstance(ret[0], torch.Tensor): |
|
ret[0] = F.interpolate( |
|
ret[0], scale_factor=4, mode="bilinear", align_corners=False |
|
) |
|
ret[1] = F.interpolate( |
|
ret[1], scale_factor=2, mode="bilinear", align_corners=False |
|
) |
|
ret[3] = F.interpolate( |
|
ret[3], scale_factor=0.5, mode="bilinear", align_corners=False |
|
) |
|
if self.has_cat: |
|
if isinstance(cache,tuple) or isinstance(cache,list): |
|
ret[0] = torch.cat((ret[0], cache[0]), dim=1) |
|
ret[1] = torch.cat((ret[1], cache[1]), dim=1) |
|
ret[2] = torch.cat((ret[2], cache[2]), dim=1) |
|
ret[3] = torch.cat((ret[3], cache[3]), dim=1) |
|
else: |
|
ret[0] = torch.cat((ret[0], self.process_cache(ret[0],cache)), dim=1) |
|
ret[1] = torch.cat((ret[1], self.process_cache(ret[1],cache)), dim=1) |
|
ret[2] = torch.cat((ret[2], self.process_cache(ret[2],cache)), dim=1) |
|
ret[3] = torch.cat((ret[3], self.process_cache(ret[3],cache)), dim=1) |
|
|
|
else: |
|
ret[0][0] = F.interpolate( |
|
ret[0][0], scale_factor=4, mode="bilinear", align_corners=False |
|
) |
|
ret[0][1] = F.interpolate( |
|
ret[0][1], scale_factor=2, mode="bilinear", align_corners=False |
|
) |
|
ret[0][3] = F.interpolate( |
|
ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False |
|
) |
|
if self.has_cat: |
|
if isinstance(cache,tuple) or isinstance(cache,list): |
|
ret[0][0] = torch.cat((ret[0][0], cache[0]), dim=1) |
|
ret[0][1] = torch.cat((ret[0][1], cache[1]), dim=1) |
|
ret[0][2] = torch.cat((ret[0][2], cache[2]), dim=1) |
|
ret[0][3] = torch.cat((ret[0][3], cache[3]), dim=1) |
|
else: |
|
ret[0][0] = torch.cat((ret[0][0], self.process_cache(ret[0][0],cache)), dim=1) |
|
ret[0][1] = torch.cat((ret[0][1], self.process_cache(ret[0][1],cache)), dim=1) |
|
ret[0][2] = torch.cat((ret[0][2], self.process_cache(ret[0][2],cache)), dim=1) |
|
ret[0][3] = torch.cat((ret[0][3], self.process_cache(ret[0][3],cache)), dim=1) |
|
return ret |
|
|
|
def train(self, mode: bool = True): |
|
if not mode: |
|
return super().train(mode) |
|
set_requires_grad(self, ["cloud_adapter"]) |
|
set_train(self, ["cloud_adapter"]) |
|
|
|
def state_dict(self, destination, prefix, keep_vars): |
|
state = super().state_dict(destination, prefix, keep_vars) |
|
keys = [k for k in state.keys() if "cloud_adapter" not in k] |
|
for key in keys: |
|
state.pop(key) |
|
if key in destination: |
|
destination.pop(key) |
|
return state |
|
|