nlp_lstm_project / model /ltsm_att.py
Анастасия
project_streamlit_app
019c64d
import json
import numpy as np
from gensim.models import Word2Vec
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import torch.nn as nn
import torchutils as tu
from torchmetrics import Accuracy
from torchmetrics.functional import f1_score
from string import punctuation
import time
with open('model/word_dict.json', 'r') as fp:
vocab_to_int = json.load(fp)
wv = Word2Vec.load("model/word2vec_for_ltsm.model")
VOCAB_SIZE = len(vocab_to_int)+1
HIDDEN_SIZE = 128
SEQ_LEN = 128
DEVICE='cpu'
EMBEDDING_DIM = 128
embedding_matrix = torch.load('model/embedding_matrix.pt')
embedding_layer = torch.nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
class ConcatAttention(nn.Module):
def __init__(
self,
hidden_size: torch.Tensor = HIDDEN_SIZE
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.linear = nn.Linear(hidden_size, hidden_size)
self.align = nn.Linear(hidden_size * 2, hidden_size)
self.tanh = nn.Tanh()
def forward(
self,
lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
final_hidden: torch.Tensor # BATCH_SIZE x HIDDEN_SIZE
) -> tuple[torch.Tensor]:
att_weights = self.linear(lstm_outputs)
att_weights = torch.bmm(att_weights, final_hidden.unsqueeze(2))
att_weights = F.softmax(att_weights.squeeze(2), dim=1)
cntxt = torch.bmm(lstm_outputs.transpose(1, 2), att_weights.unsqueeze(2))
concatted = torch.cat((cntxt, final_hidden.unsqueeze(2)), dim=1)
att_hidden = self.tanh(self.align(concatted.squeeze(-1)))
return att_hidden, att_weights
class LSTMConcatAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embedding = embedding_layer
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
self.attn = ConcatAttention(HIDDEN_SIZE)
self.clf = nn.Sequential(
nn.Linear(HIDDEN_SIZE, 128),
nn.Dropout(),
nn.Tanh(),
nn.Linear(128, 3)
)
def forward(self, x):
embeddings = self.embedding(x)
outputs, (h_n, _) = self.lstm(embeddings)
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
out = self.clf(att_hidden)
return out, att_weights
model_concat = LSTMConcatAttention()
model_concat.load_state_dict(torch.load('model/lstm_att_weight.pt', map_location='cpu'))
model_concat.eval()
def pred(text):
start_time = time.time()
text = text.lower()
text = ''.join([c for c in text if c not in punctuation])
text = [vocab_to_int[word] for word in text.split() if vocab_to_int.get(word)]
if len(text) <= 128:
zeros = list(np.zeros(128 - len(text)))
text = zeros + text
else:
text = text[: 128]
text = torch.Tensor(text)
text = text.unsqueeze(0)
text = text.type(torch.LongTensor)
# print(text.shape)
pred = model_concat(text)[0].argmax(1)
labels = {0: 'Негативный', 1:'Позитивный', 2:'Нейтральный'}
end_time = time.time()
inference_time = end_time - start_time
# return labels[pred.item()], inference_time
return f"***{labels[pred.item()]}***, время предсказания: ***{inference_time:.4f} сек***."