|
from torch import nn |
|
from abc import abstractmethod |
|
|
|
import torch |
|
|
|
class FBankResBlock(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1): |
|
super().__init__() |
|
padding = (kernel_size - 1) // 2 |
|
self.network = nn.Sequential( |
|
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride), |
|
nn.BatchNorm2d(in_channels), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride), |
|
nn.BatchNorm2d(out_channels) |
|
) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x): |
|
out = self.network(x) |
|
out = out + x |
|
out = self.relu(out) |
|
return out |
|
class FBankNet(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.network = nn.Sequential( |
|
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=(5 - 1)//2, stride=2), |
|
FBankResBlock(in_channels=32, out_channels=32, kernel_size=3), |
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=(5 - 1)//2, stride=2), |
|
FBankResBlock(in_channels=64, out_channels=64, kernel_size=3), |
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=(5 - 1) // 2, stride=2), |
|
FBankResBlock(in_channels=128, out_channels=128, kernel_size=3), |
|
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, padding=(5 - 1) // 2, stride=2), |
|
FBankResBlock(in_channels=256, out_channels=256, kernel_size=3), |
|
nn.AvgPool2d(kernel_size=4) |
|
) |
|
self.linear_layer = nn.Sequential( |
|
nn.Linear(256, 250) |
|
) |
|
|
|
@abstractmethod |
|
def forward(self, *input_): |
|
raise NotImplementedError('Call one of the subclasses of this class') |
|
|
|
|
|
class FBankCrossEntropyNet(FBankNet): |
|
def __init__(self, reduction='mean'): |
|
super().__init__() |
|
self.loss_layer = nn.CrossEntropyLoss(reduction=reduction) |
|
|
|
def forward(self, x): |
|
n = x.shape[0] |
|
out = self.network(x) |
|
out = out.reshape(n, -1) |
|
out = self.linear_layer(out) |
|
return out |
|
|
|
|
|
def loss(self, predictions, labels): |
|
loss_val = self.loss_layer(predictions, labels) |
|
return loss_val |
|
|
|
class FBankNetV2(nn.Module): |
|
def __init__(self, num_layers=4, embedding_size = 250): |
|
super().__init__() |
|
layers = [] |
|
in_channels = 1 |
|
out_channels = 32 |
|
|
|
for i in range(num_layers): |
|
|
|
|
|
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=(5 - 1) // 2, stride=2)) |
|
layers.append(FBankResBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3)) |
|
if i < num_layers - 1 : |
|
in_channels = out_channels |
|
out_channels *= 2 |
|
|
|
|
|
layers.append(nn.AdaptiveAvgPool2d(output_size=(1,1))) |
|
self.network = nn.Sequential(*layers) |
|
self.linear_layer = nn.Sequential( |
|
nn.Linear(in_features=out_channels, out_features=embedding_size) |
|
) |
|
|
|
@abstractmethod |
|
def forward(self, *input_): |
|
raise NotImplementedError('Call one of the subclasses of this class') |
|
|
|
|
|
|
|
|
|
|
|
class FBankCrossEntropyNetV2(FBankNetV2): |
|
def __init__(self, num_layers=3, reduction='mean'): |
|
super().__init__(num_layers=num_layers) |
|
self.loss_layer = nn.CrossEntropyLoss(reduction=reduction) |
|
|
|
def forward(self, x): |
|
n = x.shape[0] |
|
out = self.network(x) |
|
out = out.reshape(n, -1) |
|
out = self.linear_layer(out) |
|
return out |
|
|
|
def loss(self, predictions, labels): |
|
loss_val = self.loss_layer(predictions, labels) |
|
return loss_val |
|
|
|
def main(): |
|
num_layers = 1 |
|
model = FBankCrossEntropyNetV2(num_layers = num_layers, reduction='mean') |
|
print(model) |
|
input_data = torch.randn(8, 1, 64, 64) |
|
|
|
output = model(input_data) |
|
|
|
print("Output shape:", output.shape) |
|
labels = torch.randint(0, 250, (8,)) |
|
|
|
loss = model.loss(output, labels) |
|
|
|
print("Loss:", loss.item()) |
|
|
|
if __name__ == "__main__": |
|
main() |