lmzjms's picture
Upload 1162 files
0b32ad6 verified
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ model.py ]
# Synopsis [ Baseline model for speaker diarization ]
# Author [ Jiatong ]
# Copyright [ Copyright(c), Johns Hopkins University ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import torch
import torch.nn as nn
#########
# MODEL #
#########
class Model(nn.Module):
def __init__(self, input_dim, output_class_num, rnn_layers, hidden_size, **kwargs):
super(Model, self).__init__()
# init attributes
self.use_rnn = rnn_layers > 0
if self.use_rnn:
self.rnn = nn.LSTM(
input_dim, hidden_size, num_layers=rnn_layers, batch_first=True
)
self.linear = nn.Linear(hidden_size, output_class_num)
else:
self.linear = nn.Linear(input_dim, output_class_num)
def forward(self, features):
features = features.float()
if self.use_rnn:
hidden, _ = self.rnn(features)
predicted = self.linear(hidden)
else:
predicted = self.linear(features)
return predicted