Spaces:
Running
Running
Update pages/17_RNN_News.py
Browse files- pages/17_RNN_News.py +13 -17
pages/17_RNN_News.py
CHANGED
@@ -51,25 +51,22 @@ def load_data():
|
|
51 |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
52 |
vocab.set_default_index(vocab["<unk>"])
|
53 |
|
54 |
-
|
55 |
-
global text_pipeline, label_pipeline
|
56 |
-
text_pipeline = lambda x: vocab(tokenizer(x))
|
57 |
-
label_pipeline = lambda x: int(x) - 1
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
return vocab, train_loader, valid_loader, test_loader
|
73 |
|
74 |
# Function to train the network
|
75 |
def train_network(net, iterator, optimizer, criterion, epochs):
|
@@ -116,7 +113,6 @@ def evaluate_network(net, iterator, criterion):
|
|
116 |
|
117 |
# Load data
|
118 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
119 |
-
vocab, train_loader, valid_loader, test_loader = load_data()
|
120 |
|
121 |
# Streamlit interface
|
122 |
st.title("RNN for Text Classification on AG News Dataset")
|
|
|
51 |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
52 |
vocab.set_default_index(vocab["<unk>"])
|
53 |
|
54 |
+
return vocab, tokenizer, list(train_iter), list(test_iter)
|
|
|
|
|
|
|
55 |
|
56 |
+
# Initialize global pipelines
|
57 |
+
vocab, tokenizer, train_dataset, test_dataset = load_data()
|
58 |
+
text_pipeline = lambda x: vocab(tokenizer(x))
|
59 |
+
label_pipeline = lambda x: int(x) - 1
|
60 |
|
61 |
+
# Create DataLoaders
|
62 |
+
train_size = int(0.8 * len(train_dataset))
|
63 |
+
valid_size = len(train_dataset) - train_size
|
64 |
+
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
|
65 |
|
66 |
+
BATCH_SIZE = 64
|
67 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
68 |
+
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
69 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
|
|
|
|
70 |
|
71 |
# Function to train the network
|
72 |
def train_network(net, iterator, optimizer, criterion, epochs):
|
|
|
113 |
|
114 |
# Load data
|
115 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
116 |
|
117 |
# Streamlit interface
|
118 |
st.title("RNN for Text Classification on AG News Dataset")
|