eaglelandsonce commited on
Commit
284b3b6
·
verified ·
1 Parent(s): b4e298c

Update pages/17_RNN_News.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN_News.py +66 -40
pages/17_RNN_News.py CHANGED
@@ -2,7 +2,10 @@ import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
- from torchtext.legacy import data, datasets
 
 
 
6
  import matplotlib.pyplot as plt
7
  import pandas as pd
8
  import numpy as np
@@ -23,38 +26,62 @@ class RNN(nn.Module):
23
  out = self.fc(out[:, -1, :])
24
  return out
25
 
26
- # Load the data
27
- @st.cache(allow_output_mutation=True)
 
 
 
 
 
 
 
 
 
 
 
 
28
  def load_data():
29
- TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm', include_lengths=True)
30
- LABEL = data.LabelField(dtype=torch.long)
31
- train_data, test_data = datasets.AG_NEWS.splits(TEXT, LABEL)
32
- train_data, valid_data = train_data.split(split_ratio=0.8)
33
 
34
- TEXT.build_vocab(train_data, max_size=25000, vectors="glove.6B.100d", unk_init=torch.Tensor.normal_)
35
- LABEL.build_vocab(train_data)
 
36
 
37
- BATCH_SIZE = 64
 
38
 
39
- train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
40
- (train_data, valid_data, test_data),
41
- batch_size=BATCH_SIZE,
42
- sort_within_batch=True,
43
- device=device)
44
 
45
- return TEXT, LABEL, train_iterator, valid_iterator, test_iterator
 
 
46
 
47
- # Train the network
 
 
 
 
 
 
 
 
 
 
 
48
  def train_network(net, iterator, optimizer, criterion, epochs):
49
  loss_values = []
50
  for epoch in range(epochs):
51
  epoch_loss = 0
52
  net.train()
53
- for batch in iterator:
 
54
  optimizer.zero_grad()
55
- text, text_lengths = batch.text
56
- predictions = net(text).squeeze(1)
57
- loss = criterion(predictions, batch.label)
58
  loss.backward()
59
  optimizer.step()
60
  epoch_loss += loss.item()
@@ -64,7 +91,7 @@ def train_network(net, iterator, optimizer, criterion, epochs):
64
  st.write('Finished Training')
65
  return loss_values
66
 
67
- # Evaluate the network
68
  def evaluate_network(net, iterator, criterion):
69
  epoch_loss = 0
70
  correct = 0
@@ -73,15 +100,15 @@ def evaluate_network(net, iterator, criterion):
73
  all_predictions = []
74
  net.eval()
75
  with torch.no_grad():
76
- for batch in iterator:
77
- text, text_lengths = batch.text
78
- predictions = net(text).squeeze(1)
79
- loss = criterion(predictions, batch.label)
80
  epoch_loss += loss.item()
81
  _, predicted = torch.max(predictions, 1)
82
- correct += (predicted == batch.label).sum().item()
83
- total += len(batch.label)
84
- all_labels.extend(batch.label.cpu().numpy())
85
  all_predictions.extend(predicted.cpu().numpy())
86
  accuracy = 100 * correct / total
87
  st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
@@ -89,7 +116,7 @@ def evaluate_network(net, iterator, criterion):
89
 
90
  # Load data
91
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
- TEXT, LABEL, train_iterator, valid_iterator, test_iterator = load_data()
93
 
94
  # Streamlit interface
95
  st.title("RNN for Text Classification on AG News Dataset")
@@ -108,8 +135,8 @@ learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001)
108
  epochs = st.sidebar.slider('Epochs', 1, 20, 5)
109
 
110
  # Create the network
111
- vocab_size = len(TEXT.vocab)
112
- output_size = len(LABEL.vocab)
113
  net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
114
  criterion = nn.CrossEntropyLoss()
115
  optimizer = optim.Adam(net.parameters(), lr=learning_rate)
@@ -119,7 +146,7 @@ st.write('\n' * 10)
119
 
120
  # Train the network
121
  if st.sidebar.button('Train Network'):
122
- loss_values = train_network(net, train_iterator, optimizer, criterion, epochs)
123
 
124
  # Plot the loss values
125
  plt.figure(figsize=(10, 5))
@@ -135,7 +162,7 @@ if st.sidebar.button('Train Network'):
135
 
136
  # Test the network
137
  if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
138
- accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_iterator, criterion)
139
  st.write(f'Test Accuracy: {accuracy:.2f}%')
140
 
141
  # Display results in a table
@@ -151,18 +178,17 @@ def visualize_text_predictions(iterator, net):
151
  net.eval()
152
  samples = []
153
  with torch.no_grad():
154
- for batch in iterator:
155
- text, text_lengths = batch.text
156
- predictions = torch.max(net(text), 1)[1]
157
- samples.extend(zip(text.cpu(), batch.label.cpu(), predictions.cpu()))
158
  if len(samples) >= 10:
159
  break
160
  return samples[:10]
161
 
162
  if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
163
- samples = visualize_text_predictions(test_iterator, st.session_state['trained_model'])
164
  st.write('Ground Truth vs Predicted for Sample Texts')
165
  for i, (text, true_label, predicted) in enumerate(samples):
166
  st.write(f'Sample {i+1}')
