shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
import torch.nn as nn
import torch
class CircConv(nn.Module):
def __init__(self, state_dim, out_state_dim=None, n_adj=4):
super(CircConv, self).__init__()
self.n_adj = n_adj
out_state_dim = state_dim if out_state_dim is None else out_state_dim
self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1)
def forward(self, input, adj):
input = torch.cat([input[..., -self.n_adj:], input, input[..., :self.n_adj]], dim=2)
return self.fc(input)
class DilatedCircConv(nn.Module):
def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1):
super(DilatedCircConv, self).__init__()
self.n_adj = n_adj
self.dilation = dilation
out_state_dim = state_dim if out_state_dim is None else out_state_dim
self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation)
def forward(self, input, adj):
if self.n_adj != 0:
input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2)
return self.fc(input)
_conv_factory = {
'grid': CircConv,
'dgrid': DilatedCircConv
}
class BasicBlock(nn.Module):
def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1):
super(BasicBlock, self).__init__()
self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation)
self.relu = nn.ReLU(inplace=True)
self.norm = nn.BatchNorm1d(out_state_dim)
def forward(self, x, adj=None):
x = self.conv(x, adj)
x = self.relu(x)
x = self.norm(x)
return x
class DeepSnake(nn.Module):
def __init__(self, state_dim, feature_dim, conv_type='dgrid'):
super(DeepSnake, self).__init__()
self.head = BasicBlock(feature_dim, state_dim, conv_type)
self.res_layer_num = 7
dilation = [1, 1, 1, 2, 2, 4, 4]
for i in range(self.res_layer_num):
conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i])
self.__setattr__('res'+str(i), conv)
fusion_state_dim = 256
self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1)
self.prediction = nn.Sequential(
nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1),
nn.ReLU(inplace=True),
nn.Conv1d(256, 64, 1),
nn.ReLU(inplace=True),
nn.Conv1d(64, 2, 1)
)
def forward(self, x, adj):
states = []
x = self.head(x, adj)
states.append(x)
for i in range(self.res_layer_num):
x = self.__getattr__('res'+str(i))(x, adj) + x
states.append(x)
state = torch.cat(states, dim=1)
global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0]
global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2))
state = torch.cat([global_state, state], dim=1)
x = self.prediction(state)
return x