eaglelandsonce commited on
Commit
81bfc93
·
verified ·
1 Parent(s): 8462f4c

Delete pages/17_RNN_News.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN_News.py +0 -190
pages/17_RNN_News.py DELETED
@@ -1,190 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from torchtext.data.utils import get_tokenizer
6
- from torchtext.vocab import build_vocab_from_iterator
7
- from torchtext.datasets import AG_NEWS
8
- from torch.utils.data import DataLoader, random_split
9
- from torch.nn.utils.rnn import pad_sequence
10
- import matplotlib.pyplot as plt
11
- import pandas as pd
12
-
13
- # Define the RNN model
14
- class RNN(nn.Module):
15
- def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout):
16
- super(RNN, self).__init__()
17
- self.embedding = nn.Embedding(vocab_size, embed_size)
18
- self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True)
19
- self.fc = nn.Linear(hidden_size, output_size)
20
- self.dropout = nn.Dropout(dropout)
21
-
22
- def forward(self, x):
23
- x = self.dropout(self.embedding(x))
24
- h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device)
25
- out, _ = self.rnn(x, h0)
26
- out = self.fc(out[:, -1, :])
27
- return out
28
-
29
- # Create a custom collate function to pad sequences
30
- def collate_batch(batch):
31
- label_list, text_list = [], []
32
- for _label, _text in batch:
33
- label_list.append(label_pipeline(_label))
34
- processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
35
- text_list.append(processed_text)
36
- labels = torch.tensor(label_list, dtype=torch.int64)
37
- texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
38
- return texts, labels
39
-
40
- # Function to load the data
41
- @st.cache_data
42
- def load_data():
43
- tokenizer = get_tokenizer("basic_english")
44
- train_iter = AG_NEWS(split='train')
45
- test_iter = AG_NEWS(split='test')
46
-
47
- def yield_tokens(data_iter):
48
- for _, text in data_iter:
49
- yield tokenizer(text)
50
-
51
- vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
52
- vocab.set_default_index(vocab["<unk>"])
53
-
54
- return vocab, tokenizer, list(train_iter), list(test_iter)
55
-
56
- # Initialize global pipelines
57
- vocab, tokenizer, train_dataset, test_dataset = load_data()
58
- text_pipeline = lambda x: vocab(tokenizer(x))
59
- label_pipeline = lambda x: int(x) - 1
60
-
61
- # Create DataLoaders
62
- train_size = int(0.8 * len(train_dataset))
63
- valid_size = len(train_dataset) - train_size
64
- train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
65
-
66
- BATCH_SIZE = 64
67
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
68
- valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
69
- test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
70
-
71
- # Function to train the network
72
- def train_network(net, iterator, optimizer, criterion, epochs):
73
- loss_values = []
74
- for epoch in range(epochs):
75
- epoch_loss = 0
76
- net.train()
77
- for texts, labels in iterator:
78
- texts, labels = texts.to(device), labels.to(device)
79
- optimizer.zero_grad()
80
- predictions = net(texts)
81
- loss = criterion(predictions, labels)
82
- loss.backward()
83
- optimizer.step()
84
- epoch_loss += loss.item()
85
- epoch_loss /= len(iterator)
86
- loss_values.append(epoch_loss)
87
- st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}')
88
- st.write('Finished Training')
89
- return loss_values
90
-
91
- # Function to evaluate the network
92
- def evaluate_network(net, iterator, criterion):
93
- epoch_loss = 0
94
- correct = 0
95
- total = 0
96
- all_labels = []
97
- all_predictions = []
98
- net.eval()
99
- with torch.no_grad():
100
- for texts, labels in iterator:
101
- texts, labels = texts.to(device), labels.to(device)
102
- predictions = net(texts)
103
- loss = criterion(predictions, labels)
104
- epoch_loss += loss.item()
105
- _, predicted = torch.max(predictions, 1)
106
- correct += (predicted == labels).sum().item()
107
- total += len(labels)
108
- all_labels.extend(labels.cpu().numpy())
109
- all_predictions.extend(predicted.cpu().numpy())
110
- accuracy = 100 * correct / total
111
- st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
112
- return accuracy, all_labels, all_predictions
113
-
114
- # Load data
115
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
116
-
117
- # Streamlit interface
118
- st.title("RNN for Text Classification on AG News Dataset")
119
-
120
- st.write("""
121
- This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the AG News dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance.
122
- """)
123
-
124
- # Sidebar for input parameters
125
- st.sidebar.header('Model Hyperparameters')
126
- embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100)
127
- hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256)
128
- n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2)
129
- dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1)
130
- learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001)
131
- epochs = st.sidebar.slider('Epochs', 1, 20, 5)
132
-
133
- # Create the network
134
- vocab_size = len(vocab)
135
- output_size = 4 # Number of classes in AG_NEWS
136
- net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
137
- criterion = nn.CrossEntropyLoss()
138
- optimizer = optim.Adam(net.parameters(), lr=learning_rate)
139
-
140
- # Add vertical space
141
- st.write('\n' * 10)
142
-
143
- # Train the network
144
- if st.sidebar.button('Train Network'):
145
- loss_values = train_network(net, train_loader, optimizer, criterion, epochs)
146
-
147
- # Plot the loss values
148
- plt.figure(figsize=(10, 5))
149
- plt.plot(range(1, epochs + 1), loss_values, marker='o')
150
- plt.title('Training Loss Over Epochs')
151
- plt.xlabel('Epoch')
152
- plt.ylabel('Loss')
153
- plt.grid(True)
154
- st.pyplot(plt)
155
-
156
- # Store the trained model in the session state
157
- st.session_state['trained_model'] = net
158
-
159
- # Test the network
160
- if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
161
- accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_loader, criterion)
162
- st.write(f'Test Accuracy: {accuracy:.2f}%')
163
-
164
- # Display results in a table
165
- st.write('Ground Truth vs Predicted')
166
- results = pd.DataFrame({
167
- 'Ground Truth': all_labels,
168
- 'Predicted': all_predictions
169
- })
170
- st.table(results.head(50)) # Display first 50 results for brevity
171
-
172
- # Visualize some test results
173
- def visualize_text_predictions(iterator, net):
174
- net.eval()
175
- samples = []
176
- with torch.no_grad():
177
- for texts, labels in iterator:
178
- predictions = torch.max(net(texts), 1)[1]
179
- samples.extend(zip(texts.cpu(), labels.cpu(), predictions.cpu()))
180
- if len(samples) >= 10:
181
- break
182
- return samples[:10]
183
-
184
- if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
185
- samples = visualize_text_predictions(test_loader, st.session_state['trained_model'])
186
- st.write('Ground Truth vs Predicted for Sample Texts')
187
- for i, (text, true_label, predicted) in enumerate(samples):
188
- st.write(f'Sample {i+1}')
189
- st.text(' '.join([vocab.get_itos()[token] for token in text]))
190
- st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')