Spaces:
Runtime error
Runtime error
################################################################### | |
# File Name: RNN.py | |
# Author: S.X.Zhang | |
################################################################### | |
from __future__ import print_function | |
from __future__ import division | |
from __future__ import absolute_import | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import init | |
class RNN(nn.Module): | |
def __init__(self, input, state_dim): | |
super(RNN, self).__init__() | |
self.bn0 = nn.BatchNorm1d(input, affine=False) | |
self.rnn = nn.LSTM(input, state_dim, 1, dropout=0.1, bidirectional=True) | |
self.prediction = nn.Sequential( | |
nn.Conv1d(state_dim*2, 128, 1), | |
nn.ReLU(inplace=True), | |
nn.Conv1d(128, 64, 1), | |
nn.ReLU(inplace=True), | |
nn.Conv1d(64, 2, 1)) | |
def forward(self, x, adj): | |
x = self.bn0(x) | |
x = x.permute(2, 0, 1) | |
x, _ = self.rnn(x) | |
x = x.permute(1, 2, 0) | |
pred = self.prediction(x) | |
return pred | |