167
- st.text(' '.join([TEXT.vocab.itos[token] for token in text]))
168
  st.write(f'Ground Truth: {LABEL.vocab.itos[true_label.item()]}, Predicted: {LABEL.vocab.itos[predicted.item()]}')
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
+ from torchtext.data.utils import get_tokenizer
6
+ from torchtext.vocab import build_vocab_from_iterator
7
+ from torchtext.datasets import AG_NEWS
8
+ from torch.utils.data import DataLoader, random_split
9
  import matplotlib.pyplot as plt
10
  import pandas as pd
11
  import numpy as np
 
26
  out = self.fc(out[:, -1, :])
27
  return out
28
 
29
+ # Create a custom collate function to pad sequences
30
+ def collate_batch(batch):
31
+ label_list, text_list, lengths = [], [], []
32
+ for _label, _text in batch:
33
+ label_list.append(label_pipeline(_label))
34
+ processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
35
+ text_list.append(processed_text)
36
+ lengths.append(processed_text.size(0))
37
+ labels = torch.tensor(label_list, dtype=torch.int64)
38
+ texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
39
+ return texts, labels
40
+
41
+ # Function to load the data
42
+ @st.cache_data
43
  def load_data():
44
+ tokenizer = get_tokenizer("basic_english")
45
+ train_iter = AG_NEWS(split='train')
46
+ test_iter = AG_NEWS(split='test')
 
47
 
48
+ def yield_tokens(data_iter):
49
+ for _, text in data_iter:
50
+ yield tokenizer(text)
51
 
52
+ vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
53
+ vocab.set_default_index(vocab["<unk>"])
54
 
55
+ global text_pipeline, label_pipeline
56
+ text_pipeline = lambda x: vocab(tokenizer(x))
57
+ label_pipeline = lambda x: int(x) - 1
 
 
58
 
59
+ # Create DataLoaders
60
+ train_dataset = list(train_iter)
61
+ test_dataset = list(test_iter)
62
 
63
+ train_size = int(0.8 * len(train_dataset))
64
+ valid_size = len(train_dataset) - train_size
65
+ train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
66
+
67
+ BATCH_SIZE = 64
68
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
69
+ valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
70
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
71
+
72
+ return vocab, train_loader, valid_loader, test_loader
73
+
74
+ # Function to train the network
75
  def train_network(net, iterator, optimizer, criterion, epochs):
76
  loss_values = []
77
  for epoch in range(epochs):
78
  epoch_loss = 0
79
  net.train()
80
+ for texts, labels in iterator:
81
+ texts, labels = texts.to(device), labels.to(device)
82
  optimizer.zero_grad()
83
+ predictions = net(texts).squeeze(1)
84
+ loss = criterion(predictions, labels)
 
85
  loss.backward()
86
  optimizer.step()
87
  epoch_loss += loss.item()
 
91
  st.write('Finished Training')
92
  return loss_values
93
 
94
+ # Function to evaluate the network
95
  def evaluate_network(net, iterator, criterion):
96
  epoch_loss = 0
97
  correct = 0
 
100
  all_predictions = []
101
  net.eval()
102
  with torch.no_grad():
103
+ for texts, labels in iterator:
104
+ texts, labels = texts.to(device), labels.to(device)
105
+ predictions = net(texts).squeeze(1)
106
+ loss = criterion(predictions, labels)
107
  epoch_loss += loss.item()
108
  _, predicted = torch.max(predictions, 1)
109
+ correct += (predicted == labels).sum().item()
110
+ total += len(labels)
111
+ all_labels.extend(labels.cpu().numpy())
112
  all_predictions.extend(predicted.cpu().numpy())
113
  accuracy = 100 * correct / total
114
  st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
 
116
 
117
  # Load data
118
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
119
+ vocab, train_loader, valid_loader, test_loader = load_data()
120
 
121
  # Streamlit interface
122
  st.title("RNN for Text Classification on AG News Dataset")
 
135
  epochs = st.sidebar.slider('Epochs', 1, 20, 5)
136
 
137
  # Create the network
138
+ vocab_size = len(vocab)
139
+ output_size = 4 # Number of classes in AG_NEWS
140
  net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
141
  criterion = nn.CrossEntropyLoss()
142
  optimizer = optim.Adam(net.parameters(), lr=learning_rate)
 
146
 
147
  # Train the network
148
  if st.sidebar.button('Train Network'):
149
+ loss_values = train_network(net, train_loader, optimizer, criterion, epochs)
150
 
151
  # Plot the loss values
152
  plt.figure(figsize=(10, 5))
 
162
 
163
  # Test the network
164
  if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
165
+ accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_loader, criterion)
166
  st.write(f'Test Accuracy: {accuracy:.2f}%')
167
 
168
  # Display results in a table
 
178
  net.eval()
179
  samples = []
180
  with torch.no_grad():
181
+ for texts, labels in iterator:
182
+ predictions = torch.max(net(texts), 1)[1]
183
+ samples.extend(zip(texts.cpu(), labels.cpu(), predictions.cpu()))
 
184
  if len(samples) >= 10:
185
  break
186
  return samples[:10]
187
 
188
  if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
189
+ samples = visualize_text_predictions(test_loader, st.session_state['trained_model'])
190
  st.write('Ground Truth vs Predicted for Sample Texts')
191
  for i, (text, true_label, predicted) in enumerate(samples):
192
  st.write(f'Sample {i+1}')
193
+ st.text(' '.join([vocab.get_itos()[token] for token in text]))
194
  st.write(f'Ground Truth: {LABEL.vocab.itos[true_label.item()]}, Predicted: {LABEL.vocab.itos[predicted.item()]}')