import torch.nn as nn import torch class CircConv(nn.Module): def __init__(self, state_dim, out_state_dim=None, n_adj=4): super(CircConv, self).__init__() self.n_adj = n_adj out_state_dim = state_dim if out_state_dim is None else out_state_dim self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1) def forward(self, input, adj): input = torch.cat([input[..., -self.n_adj:], input, input[..., :self.n_adj]], dim=2) return self.fc(input) class DilatedCircConv(nn.Module): def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1): super(DilatedCircConv, self).__init__() self.n_adj = n_adj self.dilation = dilation out_state_dim = state_dim if out_state_dim is None else out_state_dim self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation) def forward(self, input, adj): if self.n_adj != 0: input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2) return self.fc(input) _conv_factory = { 'grid': CircConv, 'dgrid': DilatedCircConv } class BasicBlock(nn.Module): def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1): super(BasicBlock, self).__init__() self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation) self.relu = nn.ReLU(inplace=True) self.norm = nn.BatchNorm1d(out_state_dim) def forward(self, x, adj=None): x = self.conv(x, adj) x = self.relu(x) x = self.norm(x) return x class DeepSnake(nn.Module): def __init__(self, state_dim, feature_dim, conv_type='dgrid'): super(DeepSnake, self).__init__() self.head = BasicBlock(feature_dim, state_dim, conv_type) self.res_layer_num = 7 dilation = [1, 1, 1, 2, 2, 4, 4] for i in range(self.res_layer_num): conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i]) self.__setattr__('res'+str(i), conv) fusion_state_dim = 256 self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1) self.prediction = nn.Sequential( nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1), nn.ReLU(inplace=True), nn.Conv1d(256, 64, 1), nn.ReLU(inplace=True), nn.Conv1d(64, 2, 1) ) def forward(self, x, adj): states = [] x = self.head(x, adj) states.append(x) for i in range(self.res_layer_num): x = self.__getattr__('res'+str(i))(x, adj) + x states.append(x) state = torch.cat(states, dim=1) global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0] global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2)) state = torch.cat([global_state, state], dim=1) x = self.prediction(state) return x