vasevooo commited on
Commit
b618451
·
1 Parent(s): dd68f2f

Update pages/imdb.py

Browse files
Files changed (1) hide show
  1. pages/imdb.py +9 -4
pages/imdb.py CHANGED
@@ -16,15 +16,21 @@ from data.rnn_preprocessing import (
16
 
17
  # Load Word2Vec model
18
  wv = Word2Vec.load('models/word2vec32.model')
19
- embedding_matrix = wv.wv.vectors
20
  vocab_to_int = {word: idx + 1 for idx, word in enumerate(wv.wv.index_to_key)}
 
 
 
 
 
 
21
 
22
  # Load TF-IDF model
23
  tfidf_model = pickle.load(open('models/modeltfidf.sav', 'rb'))
24
 
25
  # Load LSTM model
26
- embedding_layer32 = torch.nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
27
- VOCAB_SIZE, EMBEDDING_DIM = embedding_matrix.shape
28
  HIDDEN_DIM = 64
29
  SEQ_LEN = 32
30
 
@@ -63,7 +69,6 @@ model = LSTMClassifierBi32(embedding_dim=EMBEDDING_DIM, hidden_size=HIDDEN_DIM)
63
  model.load_state_dict(torch.load('models/ltsm_bi1.pt'))
64
  model.eval()
65
 
66
-
67
  def predict_sentence(text: str, model: nn.Module):
68
  result = model(preprocess_single_string(text, seq_len=SEQ_LEN, vocab_to_int=vocab_to_int).unsqueeze(0)).sigmoid().round().item()
69
  return 'negative' if result == 0.0 else 'positive'
 
16
 
17
  # Load Word2Vec model
18
  wv = Word2Vec.load('models/word2vec32.model')
19
+ embedding_matrix = np.zeros((VOCAB_SIZE, EMBEDDING_DIM))
20
  vocab_to_int = {word: idx + 1 for idx, word in enumerate(wv.wv.index_to_key)}
21
+ for word, i in vocab_to_int.items():
22
+ try:
23
+ embedding_vector = wv.wv[word]
24
+ embedding_matrix[i] = embedding_vector
25
+ except KeyError:
26
+ pass
27
 
28
  # Load TF-IDF model
29
  tfidf_model = pickle.load(open('models/modeltfidf.sav', 'rb'))
30
 
31
  # Load LSTM model
32
+ embedding_layer32 = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
33
+ VOCAB_SIZE = len(vocab_to_int) + 1 # add 1 for the padding token
34
  HIDDEN_DIM = 64
35
  SEQ_LEN = 32
36
 
 
69
  model.load_state_dict(torch.load('models/ltsm_bi1.pt'))
70
  model.eval()
71
 
 
72
  def predict_sentence(text: str, model: nn.Module):
73
  result = model(preprocess_single_string(text, seq_len=SEQ_LEN, vocab_to_int=vocab_to_int).unsqueeze(0)).sigmoid().round().item()
74
  return 'negative' if result == 0.0 else 'positive'