Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class CAM(nn.Module): | |
''' | |
Convolutional Alignment Module | |
''' | |
# Current version only supports input whose size is a power of 2, such as 32, 64, 128 etc. | |
# You can adapt it to any input size by changing the padding or stride. | |
def __init__(self, | |
channels_list=[64, 128, 256, 512], | |
strides_list=[[2, 2], [1, 1], [1, 1]], | |
in_shape=[8, 32], | |
maxT=25, | |
depth=4, | |
num_channels=128): | |
super(CAM, self).__init__() | |
# cascade multiscale features | |
fpn = [] | |
for i in range(1, len(channels_list)): | |
fpn.append( | |
nn.Sequential( | |
nn.Conv2d(channels_list[i - 1], channels_list[i], (3, 3), | |
(strides_list[i - 1][0], strides_list[i - 1][1]), | |
1), nn.BatchNorm2d(channels_list[i]), | |
nn.ReLU(True))) | |
self.fpn = nn.Sequential(*fpn) | |
# convolutional alignment | |
# convs | |
assert depth % 2 == 0, 'the depth of CAM must be a even number.' | |
# in_shape = scales[-1] | |
strides = [] | |
conv_ksizes = [] | |
deconv_ksizes = [] | |
h, w = in_shape[0], in_shape[1] | |
for i in range(0, int(depth / 2)): | |
stride = [2] if 2**(depth / 2 - i) <= h else [1] | |
stride = stride + [2] if 2**(depth / 2 - i) <= w else stride + [1] | |
strides.append(stride) | |
conv_ksizes.append([3, 3]) | |
deconv_ksizes.append([_**2 for _ in stride]) | |
convs = [ | |
nn.Sequential( | |
nn.Conv2d(channels_list[-1], num_channels, | |
tuple(conv_ksizes[0]), tuple(strides[0]), | |
(int((conv_ksizes[0][0] - 1) / 2), | |
int((conv_ksizes[0][1] - 1) / 2))), | |
nn.BatchNorm2d(num_channels), nn.ReLU(True)) | |
] | |
for i in range(1, int(depth / 2)): | |
convs.append( | |
nn.Sequential( | |
nn.Conv2d(num_channels, num_channels, | |
tuple(conv_ksizes[i]), tuple(strides[i]), | |
(int((conv_ksizes[i][0] - 1) / 2), | |
int((conv_ksizes[i][1] - 1) / 2))), | |
nn.BatchNorm2d(num_channels), nn.ReLU(True))) | |
self.convs = nn.Sequential(*convs) | |
# deconvs | |
deconvs = [] | |
for i in range(1, int(depth / 2)): | |
deconvs.append( | |
nn.Sequential( | |
nn.ConvTranspose2d( | |
num_channels, num_channels, | |
tuple(deconv_ksizes[int(depth / 2) - i]), | |
tuple(strides[int(depth / 2) - i]), | |
(int(deconv_ksizes[int(depth / 2) - i][0] / 4.), | |
int(deconv_ksizes[int(depth / 2) - i][1] / 4.))), | |
nn.BatchNorm2d(num_channels), nn.ReLU(True))) | |
deconvs.append( | |
nn.Sequential( | |
nn.ConvTranspose2d(num_channels, maxT, tuple(deconv_ksizes[0]), | |
tuple(strides[0]), | |
(int(deconv_ksizes[0][0] / 4.), | |
int(deconv_ksizes[0][1] / 4.))), | |
nn.Sigmoid())) | |
self.deconvs = nn.Sequential(*deconvs) | |
def forward(self, input): | |
x = input[0] | |
for i in range(0, len(self.fpn)): | |
# print(self.fpn[i](x).shape, input[i+1].shape) | |
x = self.fpn[i](x) + input[i + 1] | |
conv_feats = [] | |
for i in range(0, len(self.convs)): | |
x = self.convs[i](x) | |
conv_feats.append(x) | |
for i in range(0, len(self.deconvs) - 1): | |
x = self.deconvs[i](x) | |
x = x + conv_feats[len(conv_feats) - 2 - i] | |
x = self.deconvs[-1](x) | |
return x | |
class CAMSimp(nn.Module): | |
def __init__(self, maxT=25, num_channels=128): | |
super(CAMSimp, self).__init__() | |
self.conv = nn.Sequential(nn.Conv2d(num_channels, maxT, 1, 1, 0), | |
nn.Sigmoid()) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class DANDecoder(nn.Module): | |
''' | |
Decoupled Text Decoder | |
''' | |
def __init__(self, | |
out_channels, | |
in_channels, | |
use_cam=True, | |
max_len=25, | |
channels_list=[64, 128, 256, 512], | |
strides_list=[[2, 2], [1, 1], [1, 1]], | |
in_shape=[8, 32], | |
depth=4, | |
dropout=0.3, | |
**kwargs): | |
super(DANDecoder, self).__init__() | |
self.eos = 0 | |
self.bos = out_channels - 2 | |
self.ignore_index = out_channels - 1 | |
nchannel = in_channels | |
self.nchannel = in_channels | |
self.use_cam = use_cam | |
if use_cam: | |
self.cam = CAM(channels_list=channels_list, | |
strides_list=strides_list, | |
in_shape=in_shape, | |
maxT=max_len + 1, | |
depth=depth, | |
num_channels=nchannel) | |
else: | |
self.cam = CAMSimp(maxT=max_len + 1, num_channels=nchannel) | |
self.pre_lstm = nn.LSTM(nchannel, | |
int(nchannel / 2), | |
bidirectional=True) | |
self.rnn = nn.GRUCell(nchannel * 2, nchannel) | |
self.generator = nn.Sequential(nn.Dropout(p=dropout), | |
nn.Linear(nchannel, out_channels - 2)) | |
self.char_embeddings = nn.Embedding(out_channels, | |
embedding_dim=in_channels, | |
padding_idx=out_channels - 1) | |
def forward(self, inputs, data=None): | |
A = self.cam(inputs) | |
if isinstance(inputs, list): | |
feature = inputs[-1] | |
else: | |
feature = inputs | |
nB, nC, nH, nW = feature.shape | |
nT = A.shape[1] | |
# Normalize | |
A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1) | |
# weighted sum | |
C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW) | |
C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0) # T, B, C | |
C, _ = self.pre_lstm(C) # T, B, C | |
C = F.dropout(C, p=0.3, training=self.training) | |
if self.training: | |
text = data[0] | |
text_length = data[-1] | |
nsteps = int(text_length.max()) | |
gru_res = torch.zeros_like(C) | |
hidden = torch.zeros(nB, self.nchannel).type_as(C.data) | |
prev_emb = self.char_embeddings(text[:, 0]) | |
for i in range(0, nsteps + 1): | |
hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1), | |
hidden) | |
gru_res[i, :, :] = hidden | |
prev_emb = self.char_embeddings(text[:, i + 1]) | |
gru_res = self.generator(gru_res) | |
return gru_res[:nsteps + 1, :, :].transpose(1, 0) | |
else: | |
gru_res = torch.zeros_like(C) | |
hidden = torch.zeros(nB, self.nchannel).type_as(C.data) | |
prev_emb = self.char_embeddings( | |
torch.zeros(nB, dtype=torch.int64, device=feature.device) + | |
self.bos) | |
dec_seq = torch.full((nB, nT), | |
self.ignore_index, | |
dtype=torch.int64, | |
device=feature.get_device()) | |
for i in range(0, nT): | |
hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1), | |
hidden) | |
gru_res[i, :, :] = hidden | |
mid_res = self.generator(hidden).argmax(-1) | |
dec_seq[:, i] = mid_res.squeeze(0) | |
if (dec_seq == self.eos).any(dim=-1).all(): | |
break | |
prev_emb = self.char_embeddings(mid_res) | |
gru_res = self.generator(gru_res) | |
return F.softmax(gru_res.transpose(1, 0), -1) | |