sahiba12 commited on
Commit
f13d1d6
·
1 Parent(s): d495eaa

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -119
app.py DELETED
@@ -1,119 +0,0 @@
1
- import warnings
2
-
3
- import gradio as gr
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- import torch
7
- from torch import cuda, nn
8
- from torch.nn import functional as F
9
-
10
- warnings.filterwarnings("ignore")
11
-
12
- device = 'cuda' if cuda.is_available() else 'cpu'
13
-
14
-
15
- class Attention(nn.Module):
16
- def __init__(self, in_features, *args, bias=True, **kwargs):
17
- super(Attention, self).__init__(*args, **kwargs)
18
- self.bias = bias
19
- self.W = nn.Parameter(torch.randn(in_features, in_features))
20
- if self.bias:
21
- self.b = nn.Parameter(torch.randn(in_features))
22
- self.u = nn.Parameter(torch.randn(in_features))
23
- self.tanh = nn.Tanh()
24
- self.softmax = nn.Softmax(dim=-1)
25
-
26
- def forward(self, x):
27
- uit = torch.matmul(x, self.W)
28
- if self.bias:
29
- uit += self.b
30
- ait = torch.matmul(self.tanh(uit), self.u)
31
- attention = self.softmax(ait)
32
- return attention
33
-
34
-
35
- class GenderClassifierWithAttention(nn.Module):
36
- def __init__(self, input_size, hidden_size, embed_size, *args, **kwargs):
37
- super(GenderClassifierWithAttention, self).__init__(*args, **kwargs)
38
- self.hid_dim = hidden_size
39
-
40
- # embedding layer
41
- self.embedding = nn.Embedding(
42
- num_embeddings=input_size, embedding_dim=embed_size)
43
-
44
- # attention layer
45
- self.attention = Attention(self.hid_dim * 2)
46
-
47
- # LSTM layer
48
- self.lstm = nn.LSTM(embed_size, hidden_size=self.hid_dim,
49
- num_layers=2, bidirectional=True, batch_first=True)
50
-
51
- # dropout layers
52
- self.dropout1 = nn.Dropout(p=0.4)
53
- self.dropout2 = nn.Dropout(p=0.4)
54
-
55
- # normalization layers
56
- self.batch_norm1 = nn.BatchNorm1d(num_features=15)
57
- self.batch_norm2 = nn.BatchNorm1d(num_features=7)
58
- self.layer_norm = nn.LayerNorm(40)
59
-
60
- # linear layers
61
- self.fc1 = nn.Linear(in_features=self.hid_dim * 2, out_features=15)
62
- self.fc2 = nn.Linear(in_features=15, out_features=7)
63
- self.fc3 = nn.Linear(in_features=7, out_features=2)
64
-
65
- # activation functions
66
- self.tanh = nn.Tanh()
67
- self.relu = nn.ReLU()
68
-
69
- def forward(self, x):
70
- x = self.embedding(x).float()
71
-
72
- out, (h, c) = self.lstm(x)
73
- out = self.layer_norm(out)
74
- attention = self.attention(self.tanh(out))
75
- x = torch.einsum('ijk,ij->ik', out, attention)
76
-
77
- x = self.tanh(self.fc1(x))
78
- x = self.batch_norm1(x)
79
- x = self.dropout1(x)
80
-
81
- x = self.tanh(self.fc2(x))
82
- x = self.batch_norm2(x)
83
- x = self.dropout2(x)
84
-
85
- x = self.fc3(x)
86
- return x, attention
87
-
88
-
89
- def get_attention(model):
90
- def inner(name):
91
- import string
92
- char_stoi = {val: key for key, val in dict(
93
- enumerate(["<PAD>"] + list(string.ascii_lowercase))).items()}
94
-
95
- name_mapped = [char_stoi[char] for char in name.lower()]
96
-
97
- probs, attention = model(torch.tensor([name_mapped]).to(device))
98
- probs = F.softmax(probs, dim=-1)
99
- probs_dict = dict(
100
- zip(["female", "male"], probs.cpu().detach().numpy().flatten().tolist()))
101
-
102
- fig, ax = plt.subplots(nrows=1, ncols=1)
103
- _ = sns.barplot(x=list(range(attention.shape[1])),
104
- y=attention.squeeze().cpu().detach().numpy(),
105
- color="#1FCECB", ax=ax)
106
- _ = ax.set_xticklabels(list(name))
107
-
108
- return probs_dict, fig
109
-
110
- return inner
111
-
112
-
113
- if __name__ == "__main__":
114
- model = torch.load("lstm_attention.pt", map_location=device)
115
- interface = gr.Interface(get_attention(model),
116
- inputs="text", outputs=[gr.outputs.Label(label="gender"), gr.outputs.Image(type="plot", label="attention")],
117
- title="Visualizing Attention in Gender Classification",
118
- examples=["annabella", "ewart",
119
- "blancha", "bronson"], allow_flagging="never").launch()