File size: 4,421 Bytes
f831146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from torch import nn
from abc import abstractmethod

import torch

class FBankResBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.network = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride),
            nn.BatchNorm2d(out_channels)
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.network(x)
        out = out + x
        out = self.relu(out)
        return out
class FBankNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=(5 - 1)//2, stride=2),
            FBankResBlock(in_channels=32, out_channels=32, kernel_size=3),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=(5 - 1)//2, stride=2),
            FBankResBlock(in_channels=64, out_channels=64, kernel_size=3),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=(5 - 1) // 2, stride=2),
            FBankResBlock(in_channels=128, out_channels=128, kernel_size=3),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, padding=(5 - 1) // 2, stride=2),
            FBankResBlock(in_channels=256, out_channels=256, kernel_size=3),
            nn.AvgPool2d(kernel_size=4)
        )
        self.linear_layer = nn.Sequential(
            nn.Linear(256, 250)
        )

    @abstractmethod
    def forward(self, *input_):
        raise NotImplementedError('Call one of the subclasses of this class')


class FBankCrossEntropyNet(FBankNet):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.loss_layer = nn.CrossEntropyLoss(reduction=reduction)

    def forward(self, x):
        n = x.shape[0]
        out = self.network(x)
        out = out.reshape(n, -1)
        out = self.linear_layer(out)
        return out

    
    def loss(self, predictions, labels):
        loss_val = self.loss_layer(predictions, labels)
        return loss_val

class FBankNetV2(nn.Module):
    def __init__(self, num_layers=4, embedding_size = 250):
        super().__init__()
        layers = []
        in_channels = 1
        out_channels = 32

        for i in range(num_layers):
            #print("In: " ,in_channels )
            #print("Out: ", out_channels)
            layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=(5 - 1) // 2, stride=2))
            layers.append(FBankResBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3))
            if i < num_layers - 1  :
                in_channels = out_channels
                out_channels *= 2
            #print("After in: " ,in_channels )
            #print("After Out: ", out_channels)
        layers.append(nn.AdaptiveAvgPool2d(output_size=(1,1)))
        self.network = nn.Sequential(*layers)
        self.linear_layer = nn.Sequential(
            nn.Linear(in_features=out_channels, out_features=embedding_size)
        )

    @abstractmethod
    def forward(self, *input_):
        raise NotImplementedError('Call one of the subclasses of this class')





class FBankCrossEntropyNetV2(FBankNetV2):
    def __init__(self, num_layers=3, reduction='mean'):
        super().__init__(num_layers=num_layers)
        self.loss_layer = nn.CrossEntropyLoss(reduction=reduction)

    def forward(self, x):
        n = x.shape[0]
        out = self.network(x)
        out = out.reshape(n, -1)
        out = self.linear_layer(out)
        return out

    def loss(self, predictions, labels):
        loss_val = self.loss_layer(predictions, labels)
        return loss_val

def main():
    num_layers = 1
    model = FBankCrossEntropyNetV2(num_layers = num_layers, reduction='mean')
    print(model)
    input_data = torch.randn(8, 1, 64, 64)

    output = model(input_data)

    print("Output shape:", output.shape)
    labels = torch.randint(0, 250, (8,))

    loss = model.loss(output, labels)

    print("Loss:", loss.item())

if __name__ == "__main__":
    main()