File size: 1,843 Bytes
1867879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st
import torch
import torch.nn as nn
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import re
import string
from collections import Counter
import numpy as np
from typing import List
import time

# Загрузка предобученной модели
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Wk = nn.Linear(hidden_size, hidden_size)
        self.Wv = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        query = query.unsqueeze(1)  # (batch_size, 1, hidden_size)
        scores = self.Wv(torch.tanh(self.Wa(query) + self.Wk(keys))).squeeze(2)  # (batch_size, seq_len)
        attention_weights = torch.softmax(scores, dim=1)  # (batch_size, seq_len)
        context = torch.bmm(attention_weights.unsqueeze(1), keys).squeeze(1)  # (batch_size, hidden_size)
        return context, attention_weights

class LSTM_Word2Vec_Attention(nn.Module):
    def __init__(self, hidden_size: int, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.attn = BahdanauAttention(hidden_size)
        self.clf = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.Dropout(),
            nn.Tanh(),
            nn.Linear(128, 3)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        embedded = self.embedding(x)
        output, (hidden, _) = self.lstm(embedded)
        context, attention_weights = self.attn(hidden[-1], output)
        output = self.clf(context.squeeze(1))
        output = self.sigmoid(output)
        return output, attention_weights