XavierJiezou commited on
Commit
a522d83
·
verified ·
1 Parent(s): 22d785f

Delete ktda/models/adapter

Browse files
ktda/models/adapter/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .fam import FAM
2
- from .fmm import FMM
3
-
4
- __all__ = ["FAM", "FMM"]
 
 
 
 
 
ktda/models/adapter/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (288 Bytes)
 
ktda/models/adapter/__pycache__/fam.cpython-311.pyc DELETED
Binary file (2.86 kB)
 
ktda/models/adapter/__pycache__/fmm.cpython-311.pyc DELETED
Binary file (5.88 kB)
 
ktda/models/adapter/fam.py DELETED
@@ -1,37 +0,0 @@
1
- from mmseg.registry import MODELS
2
- from mmengine.model import BaseModule
3
- from torch import nn as nn
4
- from torch.nn import functional as F
5
- from timm.models.layers import trunc_normal_
6
-
7
-
8
- @MODELS.register_module()
9
- class FAM(BaseModule):
10
- def __init__(self, in_channels, out_channels, output_size,init_cfg=None):
11
- super().__init__(init_cfg)
12
- self.convert = nn.ModuleList()
13
- self.output_size = output_size
14
- if isinstance(out_channels, int):
15
- out_channels = [out_channels] * len(in_channels)
16
- for in_channel, out_channel in zip(in_channels, out_channels):
17
- self.convert.append(
18
- nn.Conv2d(in_channel, out_channel, kernel_size=1),
19
- )
20
-
21
- self.apply(self._init_weights)
22
-
23
- def _init_weights(self, m):
24
- if isinstance(m, (nn.Conv2d, nn.Linear)):
25
- trunc_normal_(m.weight, std=.02)
26
- nn.init.constant_(m.bias, 0)
27
-
28
-
29
- def forward(self, inputs):
30
- outs = []
31
- for index, x in enumerate(inputs):
32
- x = self.convert[index](x)
33
- x = F.interpolate(
34
- x, size=(self.output_size,self.output_size), align_corners=False, mode="bilinear"
35
- )
36
- outs.append(x)
37
- return tuple(outs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ktda/models/adapter/fmm.py DELETED
@@ -1,109 +0,0 @@
1
- from mmseg.registry import MODELS
2
- from mmengine.model import BaseModule
3
- from torch import nn as nn
4
- from torch.nn import functional as F
5
- from typing import Callable, Optional
6
- from torch import Tensor
7
- from timm.models.layers import trunc_normal_
8
- from timm.models.vision_transformer import Block as TransformerBlock
9
-
10
-
11
- class Mlp(nn.Module):
12
- def __init__(
13
- self,
14
- in_features: int,
15
- hidden_features: Optional[int] = None,
16
- out_features: Optional[int] = None,
17
- act_layer: Callable[..., nn.Module] = nn.GELU,
18
- drop: float = 0.0,
19
- bias: bool = True,
20
- ) -> None:
21
- super().__init__()
22
- out_features = out_features or in_features
23
- hidden_features = hidden_features or in_features
24
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
25
- self.act = act_layer()
26
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
27
- self.drop = nn.Dropout(drop)
28
-
29
- def forward(self, x: Tensor) -> Tensor:
30
- x = self.fc1(x)
31
- x = self.act(x)
32
- x = self.drop(x)
33
- x = self.fc2(x)
34
- x = self.drop(x)
35
- return x
36
-
37
-
38
- @MODELS.register_module()
39
- class FMM(BaseModule):
40
- def __init__(
41
- self,
42
- in_channels,
43
- rank_dim=4,
44
- mlp_nums=1,
45
- model_type="mlp",
46
- num_heads=8,
47
- mlp_ratio=4,
48
- qkv_bias=True,
49
- qk_norm=False,
50
- init_values=None,
51
- proj_drop_rate: float = 0.0,
52
- attn_drop_rate: float = 0.0,
53
- init_cfg=None,
54
- ):
55
- super().__init__(init_cfg)
56
- self.adapters = nn.ModuleList()
57
- if model_type == "mlp":
58
- for in_channel in in_channels:
59
- mlp_list = []
60
- for _ in range(mlp_nums):
61
- mlp_list.append(
62
- Mlp(
63
- in_channel,
64
- hidden_features=in_channel // rank_dim,
65
- out_features=in_channel,
66
- )
67
- )
68
- mlp_model = nn.Sequential(*mlp_list)
69
- self.adapters.append(mlp_model)
70
-
71
- elif model_type == "vitBlock":
72
- for in_channel in in_channels:
73
- model_list = []
74
- for _ in range(mlp_nums):
75
- model_list.append(
76
- TransformerBlock(
77
- in_channel,
78
- num_heads=num_heads,
79
- mlp_ratio=mlp_ratio,
80
- qkv_bias=qkv_bias,
81
- qk_norm=qk_norm,
82
- init_values=init_values,
83
- proj_drop=proj_drop_rate,
84
- attn_drop=attn_drop_rate,
85
- )
86
- )
87
- self.adapters.append(nn.Sequential(*model_list))
88
-
89
- else:
90
- raise ValueError(f"model type must in ['mlp','vitBlock'],actually is {model_type}")
91
-
92
- self.apply(self._init_weights)
93
-
94
- def _init_weights(self, m):
95
- if isinstance(m, (nn.Conv2d, nn.Linear)):
96
- trunc_normal_(m.weight, std=0.02)
97
- nn.init.constant_(m.bias, 0)
98
-
99
- def forward(self, inputs):
100
- outs = []
101
- for index, x in enumerate(inputs):
102
- B, C, H, W = x.shape
103
- x = x.permute(0, 2, 3, 1)
104
- x = x.reshape(B, -1, C)
105
- x = self.adapters[index](x)
106
- x = x.reshape(B, H, W, C)
107
- x = x.permute(0, 3, 1, 2)
108
- outs.append(x)
109
- return tuple(outs)