speaker_identify / models /cross_entropy_model.py
DuyTa's picture
Upload folder using huggingface_hub
f831146 verified
from torch import nn
from abc import abstractmethod
import torch
class FBankResBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super().__init__()
padding = (kernel_size - 1) // 2
self.network = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride),
nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride),
nn.BatchNorm2d(out_channels)
)
self.relu = nn.ReLU()
def forward(self, x):
out = self.network(x)
out = out + x
out = self.relu(out)
return out
class FBankNet(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=(5 - 1)//2, stride=2),
FBankResBlock(in_channels=32, out_channels=32, kernel_size=3),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=(5 - 1)//2, stride=2),
FBankResBlock(in_channels=64, out_channels=64, kernel_size=3),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=(5 - 1) // 2, stride=2),
FBankResBlock(in_channels=128, out_channels=128, kernel_size=3),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, padding=(5 - 1) // 2, stride=2),
FBankResBlock(in_channels=256, out_channels=256, kernel_size=3),
nn.AvgPool2d(kernel_size=4)
)
self.linear_layer = nn.Sequential(
nn.Linear(256, 250)
)
@abstractmethod
def forward(self, *input_):
raise NotImplementedError('Call one of the subclasses of this class')
class FBankCrossEntropyNet(FBankNet):
def __init__(self, reduction='mean'):
super().__init__()
self.loss_layer = nn.CrossEntropyLoss(reduction=reduction)
def forward(self, x):
n = x.shape[0]
out = self.network(x)
out = out.reshape(n, -1)
out = self.linear_layer(out)
return out
def loss(self, predictions, labels):
loss_val = self.loss_layer(predictions, labels)
return loss_val
class FBankNetV2(nn.Module):
def __init__(self, num_layers=4, embedding_size = 250):
super().__init__()
layers = []
in_channels = 1
out_channels = 32
for i in range(num_layers):
#print("In: " ,in_channels )
#print("Out: ", out_channels)
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=(5 - 1) // 2, stride=2))
layers.append(FBankResBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3))
if i < num_layers - 1 :
in_channels = out_channels
out_channels *= 2
#print("After in: " ,in_channels )
#print("After Out: ", out_channels)
layers.append(nn.AdaptiveAvgPool2d(output_size=(1,1)))
self.network = nn.Sequential(*layers)
self.linear_layer = nn.Sequential(
nn.Linear(in_features=out_channels, out_features=embedding_size)
)
@abstractmethod
def forward(self, *input_):
raise NotImplementedError('Call one of the subclasses of this class')
class FBankCrossEntropyNetV2(FBankNetV2):
def __init__(self, num_layers=3, reduction='mean'):
super().__init__(num_layers=num_layers)
self.loss_layer = nn.CrossEntropyLoss(reduction=reduction)
def forward(self, x):
n = x.shape[0]
out = self.network(x)
out = out.reshape(n, -1)
out = self.linear_layer(out)
return out
def loss(self, predictions, labels):
loss_val = self.loss_layer(predictions, labels)
return loss_val
def main():
num_layers = 1
model = FBankCrossEntropyNetV2(num_layers = num_layers, reduction='mean')
print(model)
input_data = torch.randn(8, 1, 64, 64)
output = model(input_data)
print("Output shape:", output.shape)
labels = torch.randint(0, 250, (8,))
loss = model.loss(output, labels)
print("Loss:", loss.item())
if __name__ == "__main__":
main()