venkatesh-thiru commited on
Commit
c74cff1
1 Parent(s): 8ef2d03

Upload model

Browse files
Files changed (6) hide show
  1. RRDB.py +118 -0
  2. SRMRIModels.py +41 -0
  3. SRMRIModelsConfigs.py +52 -0
  4. config.json +23 -0
  5. model.safetensors +3 -0
  6. unet3DMSS.py +189 -0
RRDB.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class make_dense(nn.Module):
8
+ def __init__(self,nChannels,GrowthRate,kernel_size=3):
9
+ super(make_dense,self).__init__()
10
+ self.conv = nn.Conv3d(nChannels,GrowthRate,kernel_size=kernel_size,padding=(kernel_size-1)//2,bias=True)
11
+ # self.norm = nn.BatchNorm3d(nChannels)
12
+ def forward(self,x):
13
+ # out = self.norm(x)
14
+ out = F.relu(self.conv(x))
15
+ out = torch.cat([x,out],dim=1)
16
+ return out
17
+
18
+ # class RDB(nn.Module):
19
+ # def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3):
20
+ # super(RDB,self).__init__()
21
+ # nChannels_ = inChannels
22
+ # modules = []
23
+ # for i in range (nDenseLayer):
24
+ # modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
25
+ # nChannels_ += GrowthRate
26
+ # self.dense_layers = nn.Sequential(*modules)
27
+ # self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
28
+ # def forward(self,x):
29
+ # out = self.dense_layers(x)
30
+ # out = self.conv_1x1(out)
31
+ # # out = out + x
32
+ # return out
33
+
34
+ class RDB(nn.Module):
35
+ def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3,
36
+ block_dropout = True, block_dropout_rate = 0.2):
37
+ super(RDB,self).__init__()
38
+ nChannels_ = inChannels
39
+ modules = []
40
+ for i in range (nDenseLayer):
41
+ modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
42
+ nChannels_ += GrowthRate
43
+ if block_dropout:
44
+ modules.append(nn.Dropout3d(block_dropout_rate))
45
+ self.dense_layers = nn.Sequential(*modules)
46
+ self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
47
+ def forward(self,x):
48
+ out = self.dense_layers(x)
49
+ out = self.conv_1x1(out)
50
+ # out = out + x
51
+ return out
52
+
53
+
54
+ class RRDB(nn.Module):
55
+ def __init__(self,nChannels,nDenseLayers,nInitFeat,GrowthRate,featureFusion=True,kernel_config = [3,3,3,3]):
56
+ super(RRDB,self).__init__()
57
+ nChannels_ = nChannels
58
+ nDenseLayers_ = nDenseLayers
59
+ nInitFeat_ = nInitFeat
60
+ GrowthRate_ = GrowthRate
61
+ self.featureFusion = featureFusion
62
+
63
+ #First Convolution
64
+ self.C1 = nn.Conv3d(nChannels_,nInitFeat_,kernel_size=kernel_config[0],padding=(kernel_config[0]-1)//2,bias=True)
65
+ # Initialize RDB
66
+ if self.featureFusion:
67
+ self.RDB1 = RDB(nInitFeat_,nInitFeat_,nDenseLayers_,GrowthRate_,kernel_config[1])
68
+ # print(f"RDB1 =========================================== \n {self.RDB1}")
69
+ self.RDB2 = RDB(nInitFeat_*2,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[2])
70
+ # print(f"RDB2 =========================================== \n {self.RDB2}")
71
+ self.RDB3 = RDB(nInitFeat_*3,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[3])
72
+ # print(f"RDB3 =========================================== \n {self.RDB3}")
73
+ self.FF_1X1 = nn.Conv3d(nInitFeat_*4, 1, kernel_size=1, padding=0, bias=True)
74
+ # print(f"FF1x1 =========================================== \n {self.FF_1X1}")
75
+ else:
76
+ self.RDB1 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[1])
77
+ self.RDB2 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[2])
78
+ self.RDB3 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[3])
79
+ self.FF_1X1 = nn.Conv3d(nInitFeat_, 1, kernel_size=1, padding=0, bias=True)
80
+
81
+
82
+ # Feature Fusion
83
+
84
+
85
+ # self.FF_3X3 = nn.Conv3d(nInitFeat_,nInitFeat_,kernel_size=3,padding=1,bias=True)
86
+
87
+ # self.final_layer = nn.Conv3d(nInitFeat_,nChannels_,kernel_size=1,padding=0,bias=False)
88
+
89
+ def forward(self,x):
90
+ First = F.relu(self.C1(x))
91
+ R_1 = self.RDB1(First)
92
+
93
+ if self.featureFusion:
94
+ FF0 = torch.cat([First,R_1],dim = 1)
95
+ R_2 = self.RDB2(FF0)
96
+ FF1 = torch.cat([First,R_1,R_2],dim=1)
97
+ R_3 = self.RDB3(FF1)
98
+ FF2 = torch.cat([First,R_1, R_2, R_3], dim=1)
99
+ FF1X1 = self.FF_1X1(FF2)
100
+ else:
101
+ R_2 = self.RDB2(R_1)
102
+ R_3 = self.RDB3(R_2)
103
+ FF1X1 = self.FF_1X1(R_3)
104
+
105
+ # FF2 = torch.cat([R_1,R_2,R_3],dim=1)
106
+ # FF1X1 = self.FF_1X1(FF2)
107
+ # FF3X3 = self.FF_3X3(FF1X1)
108
+ # output = self.final_layer(FF3X3)
109
+
110
+ return FF1X1
111
+
112
+ if __name__ == '__main__':
113
+ model = RRDB(nChannels=1,nDenseLayers=6,nInitFeat=6,GrowthRate=12,featureFusion=True,kernel_config = [3,3,3,3]).cuda()
114
+ dimensions = 1, 1, 64, 64, 64
115
+ x = torch.rand(dimensions)
116
+ x = x.cuda()
117
+ out = model(x)
118
+ print(out.shape)
SRMRIModels.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .RRDB import RRDB
3
+ from .unet3DMSS import UNetMSS
4
+ from .SRMRIModelsConfigs import RRDBConfiguration, UNetMSSConfiguration
5
+
6
+ class SRMRIModelUNetMSS(PreTrainedModel):
7
+ config_class = UNetMSSConfiguration
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+ self.model = UNetMSS(
11
+ in_channels=config.in_channels,
12
+ n_classes=config.n_classes,
13
+ depth=config.depth,
14
+ wf=config.wf,
15
+ padding=config.padding,
16
+ batch_norm=config.batch_norm,
17
+ up_mode=config.up_mode,
18
+ dropout=config.dropout,
19
+ mss_level=config.mss_level,
20
+ mss_fromlatent=config.mss_fromlatent,
21
+ mss_up=config.mss_up,
22
+ mss_interpb4=config.mss_interpb4)
23
+ def forward(self, x):
24
+ return self.model.forward(x)
25
+
26
+
27
+ class SRMRIModelRRDB(PreTrainedModel):
28
+ config_class = RRDBConfiguration
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+ self.model = RRDB(
32
+ nChannels=config.nChannels,
33
+ nDenseLayers=config.nDenseLayers,
34
+ nInitFeat=config.nInitFeat,
35
+ GrowthRate=config.GrowthRate,
36
+ featureFusion=config.featureFusion,
37
+ kernel_config=config.kernel_config,
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.model.forward(x)
SRMRIModelsConfigs.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class RRDBConfiguration(PretrainedConfig):
5
+ model_type = "SRMRIModelRRDB"
6
+ def __init__(
7
+ self,
8
+ nChannels=1,
9
+ nDenseLayers=6,
10
+ nInitFeat=6,
11
+ GrowthRate=12,
12
+ featureFusion=True,
13
+ kernel_config=[3, 3, 3, 3],
14
+ **kwargs):
15
+ self.nChannels = nChannels
16
+ self.nDenseLayers = nDenseLayers
17
+ self.nInitFeat = nInitFeat
18
+ self.GrowthRate = GrowthRate
19
+ self.featureFusion = featureFusion
20
+ self.kernel_config = kernel_config
21
+ super().__init__(**kwargs)
22
+
23
+ class UNetMSSConfiguration(PretrainedConfig):
24
+ model_type = "SRMRIModelUNetMSS"
25
+ def __init__(
26
+ self,
27
+ in_channels=1,
28
+ n_classes=1,
29
+ depth=3,
30
+ wf=6,
31
+ padding=True,
32
+ batch_norm=False,
33
+ up_mode='upconv',
34
+ dropout=False,
35
+ mss_level=2,
36
+ mss_fromlatent=True,
37
+ mss_up="trilinear",
38
+ mss_interpb4=True,
39
+ **kwargs):
40
+ self.in_channels = in_channels
41
+ self.n_classes = n_classes
42
+ self.depth = depth
43
+ self.wf = wf
44
+ self.padding = padding
45
+ self.batch_norm = batch_norm
46
+ self.up_mode = up_mode
47
+ self.dropout = dropout
48
+ self.mss_level = mss_level
49
+ self.mss_fromlatent = mss_fromlatent
50
+ self.mss_up = mss_up
51
+ self.mss_interpb4 = mss_interpb4
52
+ super().__init__(**kwargs)
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "GrowthRate": 12,
3
+ "architectures": [
4
+ "SRMRIModelRRDB"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "SRMRIModelsConfigs.RRDBConfiguration",
8
+ "AutoModel": "SRMRIModels.SRMRIModelRRDB"
9
+ },
10
+ "featureFusion": true,
11
+ "kernel_config": [
12
+ 3,
13
+ 3,
14
+ 3,
15
+ 3
16
+ ],
17
+ "model_type": "SRMRIModelRRDB",
18
+ "nChannels": 1,
19
+ "nDenseLayers": 6,
20
+ "nInitFeat": 6,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.44.0"
23
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a85a532ce98e193005edcada09bf58b24e9746346432e21a2d8695be310d88c
3
+ size 991796
unet3DMSS.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://discuss.pytorch.org/t/unet-implementation/426
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ # import torchcomplex.nn.functional as cF
7
+
8
+ __author__ = "Soumick Chatterjee, Chompunuch Sarasaen"
9
+ __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
10
+ __credits__ = ["Soumick Chatterjee", "Chompunuch Sarasaen"]
11
+ __license__ = "GPL"
12
+ __version__ = "1.0.0"
13
+ __maintainer__ = "Soumick Chatterjee"
14
+ __email__ = "[email protected]"
15
+ __status__ = "Production"
16
+
17
+
18
+ class UNetMSS(nn.Module):
19
+ """
20
+ Implementation of
21
+ U-Net: Convolutional Networks for Biomedical Image Segmentation
22
+ (Ronneberger et al., 2015)
23
+ https://arxiv.org/abs/1505.04597
24
+
25
+ Using the default arguments will yield the exact version used
26
+ in the original paper
27
+
28
+ Args:
29
+ in_channels (int): number of input channels
30
+ n_classes (int): number of output channels
31
+ depth (int): depth of the network
32
+ wf (int): number of filters in the first layer is 2**wf
33
+ padding (bool): if True, apply padding such that the input shape
34
+ is the same as the output.
35
+ This may introduce artifacts
36
+ batch_norm (bool): Use BatchNorm after layers with an
37
+ activation function
38
+ up_mode (str): one of 'upconv' or 'upsample'.
39
+ 'upconv' will use transposed convolutions for
40
+ learned upsampling.
41
+ 'upsample' will use bilinear upsampling.
42
+ """
43
+ def __init__(self, in_channels=1, n_classes=1, depth=3, wf=6, padding=True,
44
+ batch_norm=False, up_mode='upconv', dropout=False, mss_level=2, mss_fromlatent=True,
45
+ mss_up="trilinear", mss_interpb4=False):
46
+ super(UNetMSS, self).__init__()
47
+ assert up_mode in ('upconv', 'upsample')
48
+ self.padding = padding
49
+ self.depth = depth
50
+ self.dropout = nn.Dropout3d() if dropout else nn.Sequential()
51
+ prev_channels = in_channels
52
+ self.down_path = nn.ModuleList()
53
+ up_out_features = []
54
+ for i in range(depth):
55
+ self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
56
+ padding, batch_norm))
57
+ prev_channels = 2**(wf+i)
58
+
59
+ if mss_fromlatent:
60
+ mss_features = [prev_channels]
61
+ else:
62
+ mss_features = []
63
+
64
+ self.up_path = nn.ModuleList()
65
+ for i in reversed(range(depth - 1)):
66
+ self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
67
+ padding, batch_norm))
68
+ prev_channels = 2**(wf+i)
69
+ up_out_features.append(prev_channels)
70
+
71
+ self.last = nn.Conv3d(prev_channels, n_classes, kernel_size=1)
72
+
73
+ mss_features += up_out_features[len(up_out_features)-1-mss_level if not mss_fromlatent
74
+ else len(up_out_features)-1-mss_level+1:-1]
75
+
76
+ self.mss_level = mss_level
77
+ self.mss_up = mss_up
78
+ self.mss_fromlatent = mss_fromlatent
79
+ self.mss_interpb4 = mss_interpb4
80
+ self.mss_convs = nn.ModuleList()
81
+ for i in range(self.mss_level):
82
+ self.mss_convs.append(nn.Conv3d(mss_features[i], n_classes, kernel_size=1))
83
+ if self.mss_level == 1:
84
+ self.mss_coeff = [0.5]
85
+ else:
86
+ lmbda = []
87
+ for i in range(self.mss_level-1, -1, -1):
88
+ lmbda.append(2**i)
89
+ self.mss_coeff = []
90
+ fact = 1.0 / sum(lmbda)
91
+ for i in range(self.mss_level-1):
92
+ self.mss_coeff.append(fact*lmbda[i])
93
+ self.mss_coeff.append(1.0 - sum(self.mss_coeff))
94
+ self.mss_coeff.reverse()
95
+
96
+
97
+ def forward(self, x):
98
+ blocks = []
99
+ for i, down in enumerate(self.down_path):
100
+ x = down(x)
101
+ if i != len(self.down_path)-1:
102
+ blocks.append(x)
103
+ x = F.avg_pool3d(x, 2)
104
+ x = self.dropout(x)
105
+
106
+ if self.mss_fromlatent:
107
+ mss = [x]
108
+ else:
109
+ mss = []
110
+
111
+ for i, up in enumerate(self.up_path):
112
+ x = up(x, blocks[-i-1])
113
+ if self.training and ((len(self.up_path)-1-i <= self.mss_level) and not(i+1 == len(self.up_path))):
114
+ mss.append(x)
115
+
116
+ if self.training:
117
+ for i in range(len(mss)):
118
+ if not self.mss_interpb4:
119
+ mss[i] = F.interpolate(self.mss_convs[i](mss[i]), size=x.shape[2:], mode=self.mss_up)
120
+ else:
121
+ mss[i] = self.mss_convs[i](F.interpolate(mss[i], size=x.shape[2:], mode=self.mss_up))
122
+
123
+ return self.last(x), mss
124
+ else:
125
+ return self.last(x)
126
+
127
+ class UNetConvBlock(nn.Module):
128
+ def __init__(self, in_size, out_size, padding, batch_norm):
129
+ super(UNetConvBlock, self).__init__()
130
+ block = []
131
+
132
+ block.append(nn.Conv3d(in_size, out_size, kernel_size=3,
133
+ padding=int(padding)))
134
+ block.append(nn.ReLU())
135
+ if batch_norm:
136
+ block.append(nn.BatchNorm3d(out_size))
137
+
138
+ block.append(nn.Conv3d(out_size, out_size, kernel_size=3,
139
+ padding=int(padding)))
140
+ block.append(nn.ReLU())
141
+ if batch_norm:
142
+ block.append(nn.BatchNorm3d(out_size))
143
+
144
+ self.block = nn.Sequential(*block)
145
+
146
+ def forward(self, x):
147
+ out = self.block(x)
148
+ return out
149
+
150
+
151
+ class UNetUpBlock(nn.Module):
152
+ def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
153
+ super(UNetUpBlock, self).__init__()
154
+ if up_mode == 'upconv':
155
+ self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=2,
156
+ stride=2)
157
+ elif up_mode == 'upsample':
158
+ self.up = nn.Sequential(nn.Upsample(mode='trilinear', scale_factor=2),
159
+ nn.Conv3d(in_size, out_size, kernel_size=1))
160
+
161
+ self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
162
+
163
+ def center_crop(self, layer, target_size):
164
+ _, _, layer_depth, layer_height, layer_width = layer.size()
165
+ diff_z = (layer_depth - target_size[0]) // 2
166
+ diff_y = (layer_height - target_size[1]) // 2
167
+ diff_x = (layer_width - target_size[2]) // 2
168
+ return layer[:, :, diff_z:(diff_z + target_size[0]), diff_y:(diff_y + target_size[1]), diff_x:(diff_x + target_size[2])]
169
+ # _, _, layer_height, layer_width = layer.size() #for 2D data
170
+ # diff_y = (layer_height - target_size[0]) // 2
171
+ # diff_x = (layer_width - target_size[1]) // 2
172
+ # return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
173
+
174
+ def forward(self, x, bridge):
175
+ up = self.up(x)
176
+ # bridge = self.center_crop(bridge, up.shape[2:]) #sending shape ignoring 2 digit, so target size start with 0,1,2
177
+ up = F.interpolate(up, size=bridge.shape[2:], mode='trilinear')
178
+ out = torch.cat([up, bridge], 1)
179
+ out = self.conv_block(out)
180
+
181
+ return out
182
+
183
+
184
+ if __name__ == "__main__":
185
+ model = UNetMSS(in_channels=1, n_classes=1, depth=4, wf=6, padding=True,
186
+ batch_norm=False, up_mode='upconv', dropout=True, mss_level=3,
187
+ mss_fromlatent=True, mss_up="trilinear", mss_interpb4=True).cuda()
188
+
189
+ print(model)