sent / app.py
sahiba12's picture
Update app.py
d495eaa
raw
history blame
4 kB
import warnings
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import cuda, nn
from torch.nn import functional as F
warnings.filterwarnings("ignore")
device = 'cuda' if cuda.is_available() else 'cpu'
class Attention(nn.Module):
def __init__(self, in_features, *args, bias=True, **kwargs):
super(Attention, self).__init__(*args, **kwargs)
self.bias = bias
self.W = nn.Parameter(torch.randn(in_features, in_features))
if self.bias:
self.b = nn.Parameter(torch.randn(in_features))
self.u = nn.Parameter(torch.randn(in_features))
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
uit = torch.matmul(x, self.W)
if self.bias:
uit += self.b
ait = torch.matmul(self.tanh(uit), self.u)
attention = self.softmax(ait)
return attention
class GenderClassifierWithAttention(nn.Module):
def __init__(self, input_size, hidden_size, embed_size, *args, **kwargs):
super(GenderClassifierWithAttention, self).__init__(*args, **kwargs)
self.hid_dim = hidden_size
# embedding layer
self.embedding = nn.Embedding(
num_embeddings=input_size, embedding_dim=embed_size)
# attention layer
self.attention = Attention(self.hid_dim * 2)
# LSTM layer
self.lstm = nn.LSTM(embed_size, hidden_size=self.hid_dim,
num_layers=2, bidirectional=True, batch_first=True)
# dropout layers
self.dropout1 = nn.Dropout(p=0.4)
self.dropout2 = nn.Dropout(p=0.4)
# normalization layers
self.batch_norm1 = nn.BatchNorm1d(num_features=15)
self.batch_norm2 = nn.BatchNorm1d(num_features=7)
self.layer_norm = nn.LayerNorm(40)
# linear layers
self.fc1 = nn.Linear(in_features=self.hid_dim * 2, out_features=15)
self.fc2 = nn.Linear(in_features=15, out_features=7)
self.fc3 = nn.Linear(in_features=7, out_features=2)
# activation functions
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
def forward(self, x):
x = self.embedding(x).float()
out, (h, c) = self.lstm(x)
out = self.layer_norm(out)
attention = self.attention(self.tanh(out))
x = torch.einsum('ijk,ij->ik', out, attention)
x = self.tanh(self.fc1(x))
x = self.batch_norm1(x)
x = self.dropout1(x)
x = self.tanh(self.fc2(x))
x = self.batch_norm2(x)
x = self.dropout2(x)
x = self.fc3(x)
return x, attention
def get_attention(model):
def inner(name):
import string
char_stoi = {val: key for key, val in dict(
enumerate(["<PAD>"] + list(string.ascii_lowercase))).items()}
name_mapped = [char_stoi[char] for char in name.lower()]
probs, attention = model(torch.tensor([name_mapped]).to(device))
probs = F.softmax(probs, dim=-1)
probs_dict = dict(
zip(["female", "male"], probs.cpu().detach().numpy().flatten().tolist()))
fig, ax = plt.subplots(nrows=1, ncols=1)
_ = sns.barplot(x=list(range(attention.shape[1])),
y=attention.squeeze().cpu().detach().numpy(),
color="#1FCECB", ax=ax)
_ = ax.set_xticklabels(list(name))
return probs_dict, fig
return inner
if __name__ == "__main__":
model = torch.load("lstm_attention.pt", map_location=device)
interface = gr.Interface(get_attention(model),
inputs="text", outputs=[gr.outputs.Label(label="gender"), gr.outputs.Image(type="plot", label="attention")],
title="Visualizing Attention in Gender Classification",
examples=["annabella", "ewart",
"blancha", "bronson"], allow_flagging="never").launch()