Spaces:
Running
Running
Update functions_preprocess.py
Browse files- functions_preprocess.py +21 -2
functions_preprocess.py
CHANGED
@@ -13,7 +13,10 @@ from nltk.corpus import stopwords
|
|
13 |
from nltk.stem import WordNetLemmatizer
|
14 |
from nltk.tokenize import word_tokenize
|
15 |
from nltk.corpus import wordnet
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
def download_if_non_existent(res_path, res_name):
|
19 |
try:
|
@@ -112,4 +115,20 @@ def training_data(dataset_1, dataset_2, dataset_3):
|
|
112 |
X_test = np.array(X_test)
|
113 |
X_train = np.array(X_train)
|
114 |
|
115 |
-
return X_train, y_train, X_test, y_test
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from nltk.stem import WordNetLemmatizer
|
14 |
from nltk.tokenize import word_tokenize
|
15 |
from nltk.corpus import wordnet
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.optim as optim
|
19 |
+
import torch.nn.functional as F
|
20 |
|
21 |
def download_if_non_existent(res_path, res_name):
|
22 |
try:
|
|
|
115 |
X_test = np.array(X_test)
|
116 |
X_train = np.array(X_train)
|
117 |
|
118 |
+
return X_train, y_train, X_test, y_test
|
119 |
+
|
120 |
+
class CNN(nn.Module):
|
121 |
+
def __init__(self, vocab_size, embed_size, n_filters, filter_sizes, dropout, num_classes):
|
122 |
+
super(CNN, self).__init__()
|
123 |
+
self.embedding = nn.Embedding(vocab_size, embed_size)
|
124 |
+
self.convs = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(fs, embed_size)) for fs in filter_sizes])
|
125 |
+
self.dropout = nn.Dropout(dropout)
|
126 |
+
self.fc1 = nn.Linear(len(filter_sizes) * n_filters, num_classes)
|
127 |
+
|
128 |
+
def forward(self, text):
|
129 |
+
embedded = self.embedding(text)
|
130 |
+
embedded = embedded.unsqueeze(1)
|
131 |
+
conved = [F.leaky_relu(conv(embedded)).squeeze(3) for conv in self.convs]
|
132 |
+
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
|
133 |
+
cat = self.dropout(torch.cat(pooled, dim=1))
|
134 |
+
return self.fc1(cat)
|