eaglelandsonce commited on
Commit
b4e298c
·
verified ·
1 Parent(s): 713f0f4

Create 17_RNN_News.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN_News.py +168 -0
pages/17_RNN_News.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchtext.legacy import data, datasets
6
+ import matplotlib.pyplot as plt
7
+ import pandas as pd
8
+ import numpy as np
9
+
10
+ # Define the RNN model
11
+ class RNN(nn.Module):
12
+ def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout):
13
+ super(RNN, self).__init__()
14
+ self.embedding = nn.Embedding(vocab_size, embed_size)
15
+ self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True)
16
+ self.fc = nn.Linear(hidden_size, output_size)
17
+ self.dropout = nn.Dropout(dropout)
18
+
19
+ def forward(self, x):
20
+ x = self.dropout(self.embedding(x))
21
+ h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device)
22
+ out, _ = self.rnn(x, h0)
23
+ out = self.fc(out[:, -1, :])
24
+ return out
25
+
26
+ # Load the data
27
+ @st.cache(allow_output_mutation=True)
28
+ def load_data():
29
+ TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm', include_lengths=True)
30
+ LABEL = data.LabelField(dtype=torch.long)
31
+ train_data, test_data = datasets.AG_NEWS.splits(TEXT, LABEL)
32
+ train_data, valid_data = train_data.split(split_ratio=0.8)
33
+
34
+ TEXT.build_vocab(train_data, max_size=25000, vectors="glove.6B.100d", unk_init=torch.Tensor.normal_)
35
+ LABEL.build_vocab(train_data)
36
+
37
+ BATCH_SIZE = 64
38
+
39
+ train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
40
+ (train_data, valid_data, test_data),
41
+ batch_size=BATCH_SIZE,
42
+ sort_within_batch=True,
43
+ device=device)
44
+
45
+ return TEXT, LABEL, train_iterator, valid_iterator, test_iterator
46
+
47
+ # Train the network
48
+ def train_network(net, iterator, optimizer, criterion, epochs):
49
+ loss_values = []
50
+ for epoch in range(epochs):
51
+ epoch_loss = 0
52
+ net.train()
53
+ for batch in iterator:
54
+ optimizer.zero_grad()
55
+ text, text_lengths = batch.text
56
+ predictions = net(text).squeeze(1)
57
+ loss = criterion(predictions, batch.label)
58
+ loss.backward()
59
+ optimizer.step()
60
+ epoch_loss += loss.item()
61
+ epoch_loss /= len(iterator)
62
+ loss_values.append(epoch_loss)
63
+ st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}')
64
+ st.write('Finished Training')
65
+ return loss_values
66
+
67
+ # Evaluate the network
68
+ def evaluate_network(net, iterator, criterion):
69
+ epoch_loss = 0
70
+ correct = 0
71
+ total = 0
72
+ all_labels = []
73
+ all_predictions = []
74
+ net.eval()
75
+ with torch.no_grad():
76
+ for batch in iterator:
77
+ text, text_lengths = batch.text
78
+ predictions = net(text).squeeze(1)
79
+ loss = criterion(predictions, batch.label)
80
+ epoch_loss += loss.item()
81
+ _, predicted = torch.max(predictions, 1)
82
+ correct += (predicted == batch.label).sum().item()
83
+ total += len(batch.label)
84
+ all_labels.extend(batch.label.cpu().numpy())
85
+ all_predictions.extend(predicted.cpu().numpy())
86
+ accuracy = 100 * correct / total
87
+ st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
88
+ return accuracy, all_labels, all_predictions
89
+
90
+ # Load data
91
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+ TEXT, LABEL, train_iterator, valid_iterator, test_iterator = load_data()
93
+
94
+ # Streamlit interface
95
+ st.title("RNN for Text Classification on AG News Dataset")
96
+
97
+ st.write("""
98
+ This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the AG News dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance.
99
+ """)
100
+
101
+ # Sidebar for input parameters
102
+ st.sidebar.header('Model Hyperparameters')
103
+ embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100)
104
+ hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256)
105
+ n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2)
106
+ dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1)
107
+ learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001)
108
+ epochs = st.sidebar.slider('Epochs', 1, 20, 5)
109
+
110
+ # Create the network
111
+ vocab_size = len(TEXT.vocab)
112
+ output_size = len(LABEL.vocab)
113
+ net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
114
+ criterion = nn.CrossEntropyLoss()
115
+ optimizer = optim.Adam(net.parameters(), lr=learning_rate)
116
+
117
+ # Add vertical space
118
+ st.write('\n' * 10)
119
+
120
+ # Train the network
121
+ if st.sidebar.button('Train Network'):
122
+ loss_values = train_network(net, train_iterator, optimizer, criterion, epochs)
123
+
124
+ # Plot the loss values
125
+ plt.figure(figsize=(10, 5))
126
+ plt.plot(range(1, epochs + 1), loss_values, marker='o')
127
+ plt.title('Training Loss Over Epochs')
128
+ plt.xlabel('Epoch')
129
+ plt.ylabel('Loss')
130
+ plt.grid(True)
131
+ st.pyplot(plt)
132
+
133
+ # Store the trained model in the session state
134
+ st.session_state['trained_model'] = net
135
+
136
+ # Test the network
137
+ if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
138
+ accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_iterator, criterion)
139
+ st.write(f'Test Accuracy: {accuracy:.2f}%')
140
+
141
+ # Display results in a table
142
+ st.write('Ground Truth vs Predicted')
143
+ results = pd.DataFrame({
144
+ 'Ground Truth': [LABEL.vocab.itos[label] for label in all_labels],
145
+ 'Predicted': [LABEL.vocab.itos[label] for label in all_predictions]
146
+ })
147
+ st.table(results.head(50)) # Display first 50 results for brevity
148
+
149
+ # Visualize some test results
150
+ def visualize_text_predictions(iterator, net):
151
+ net.eval()
152
+ samples = []
153
+ with torch.no_grad():
154
+ for batch in iterator:
155
+ text, text_lengths = batch.text
156
+ predictions = torch.max(net(text), 1)[1]
157
+ samples.extend(zip(text.cpu(), batch.label.cpu(), predictions.cpu()))
158
+ if len(samples) >= 10:
159
+ break
160
+ return samples[:10]
161
+
162
+ if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
163
+ samples = visualize_text_predictions(test_iterator, st.session_state['trained_model'])
164
+ st.write('Ground Truth vs Predicted for Sample Texts')
165
+ for i, (text, true_label, predicted) in enumerate(samples):
166
+ st.write(f'Sample {i+1}')
167
+ st.text(' '.join([TEXT.vocab.itos[token] for token in text]))
168
+ st.write(f'Ground Truth: {LABEL.vocab.itos[true_label.item()]}, Predicted: {LABEL.vocab.itos[predicted.item()]}')