|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, input_dim, output_class_num, rnn_layers, hidden_size, **kwargs): |
|
super(Model, self).__init__() |
|
|
|
|
|
self.use_rnn = rnn_layers > 0 |
|
if self.use_rnn: |
|
self.rnn = nn.LSTM( |
|
input_dim, hidden_size, num_layers=rnn_layers, batch_first=True |
|
) |
|
self.linear = nn.Linear(hidden_size, output_class_num) |
|
else: |
|
self.linear = nn.Linear(input_dim, output_class_num) |
|
|
|
def forward(self, features): |
|
features = features.float() |
|
if self.use_rnn: |
|
hidden, _ = self.rnn(features) |
|
predicted = self.linear(hidden) |
|
else: |
|
predicted = self.linear(features) |
|
return predicted |
|
|