File size: 3,995 Bytes
d495eaa
779cf1e
 
d495eaa
 
 
 
 
779cf1e
d495eaa
1440e8c
d495eaa
1440e8c
 
d495eaa
 
 
 
 
 
 
 
 
 
1440e8c
d495eaa
 
 
 
 
 
 
1440e8c
 
d495eaa
 
 
 
1440e8c
d495eaa
 
 
1440e8c
d495eaa
 
779cf1e
d495eaa
 
 
779cf1e
d495eaa
 
 
779cf1e
d495eaa
 
 
 
779cf1e
d495eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import warnings

import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import cuda, nn
from torch.nn import functional as F

warnings.filterwarnings("ignore")

device = 'cuda' if cuda.is_available() else 'cpu'


class Attention(nn.Module):
    def __init__(self, in_features, *args, bias=True, **kwargs):
        super(Attention, self).__init__(*args, **kwargs)
        self.bias = bias
        self.W = nn.Parameter(torch.randn(in_features, in_features))
        if self.bias:
            self.b = nn.Parameter(torch.randn(in_features))
        self.u = nn.Parameter(torch.randn(in_features))
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        uit = torch.matmul(x, self.W)
        if self.bias:
            uit += self.b
        ait = torch.matmul(self.tanh(uit), self.u)
        attention = self.softmax(ait)
        return attention


class GenderClassifierWithAttention(nn.Module):
    def __init__(self, input_size, hidden_size, embed_size, *args, **kwargs):
        super(GenderClassifierWithAttention, self).__init__(*args, **kwargs)
        self.hid_dim = hidden_size

        # embedding layer
        self.embedding = nn.Embedding(
            num_embeddings=input_size, embedding_dim=embed_size)

        # attention layer
        self.attention = Attention(self.hid_dim * 2)

        # LSTM layer
        self.lstm = nn.LSTM(embed_size, hidden_size=self.hid_dim,
                            num_layers=2, bidirectional=True, batch_first=True)

        # dropout layers
        self.dropout1 = nn.Dropout(p=0.4)
        self.dropout2 = nn.Dropout(p=0.4)

        # normalization layers
        self.batch_norm1 = nn.BatchNorm1d(num_features=15)
        self.batch_norm2 = nn.BatchNorm1d(num_features=7)
        self.layer_norm = nn.LayerNorm(40)

        # linear layers
        self.fc1 = nn.Linear(in_features=self.hid_dim * 2, out_features=15)
        self.fc2 = nn.Linear(in_features=15, out_features=7)
        self.fc3 = nn.Linear(in_features=7, out_features=2)

        # activation functions
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.embedding(x).float()

        out, (h, c) = self.lstm(x)
        out = self.layer_norm(out)
        attention = self.attention(self.tanh(out))
        x = torch.einsum('ijk,ij->ik', out, attention)

        x = self.tanh(self.fc1(x))
        x = self.batch_norm1(x)
        x = self.dropout1(x)

        x = self.tanh(self.fc2(x))
        x = self.batch_norm2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        return x, attention


def get_attention(model):
    def inner(name):
        import string
        char_stoi = {val: key for key, val in dict(
            enumerate(["<PAD>"] + list(string.ascii_lowercase))).items()}

        name_mapped = [char_stoi[char] for char in name.lower()]

        probs, attention = model(torch.tensor([name_mapped]).to(device))
        probs = F.softmax(probs, dim=-1)
        probs_dict = dict(
            zip(["female", "male"], probs.cpu().detach().numpy().flatten().tolist()))

        fig, ax = plt.subplots(nrows=1, ncols=1)
        _ = sns.barplot(x=list(range(attention.shape[1])),
                        y=attention.squeeze().cpu().detach().numpy(),
                        color="#1FCECB", ax=ax)
        _ = ax.set_xticklabels(list(name))

        return probs_dict, fig

    return inner


if __name__ == "__main__":
    model = torch.load("lstm_attention.pt", map_location=device)
    interface = gr.Interface(get_attention(model),
                             inputs="text", outputs=[gr.outputs.Label(label="gender"), gr.outputs.Image(type="plot", label="attention")],
                             title="Visualizing Attention in Gender Classification",
                             examples=["annabella", "ewart",
                                       "blancha", "bronson"], allow_flagging="never").launch()