File size: 2,227 Bytes
74b29ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import pytorch_lightning as pl

from torchmetrics.classification import BinaryAccuracy

class BiLSTM(pl.LightningModule):
    def __init__(self, lr, num_classes, input_size, hidden_size=300, num_layers=2, dropout=0.5):
        super(BiLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.output_layer = nn.Linear(hidden_size * 2, num_classes)
        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_metric = BinaryAccuracy()
        self.lr = lr
        self.sigmoid = nn.Sigmoid()

    def forward(self, X):
        lstm_output, _ = self.lstm(X)
        preds = self.output_layer(self.dropout(lstm_output))

        return preds
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        X, target = train_batch

        preds = self(X)
        
        preds = preds.squeeze(1)
        loss = self.criterion(preds, target.float())
        
        preds = self.sigmoid(preds)
        accuracy = self.accuracy_metric(preds, target)

        self.log_dict({'train_loss': loss, 'train_accuracy': accuracy}, prog_bar=True, on_epoch=True)

        return loss

    def validation_step(self, valid_batch, batch_idx):
        X, target = valid_batch

        preds = self(X)
        
        preds = preds.squeeze(1)
        loss = self.criterion(preds, target.float())
        
        preds = self.sigmoid(preds)
        accuracy = self.accuracy_metric(preds, target)

        self.log_dict({'val_loss': loss, 'val_accuracy': accuracy}, prog_bar=True, on_epoch=True)

        return loss

    def test_step(self, test_batch, batch_idx):
        X, target = test_batch

        preds = self(X)
        
        preds = preds.squeeze(1)
        loss = self.criterion(preds, target.float())
        
        preds = self.sigmoid(preds)
        accuracy = self.accuracy_metric(preds, target)

        self.log_dict({'test_loss': loss, 'test_accuracy': accuracy}, prog_bar=True, on_epoch=True)

        return loss