|
import warnings |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import torch |
|
from torch import cuda, nn |
|
from torch.nn import functional as F |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
device = 'cuda' if cuda.is_available() else 'cpu' |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, in_features, *args, bias=True, **kwargs): |
|
super(Attention, self).__init__(*args, **kwargs) |
|
self.bias = bias |
|
self.W = nn.Parameter(torch.randn(in_features, in_features)) |
|
if self.bias: |
|
self.b = nn.Parameter(torch.randn(in_features)) |
|
self.u = nn.Parameter(torch.randn(in_features)) |
|
self.tanh = nn.Tanh() |
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def forward(self, x): |
|
uit = torch.matmul(x, self.W) |
|
if self.bias: |
|
uit += self.b |
|
ait = torch.matmul(self.tanh(uit), self.u) |
|
attention = self.softmax(ait) |
|
return attention |
|
|
|
|
|
class GenderClassifierWithAttention(nn.Module): |
|
def __init__(self, input_size, hidden_size, embed_size, *args, **kwargs): |
|
super(GenderClassifierWithAttention, self).__init__(*args, **kwargs) |
|
self.hid_dim = hidden_size |
|
|
|
|
|
self.embedding = nn.Embedding( |
|
num_embeddings=input_size, embedding_dim=embed_size) |
|
|
|
|
|
self.attention = Attention(self.hid_dim * 2) |
|
|
|
|
|
self.lstm = nn.LSTM(embed_size, hidden_size=self.hid_dim, |
|
num_layers=2, bidirectional=True, batch_first=True) |
|
|
|
|
|
self.dropout1 = nn.Dropout(p=0.4) |
|
self.dropout2 = nn.Dropout(p=0.4) |
|
|
|
|
|
self.batch_norm1 = nn.BatchNorm1d(num_features=15) |
|
self.batch_norm2 = nn.BatchNorm1d(num_features=7) |
|
self.layer_norm = nn.LayerNorm(40) |
|
|
|
|
|
self.fc1 = nn.Linear(in_features=self.hid_dim * 2, out_features=15) |
|
self.fc2 = nn.Linear(in_features=15, out_features=7) |
|
self.fc3 = nn.Linear(in_features=7, out_features=2) |
|
|
|
|
|
self.tanh = nn.Tanh() |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x): |
|
x = self.embedding(x).float() |
|
|
|
out, (h, c) = self.lstm(x) |
|
out = self.layer_norm(out) |
|
attention = self.attention(self.tanh(out)) |
|
x = torch.einsum('ijk,ij->ik', out, attention) |
|
|
|
x = self.tanh(self.fc1(x)) |
|
x = self.batch_norm1(x) |
|
x = self.dropout1(x) |
|
|
|
x = self.tanh(self.fc2(x)) |
|
x = self.batch_norm2(x) |
|
x = self.dropout2(x) |
|
|
|
x = self.fc3(x) |
|
return x, attention |
|
|
|
|
|
def get_attention(model): |
|
def inner(name): |
|
import string |
|
char_stoi = {val: key for key, val in dict( |
|
enumerate(["<PAD>"] + list(string.ascii_lowercase))).items()} |
|
|
|
name_mapped = [char_stoi[char] for char in name.lower()] |
|
|
|
probs, attention = model(torch.tensor([name_mapped]).to(device)) |
|
probs = F.softmax(probs, dim=-1) |
|
probs_dict = dict( |
|
zip(["female", "male"], probs.cpu().detach().numpy().flatten().tolist())) |
|
|
|
fig, ax = plt.subplots(nrows=1, ncols=1) |
|
_ = sns.barplot(x=list(range(attention.shape[1])), |
|
y=attention.squeeze().cpu().detach().numpy(), |
|
color="#1FCECB", ax=ax) |
|
_ = ax.set_xticklabels(list(name)) |
|
|
|
return probs_dict, fig |
|
|
|
return inner |
|
|
|
|
|
if __name__ == "__main__": |
|
model = torch.load("lstm_attention.pt", map_location=device) |
|
interface = gr.Interface(get_attention(model), |
|
inputs="text", outputs=[gr.outputs.Label(label="gender"), gr.outputs.Image(type="plot", label="attention")], |
|
title="Visualizing Attention in Gender Classification", |
|
examples=["annabella", "ewart", |
|
"blancha", "bronson"], allow_flagging="never").launch() |