from torch import nn from torch.nn import Sequential from hw_asr.base import BaseModel class BaselineModel(BaseModel): def __init__(self, n_feats, n_class, fc_hidden=512, **batch): super().__init__(n_feats, n_class, **batch) self.net = Sequential( # people say it can aproximate any function... nn.Linear(in_features=n_feats, out_features=fc_hidden), nn.ReLU(), nn.Linear(in_features=fc_hidden, out_features=fc_hidden), nn.ReLU(), nn.Linear(in_features=fc_hidden, out_features=n_class) ) def forward(self, spectrogram, **batch): print(self.net(spectrogram.transpose(1, 2)).shape) return {"logits": self.net(spectrogram.transpose(1, 2))} def transform_input_lengths(self, input_lengths): return input_lengths # we don't reduce time dimension here