torinriley commited on
Commit
05469a1
·
1 Parent(s): ba632ba
Files changed (1) hide show
  1. src/diffusion.py +26 -53
src/diffusion.py CHANGED
@@ -13,25 +13,8 @@ class TimeEmbedding(nn.Module):
13
  x = F.silu(self.linear_1(x))
14
  return self.linear_2(x)
15
 
16
- class SqueezeExcitation(nn.Module):
17
- def __init__(self, channels, reduction=16):
18
- super().__init__()
19
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
20
- self.fc = nn.Sequential(
21
- nn.Linear(channels, channels // reduction, bias=False),
22
- nn.ReLU(inplace=True),
23
- nn.Linear(channels // reduction, channels, bias=False),
24
- nn.Sigmoid()
25
- )
26
-
27
- def forward(self, x):
28
- b, c, _, _ = x.size()
29
- y = self.avg_pool(x).view(b, c)
30
- y = self.fc(y).view(b, c, 1, 1)
31
- return x * y.expand_as(x)
32
-
33
  class UNET_ResidualBlock(nn.Module):
34
- def __init__(self, in_channels, out_channels, n_time=1280, use_se=False):
35
  super().__init__()
36
  self.groupnorm_feature = nn.GroupNorm(32, in_channels)
37
  self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
@@ -39,26 +22,16 @@ class UNET_ResidualBlock(nn.Module):
39
  self.groupnorm_merged = nn.GroupNorm(32, out_channels)
40
  self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
41
  self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
42
-
43
- # Add Squeeze-Excitation blocks only if use_se is True
44
- self.use_se = use_se
45
- if use_se:
46
- self.se1 = SqueezeExcitation(out_channels)
47
- self.se2 = SqueezeExcitation(out_channels)
48
 
49
  def forward(self, feature, time):
50
  residue = feature
51
  feature = F.silu(self.groupnorm_feature(feature))
52
  feature = self.conv_feature(feature)
53
- if self.use_se:
54
- feature = self.se1(feature) # Apply SE after first conv
55
 
56
  time = self.linear_time(F.silu(time))
57
  merged = feature + time.unsqueeze(-1).unsqueeze(-1)
58
  merged = F.silu(self.groupnorm_merged(merged))
59
  merged = self.conv_merged(merged)
60
- if self.use_se:
61
- merged = self.se2(merged) # Apply SE after second conv
62
 
63
  return merged + self.residual_layer(residue)
64
 
@@ -112,42 +85,42 @@ class SwitchSequential(nn.Sequential):
112
  return x
113
 
114
  class UNET(nn.Module):
115
- def __init__(self, use_se=False):
116
  super().__init__()
117
  self.encoders = nn.ModuleList([
118
  SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
119
- SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
120
- SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
121
  SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
122
- SwitchSequential(UNET_ResidualBlock(320, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
123
- SwitchSequential(UNET_ResidualBlock(640, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
124
  SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
125
- SwitchSequential(UNET_ResidualBlock(640, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
126
- SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
127
  SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
128
- SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
129
- SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
130
  ])
131
 
132
  self.bottleneck = SwitchSequential(
133
- UNET_ResidualBlock(1280, 1280, use_se=use_se),
134
  UNET_AttentionBlock(8, 160),
135
- UNET_ResidualBlock(1280, 1280, use_se=use_se),
136
  )
137
 
138
  self.decoders = nn.ModuleList([
139
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
140
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
141
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), Upsample(1280)),
142
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
143
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
144
- SwitchSequential(UNET_ResidualBlock(1920, 1280, use_se=use_se), UNET_AttentionBlock(8, 160), Upsample(1280)),
145
- SwitchSequential(UNET_ResidualBlock(1920, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
146
- SwitchSequential(UNET_ResidualBlock(1280, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
147
- SwitchSequential(UNET_ResidualBlock(960, 640, use_se=use_se), UNET_AttentionBlock(8, 80), Upsample(640)),
148
- SwitchSequential(UNET_ResidualBlock(960, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
149
- SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
150
- SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
151
  ])
152
 
153
  def forward(self, x, context, time):
@@ -175,10 +148,10 @@ class UNET_OutputLayer(nn.Module):
175
  return self.conv(x)
176
 
177
  class Diffusion(nn.Module):
178
- def __init__(self, use_se=False):
179
  super().__init__()
180
  self.time_embedding = TimeEmbedding(320)
181
- self.unet = UNET(use_se=use_se)
182
  self.final = UNET_OutputLayer(320, 4)
183
 
184
  def forward(self, latent, context, time):
 
13
  x = F.silu(self.linear_1(x))
14
  return self.linear_2(x)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class UNET_ResidualBlock(nn.Module):
17
+ def __init__(self, in_channels, out_channels, n_time=1280):
18
  super().__init__()
19
  self.groupnorm_feature = nn.GroupNorm(32, in_channels)
20
  self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
 
22
  self.groupnorm_merged = nn.GroupNorm(32, out_channels)
23
  self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
24
  self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
 
 
 
 
 
 
25
 
26
  def forward(self, feature, time):
27
  residue = feature
28
  feature = F.silu(self.groupnorm_feature(feature))
29
  feature = self.conv_feature(feature)
 
 
30
 
31
  time = self.linear_time(F.silu(time))
32
  merged = feature + time.unsqueeze(-1).unsqueeze(-1)
33
  merged = F.silu(self.groupnorm_merged(merged))
34
  merged = self.conv_merged(merged)
 
 
35
 
36
  return merged + self.residual_layer(residue)
37
 
 
85
  return x
86
 
87
  class UNET(nn.Module):
88
+ def __init__(self):
89
  super().__init__()
90
  self.encoders = nn.ModuleList([
91
  SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
92
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
93
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
94
  SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
95
+ SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
96
+ SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
97
  SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
98
+ SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
99
+ SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
100
  SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
101
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
102
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
103
  ])
104
 
105
  self.bottleneck = SwitchSequential(
106
+ UNET_ResidualBlock(1280, 1280),
107
  UNET_AttentionBlock(8, 160),
108
+ UNET_ResidualBlock(1280, 1280),
109
  )
110
 
111
  self.decoders = nn.ModuleList([
112
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
113
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
114
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
115
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
116
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
117
+ SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
118
+ SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
119
+ SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
120
+ SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
121
+ SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
122
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
123
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
124
  ])
125
 
126
  def forward(self, x, context, time):
 
148
  return self.conv(x)
149
 
150
  class Diffusion(nn.Module):
151
+ def __init__(self):
152
  super().__init__()
153
  self.time_embedding = TimeEmbedding(320)
154
+ self.unet = UNET()
155
  self.final = UNET_OutputLayer(320, 4)
156
 
157
  def forward(self, latent, context, time):