Spaces:
Sleeping
Sleeping
Update pages/imdb.py
Browse files- pages/imdb.py +9 -4
pages/imdb.py
CHANGED
@@ -16,15 +16,21 @@ from data.rnn_preprocessing import (
|
|
16 |
|
17 |
# Load Word2Vec model
|
18 |
wv = Word2Vec.load('models/word2vec32.model')
|
19 |
-
embedding_matrix =
|
20 |
vocab_to_int = {word: idx + 1 for idx, word in enumerate(wv.wv.index_to_key)}
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Load TF-IDF model
|
23 |
tfidf_model = pickle.load(open('models/modeltfidf.sav', 'rb'))
|
24 |
|
25 |
# Load LSTM model
|
26 |
-
embedding_layer32 =
|
27 |
-
VOCAB_SIZE
|
28 |
HIDDEN_DIM = 64
|
29 |
SEQ_LEN = 32
|
30 |
|
@@ -63,7 +69,6 @@ model = LSTMClassifierBi32(embedding_dim=EMBEDDING_DIM, hidden_size=HIDDEN_DIM)
|
|
63 |
model.load_state_dict(torch.load('models/ltsm_bi1.pt'))
|
64 |
model.eval()
|
65 |
|
66 |
-
|
67 |
def predict_sentence(text: str, model: nn.Module):
|
68 |
result = model(preprocess_single_string(text, seq_len=SEQ_LEN, vocab_to_int=vocab_to_int).unsqueeze(0)).sigmoid().round().item()
|
69 |
return 'negative' if result == 0.0 else 'positive'
|
|
|
16 |
|
17 |
# Load Word2Vec model
|
18 |
wv = Word2Vec.load('models/word2vec32.model')
|
19 |
+
embedding_matrix = np.zeros((VOCAB_SIZE, EMBEDDING_DIM))
|
20 |
vocab_to_int = {word: idx + 1 for idx, word in enumerate(wv.wv.index_to_key)}
|
21 |
+
for word, i in vocab_to_int.items():
|
22 |
+
try:
|
23 |
+
embedding_vector = wv.wv[word]
|
24 |
+
embedding_matrix[i] = embedding_vector
|
25 |
+
except KeyError:
|
26 |
+
pass
|
27 |
|
28 |
# Load TF-IDF model
|
29 |
tfidf_model = pickle.load(open('models/modeltfidf.sav', 'rb'))
|
30 |
|
31 |
# Load LSTM model
|
32 |
+
embedding_layer32 = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
|
33 |
+
VOCAB_SIZE = len(vocab_to_int) + 1 # add 1 for the padding token
|
34 |
HIDDEN_DIM = 64
|
35 |
SEQ_LEN = 32
|
36 |
|
|
|
69 |
model.load_state_dict(torch.load('models/ltsm_bi1.pt'))
|
70 |
model.eval()
|
71 |
|
|
|
72 |
def predict_sentence(text: str, model: nn.Module):
|
73 |
result = model(preprocess_single_string(text, seq_len=SEQ_LEN, vocab_to_int=vocab_to_int).unsqueeze(0)).sigmoid().round().item()
|
74 |
return 'negative' if result == 0.0 else 'positive'
|