Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from positional_encodings.torch_encodings import PositionalEncoding2D | |
class LayerNorm2D(nn.Module): | |
def __init__(self, embed_dim): | |
super().__init__() | |
self.layer_norm = nn.LayerNorm(embed_dim) | |
def forward(self, x): | |
x = x.permute(0, 2, 3, 1) | |
x = self.layer_norm(x) | |
x = x.permute(0, 3, 1, 2) | |
return x | |
class Image_Adaptor(nn.Module): | |
def __init__(self, in_channels, adp_channels, dropout=0.1): | |
super().__init__() | |
self.adaptor = nn.Sequential( | |
nn.Conv2d(in_channels, adp_channels // 4, kernel_size=4, padding='same'), | |
LayerNorm2D(adp_channels // 4), | |
nn.GELU(), | |
nn.Conv2d(adp_channels // 4, adp_channels // 4, kernel_size=2, padding='same'), | |
LayerNorm2D(adp_channels // 4), | |
nn.GELU(), | |
nn.Conv2d(adp_channels // 4, adp_channels, kernel_size=2, padding='same') | |
) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, images): | |
""" | |
input: [N, in_channels, H, W] | |
output: [N, apd_channels, H, W] | |
""" | |
adapt_imgs = self.adaptor(images) | |
return self.dropout(adapt_imgs) | |
class Positional_Encoding(nn.Module): | |
def __init__(self, adp_channels): | |
super().__init__() | |
self.pe = PositionalEncoding2D(adp_channels) | |
def forward(self, adapt_imgs): | |
""" | |
input: [N, apd_channels, H, W] | |
output: [N, apd_channels, H, W] | |
""" | |
x = adapt_imgs.permute(0, -2, -1, -3) | |
encode = self.pe(x) | |
encode = encode.permute(0, -1, -3, -2) | |
return encode | |
class GeGLU(nn.Module): | |
def __init__(self, emb_channels, ffn_size): | |
super().__init__() | |
self.wi_0 = nn.Linear(emb_channels, ffn_size, bias=False) | |
self.wi_1 = nn.Linear(emb_channels, ffn_size, bias=False) | |
self.act = nn.GELU() | |
def forward(self, x): | |
x_gelu = self.act(self.wi_0(x)) | |
x_linear = self.wi_1(x) | |
x = x_gelu * x_linear | |
return x | |
class Feed_Forward(nn.Module): | |
def __init__(self, in_channels, ffw_channels, dropout=0.1): | |
super().__init__() | |
self.ln1 = GeGLU(in_channels, ffw_channels) | |
self.dropout = nn.Dropout(dropout) | |
self.ln2 = GeGLU(ffw_channels, in_channels) | |
def forward(self, x): | |
''' | |
input: [N, H, W, channels] | |
output: [N, H, W, channels] | |
''' | |
x = self.ln1(x) | |
x = self.dropout(x) | |
x = self.ln2(x) | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, channels, num_attn_heads, dropout=0.1): | |
super().__init__() | |
self.head_size = num_attn_heads | |
self.channels = channels | |
self.attn_size = channels // num_attn_heads | |
self.scale = self.attn_size ** -0.5 | |
assert num_attn_heads * self.attn_size == channels, "Input channels of attention must divisible by number of attention head!" | |
self.lq = nn.Linear(channels, self.head_size*self.attn_size, bias=False) | |
self.lk = nn.Linear(channels, self.head_size*self.attn_size, bias=False) | |
self.lv = nn.Linear(channels, self.head_size*self.attn_size, bias=False) | |
self.lout = nn.Linear(self.head_size*self.attn_size, channels, bias=False) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, q, k, v): | |
''' | |
input: [N, H, W, channels] cho cả 3 cái q, k, v | |
output: [N, H, W, channels] | |
''' | |
bz, H, W, C = q.shape | |
# Duỗi ảnh ra trước | |
q = q.view(bz, -1, C) # [N, H*W, C] | |
k = k.view(bz, -1, C) # [N, H*W, C] | |
v = v.view(bz, -1, C) # [N, H*W, C] | |
q = self.lq(q).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] | |
k = self.lk(k).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] | |
v = self.lv(v).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] | |
q = q.transpose(1, 2) # [N, hz, H*W, az] | |
k = k.transpose(1, 2).transpose(-1, -2) # [N, hz, az, H*W] | |
v = v.transpose(1, 2) # [N, hz, H*W, az] | |
q *= self.scale | |
x = torch.matmul(q, k) # [N, hz, H*W, H*W] | |
x = torch.softmax(x, dim=-1) | |
x = self.dropout(x) | |
x = x.matmul(v) # [N, hz, H*W, az] | |
x = x.transpose(1, 2).contiguous() # [N, H*W, hz, az] | |
x = x.view(bz, -1, C) # [N, H*W, C] | |
x = x.view(bz, H, W, C) # [N, H, W, C] | |
x = self.lout(x) # [N, H, W, C] | |
return x | |
class Transformer_Encoder_Layer(nn.Module): | |
def __init__(self, channels, num_attn_heads, ffw_channels, dropout=0.1): | |
super().__init__() | |
self.attn_norm = nn.LayerNorm(channels) | |
self.attn_layer = MultiHeadAttention(channels, num_attn_heads, dropout) | |
self.attn_dropout = nn.Dropout(dropout) | |
self.ffw_norm = nn.LayerNorm(channels) | |
self.ffw_layer = Feed_Forward(channels, ffw_channels, dropout) | |
self.ffw_dropout = nn.Dropout(dropout) | |
def forward(self, adp_pos_imgs): | |
""" | |
input: [N, H, W, channels] | |
output: [N, H, W, channels] | |
""" | |
_x = adp_pos_imgs | |
x = self.attn_norm(adp_pos_imgs) | |
x = self.attn_layer(x, x, x) | |
x = self.attn_dropout(x) | |
x = x + _x | |
_x = x | |
x = self.ffw_norm(x) | |
x = self.ffw_layer(x) | |
x = self.ffw_dropout(x) | |
x = x + _x | |
return x | |
class Transformer_Encoder(nn.Module): | |
def __init__(self, in_channels, out_channels, num_layers, num_attn_heads, ffw_channels, dropout=0.1): | |
super().__init__() | |
self.encoder_layers = nn.ModuleList([ | |
Transformer_Encoder_Layer(in_channels, num_attn_heads, ffw_channels, dropout) for _ in range(num_layers) | |
]) | |
self.linear = nn.Linear(in_channels, out_channels) | |
self.last_norm = LayerNorm2D(out_channels) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, adp_pos_imgs): | |
""" | |
input: [N, in_channels, H, W] | |
output: [N, out_channels, H, W] | |
""" | |
x = adp_pos_imgs.permute(0, -2, -1, -3) # [N, H, W, in_channels] | |
for layer in self.encoder_layers: | |
x = layer(x) | |
x = self.linear(x) # [N, H, W, out_channels] | |
x = x.permute(0, -1, -3, -2) | |
x = self.last_norm(x) | |
out = self.dropout(x) | |
return out | |
class Double_Conv(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, X): | |
""" | |
input: [N, in_channels, H, W] | |
output: [N, out_channels, H//2, W//2] | |
""" | |
return self.double_conv(X) | |
class Down(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.down = nn.Sequential( | |
nn.MaxPool2d(2), | |
Double_Conv(in_channels, out_channels) | |
) | |
def forward(self, X): | |
""" | |
input: [N, in_channels, H, W] | |
output: [N, out_channels, H//2, W//2] | |
""" | |
return self.down(X) | |
class Up(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2) | |
self.conv = Double_Conv(in_channels, out_channels) | |
def forward(self, X1, X2): | |
""" | |
input: X1 : [N, in_channels, H // 2, W // 2] | |
X2 : [N, in_channels // 2, H, W] | |
output: X : [N, out_channels, H, W] | |
""" | |
X1 = self.up(X1) | |
diffY = X2.shape[-2] - X1.shape[-2] | |
diffX = X2.shape[-1] - X1.shape[-1] | |
pad_top = diffY // 2 | |
pad_bottom = diffY - pad_top | |
pad_left = diffX // 2 | |
pad_right = diffX - pad_left | |
X1 = F.pad(X1, (pad_left, pad_right, pad_top, pad_bottom)) | |
X = torch.cat((X2, X1), dim=-3) | |
return self.conv(X) | |
class Out_Conv(nn.Module): | |
def __init__(self, adp_channels, out_channels): | |
super().__init__() | |
self.out_conv = nn.Conv2d(adp_channels, out_channels, kernel_size=1) | |
def forward(self, X): | |
return self.out_conv(X) | |
class Trans_UNet(nn.Module): | |
def __init__(self, | |
in_channels, | |
adp_channels, | |
out_channels, | |
trans_num_layers=5, | |
trans_num_attn_heads=8, | |
trans_ffw_channels=1024, | |
dropout=0.1): | |
super().__init__() | |
self.img_adaptor = Image_Adaptor(in_channels, adp_channels, dropout) | |
self.pos_encoding = Positional_Encoding(adp_channels) | |
self.down1 = Down(adp_channels * 1, adp_channels * 2) | |
self.down2 = Down(adp_channels * 2, adp_channels * 4) | |
self.down3 = Down(adp_channels * 4, adp_channels * 8) | |
self.down4 = Down(adp_channels * 8, adp_channels * 16) | |
self.down5 = Down(adp_channels * 16, adp_channels * 32) | |
self.trans_encoder = Transformer_Encoder(adp_channels * 32, adp_channels * 32, trans_num_layers, trans_num_attn_heads, trans_ffw_channels, dropout) | |
self.up5 = Up(adp_channels * 32, adp_channels * 16) | |
self.up4 = Up(adp_channels * 16, adp_channels * 8) | |
self.up3 = Up(adp_channels * 8, adp_channels * 4) | |
self.up2 = Up(adp_channels * 4, adp_channels * 2) | |
self.up1 = Up(adp_channels * 2, adp_channels * 1) | |
self.out_conv = Out_Conv(adp_channels, out_channels) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, images): | |
adp_imgs = self.img_adaptor(images) | |
pos_enc = self.pos_encoding(adp_imgs) | |
adp_imgs += pos_enc | |
d1 = self.down1(adp_imgs) | |
d2 = self.down2(d1) | |
d3 = self.down3(d2) | |
d4 = self.down4(d3) | |
d5 = self.down5(d4) | |
x = self.trans_encoder(d5) | |
u5 = self.up5(x, d4) | |
u4 = self.up4(u5, d3) | |
u3 = self.up3(u4, d2) | |
u2 = self.up2(u3, d1) | |
u1 = self.up1(u2, adp_imgs) | |
x = self.out_conv(u1) | |
out = self.sigmoid(x) | |
return out |