Spaces:
Running
Running
Update pages/17_RNN_News.py
Browse files- pages/17_RNN_News.py +66 -40
pages/17_RNN_News.py
CHANGED
@@ -2,7 +2,10 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.optim as optim
|
5 |
-
from torchtext.
|
|
|
|
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
import pandas as pd
|
8 |
import numpy as np
|
@@ -23,38 +26,62 @@ class RNN(nn.Module):
|
|
23 |
out = self.fc(out[:, -1, :])
|
24 |
return out
|
25 |
|
26 |
-
#
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def load_data():
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
train_data, valid_data = train_data.split(split_ratio=0.8)
|
33 |
|
34 |
-
|
35 |
-
|
|
|
36 |
|
37 |
-
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
sort_within_batch=True,
|
43 |
-
device=device)
|
44 |
|
45 |
-
|
|
|
|
|
46 |
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def train_network(net, iterator, optimizer, criterion, epochs):
|
49 |
loss_values = []
|
50 |
for epoch in range(epochs):
|
51 |
epoch_loss = 0
|
52 |
net.train()
|
53 |
-
for
|
|
|
54 |
optimizer.zero_grad()
|
55 |
-
|
56 |
-
|
57 |
-
loss = criterion(predictions, batch.label)
|
58 |
loss.backward()
|
59 |
optimizer.step()
|
60 |
epoch_loss += loss.item()
|
@@ -64,7 +91,7 @@ def train_network(net, iterator, optimizer, criterion, epochs):
|
|
64 |
st.write('Finished Training')
|
65 |
return loss_values
|
66 |
|
67 |
-
#
|
68 |
def evaluate_network(net, iterator, criterion):
|
69 |
epoch_loss = 0
|
70 |
correct = 0
|
@@ -73,15 +100,15 @@ def evaluate_network(net, iterator, criterion):
|
|
73 |
all_predictions = []
|
74 |
net.eval()
|
75 |
with torch.no_grad():
|
76 |
-
for
|
77 |
-
|
78 |
-
predictions = net(
|
79 |
-
loss = criterion(predictions,
|
80 |
epoch_loss += loss.item()
|
81 |
_, predicted = torch.max(predictions, 1)
|
82 |
-
correct += (predicted ==
|
83 |
-
total += len(
|
84 |
-
all_labels.extend(
|
85 |
all_predictions.extend(predicted.cpu().numpy())
|
86 |
accuracy = 100 * correct / total
|
87 |
st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
|
@@ -89,7 +116,7 @@ def evaluate_network(net, iterator, criterion):
|
|
89 |
|
90 |
# Load data
|
91 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
92 |
-
|
93 |
|
94 |
# Streamlit interface
|
95 |
st.title("RNN for Text Classification on AG News Dataset")
|
@@ -108,8 +135,8 @@ learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001)
|
|
108 |
epochs = st.sidebar.slider('Epochs', 1, 20, 5)
|
109 |
|
110 |
# Create the network
|
111 |
-
vocab_size = len(
|
112 |
-
output_size =
|
113 |
net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
|
114 |
criterion = nn.CrossEntropyLoss()
|
115 |
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
|
@@ -119,7 +146,7 @@ st.write('\n' * 10)
|
|
119 |
|
120 |
# Train the network
|
121 |
if st.sidebar.button('Train Network'):
|
122 |
-
loss_values = train_network(net,
|
123 |
|
124 |
# Plot the loss values
|
125 |
plt.figure(figsize=(10, 5))
|
@@ -135,7 +162,7 @@ if st.sidebar.button('Train Network'):
|
|
135 |
|
136 |
# Test the network
|
137 |
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
|
138 |
-
accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'],
|
139 |
st.write(f'Test Accuracy: {accuracy:.2f}%')
|
140 |
|
141 |
# Display results in a table
|
@@ -151,18 +178,17 @@ def visualize_text_predictions(iterator, net):
|
|
151 |
net.eval()
|
152 |
samples = []
|
153 |
with torch.no_grad():
|
154 |
-
for
|
155 |
-
|
156 |
-
|
157 |
-
samples.extend(zip(text.cpu(), batch.label.cpu(), predictions.cpu()))
|
158 |
if len(samples) >= 10:
|
159 |
break
|
160 |
return samples[:10]
|
161 |
|
162 |
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
|
163 |
-
samples = visualize_text_predictions(
|
164 |
st.write('Ground Truth vs Predicted for Sample Texts')
|
165 |
for i, (text, true_label, predicted) in enumerate(samples):
|
166 |
st.write(f'Sample {i+1}')
|
167 |
-
st.text(' '.join([
|
168 |
st.write(f'Ground Truth: {LABEL.vocab.itos[true_label.item()]}, Predicted: {LABEL.vocab.itos[predicted.item()]}')
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.optim as optim
|
5 |
+
from torchtext.data.utils import get_tokenizer
|
6 |
+
from torchtext.vocab import build_vocab_from_iterator
|
7 |
+
from torchtext.datasets import AG_NEWS
|
8 |
+
from torch.utils.data import DataLoader, random_split
|
9 |
import matplotlib.pyplot as plt
|
10 |
import pandas as pd
|
11 |
import numpy as np
|
|
|
26 |
out = self.fc(out[:, -1, :])
|
27 |
return out
|
28 |
|
29 |
+
# Create a custom collate function to pad sequences
|
30 |
+
def collate_batch(batch):
|
31 |
+
label_list, text_list, lengths = [], [], []
|
32 |
+
for _label, _text in batch:
|
33 |
+
label_list.append(label_pipeline(_label))
|
34 |
+
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
|
35 |
+
text_list.append(processed_text)
|
36 |
+
lengths.append(processed_text.size(0))
|
37 |
+
labels = torch.tensor(label_list, dtype=torch.int64)
|
38 |
+
texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
|
39 |
+
return texts, labels
|
40 |
+
|
41 |
+
# Function to load the data
|
42 |
+
@st.cache_data
|
43 |
def load_data():
|
44 |
+
tokenizer = get_tokenizer("basic_english")
|
45 |
+
train_iter = AG_NEWS(split='train')
|
46 |
+
test_iter = AG_NEWS(split='test')
|
|
|
47 |
|
48 |
+
def yield_tokens(data_iter):
|
49 |
+
for _, text in data_iter:
|
50 |
+
yield tokenizer(text)
|
51 |
|
52 |
+
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
53 |
+
vocab.set_default_index(vocab["<unk>"])
|
54 |
|
55 |
+
global text_pipeline, label_pipeline
|
56 |
+
text_pipeline = lambda x: vocab(tokenizer(x))
|
57 |
+
label_pipeline = lambda x: int(x) - 1
|
|
|
|
|
58 |
|
59 |
+
# Create DataLoaders
|
60 |
+
train_dataset = list(train_iter)
|
61 |
+
test_dataset = list(test_iter)
|
62 |
|
63 |
+
train_size = int(0.8 * len(train_dataset))
|
64 |
+
valid_size = len(train_dataset) - train_size
|
65 |
+
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
|
66 |
+
|
67 |
+
BATCH_SIZE = 64
|
68 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
69 |
+
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
70 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
71 |
+
|
72 |
+
return vocab, train_loader, valid_loader, test_loader
|
73 |
+
|
74 |
+
# Function to train the network
|
75 |
def train_network(net, iterator, optimizer, criterion, epochs):
|
76 |
loss_values = []
|
77 |
for epoch in range(epochs):
|
78 |
epoch_loss = 0
|
79 |
net.train()
|
80 |
+
for texts, labels in iterator:
|
81 |
+
texts, labels = texts.to(device), labels.to(device)
|
82 |
optimizer.zero_grad()
|
83 |
+
predictions = net(texts).squeeze(1)
|
84 |
+
loss = criterion(predictions, labels)
|
|
|
85 |
loss.backward()
|
86 |
optimizer.step()
|
87 |
epoch_loss += loss.item()
|
|
|
91 |
st.write('Finished Training')
|
92 |
return loss_values
|
93 |
|
94 |
+
# Function to evaluate the network
|
95 |
def evaluate_network(net, iterator, criterion):
|
96 |
epoch_loss = 0
|
97 |
correct = 0
|
|
|
100 |
all_predictions = []
|
101 |
net.eval()
|
102 |
with torch.no_grad():
|
103 |
+
for texts, labels in iterator:
|
104 |
+
texts, labels = texts.to(device), labels.to(device)
|
105 |
+
predictions = net(texts).squeeze(1)
|
106 |
+
loss = criterion(predictions, labels)
|
107 |
epoch_loss += loss.item()
|
108 |
_, predicted = torch.max(predictions, 1)
|
109 |
+
correct += (predicted == labels).sum().item()
|
110 |
+
total += len(labels)
|
111 |
+
all_labels.extend(labels.cpu().numpy())
|
112 |
all_predictions.extend(predicted.cpu().numpy())
|
113 |
accuracy = 100 * correct / total
|
114 |
st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
|
|
|
116 |
|
117 |
# Load data
|
118 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
119 |
+
vocab, train_loader, valid_loader, test_loader = load_data()
|
120 |
|
121 |
# Streamlit interface
|
122 |
st.title("RNN for Text Classification on AG News Dataset")
|
|
|
135 |
epochs = st.sidebar.slider('Epochs', 1, 20, 5)
|
136 |
|
137 |
# Create the network
|
138 |
+
vocab_size = len(vocab)
|
139 |
+
output_size = 4 # Number of classes in AG_NEWS
|
140 |
net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
|
141 |
criterion = nn.CrossEntropyLoss()
|
142 |
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
|
|
|
146 |
|
147 |
# Train the network
|
148 |
if st.sidebar.button('Train Network'):
|
149 |
+
loss_values = train_network(net, train_loader, optimizer, criterion, epochs)
|
150 |
|
151 |
# Plot the loss values
|
152 |
plt.figure(figsize=(10, 5))
|
|
|
162 |
|
163 |
# Test the network
|
164 |
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
|
165 |
+
accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_loader, criterion)
|
166 |
st.write(f'Test Accuracy: {accuracy:.2f}%')
|
167 |
|
168 |
# Display results in a table
|
|
|
178 |
net.eval()
|
179 |
samples = []
|
180 |
with torch.no_grad():
|
181 |
+
for texts, labels in iterator:
|
182 |
+
predictions = torch.max(net(texts), 1)[1]
|
183 |
+
samples.extend(zip(texts.cpu(), labels.cpu(), predictions.cpu()))
|
|
|
184 |
if len(samples) >= 10:
|
185 |
break
|
186 |
return samples[:10]
|
187 |
|
188 |
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
|
189 |
+
samples = visualize_text_predictions(test_loader, st.session_state['trained_model'])
|
190 |
st.write('Ground Truth vs Predicted for Sample Texts')
|
191 |
for i, (text, true_label, predicted) in enumerate(samples):
|
192 |
st.write(f'Sample {i+1}')
|
193 |
+
st.text(' '.join([vocab.get_itos()[token] for token in text]))
|
194 |
st.write(f'Ground Truth: {LABEL.vocab.itos[true_label.item()]}, Predicted: {LABEL.vocab.itos[predicted.item()]}')
|