eaglelandsonce commited on
Commit
de0d854
·
verified ·
1 Parent(s): 15bd774

Create RNN.py

Browse files
Files changed (1) hide show
  1. pages/RNN.py +165 -0
pages/RNN.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchtext.legacy import data, datasets
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ # Define the RNN model
12
+ class RNN(nn.Module):
13
+ def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout):
14
+ super(RNN, self).__init__()
15
+ self.embedding = nn.Embedding(vocab_size, embed_size)
16
+ self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True)
17
+ self.fc = nn.Linear(hidden_size, output_size)
18
+ self.dropout = nn.Dropout(dropout)
19
+
20
+ def forward(self, x):
21
+ x = self.dropout(self.embedding(x))
22
+ h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device)
23
+ out, _ = self.rnn(x, h0)
24
+ out = self.fc(out[:, -1, :])
25
+ return out
26
+
27
+ # Function to load the data
28
+ @st.cache_data
29
+ def load_data():
30
+ TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
31
+ LABEL = data.LabelField(dtype=torch.float)
32
+ train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
33
+ train_data, valid_data = train_data.split(split_ratio=0.8)
34
+
35
+ MAX_VOCAB_SIZE = 25_000
36
+ TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE, vectors="glove.6B.100d", unk_init=torch.Tensor.normal_)
37
+ LABEL.build_vocab(train_data)
38
+
39
+ BATCH_SIZE = 64
40
+ train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
41
+ (train_data, valid_data, test_data),
42
+ batch_size=BATCH_SIZE,
43
+ device=device)
44
+
45
+ return TEXT, LABEL, train_iterator, valid_iterator, test_iterator
46
+
47
+ # Function to train the network
48
+ def train_network(net, iterator, optimizer, criterion, epochs):
49
+ loss_values = []
50
+ for epoch in range(epochs):
51
+ epoch_loss = 0
52
+ net.train()
53
+ for batch in iterator:
54
+ optimizer.zero_grad()
55
+ predictions = net(batch.text).squeeze(1)
56
+ loss = criterion(predictions, batch.label)
57
+ loss.backward()
58
+ optimizer.step()
59
+ epoch_loss += loss.item()
60
+ epoch_loss /= len(iterator)
61
+ loss_values.append(epoch_loss)
62
+ st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}')
63
+ st.write('Finished Training')
64
+ return loss_values
65
+
66
+ # Function to evaluate the network
67
+ def evaluate_network(net, iterator, criterion):
68
+ epoch_loss = 0
69
+ correct = 0
70
+ total = 0
71
+ all_labels = []
72
+ all_predictions = []
73
+ net.eval()
74
+ with torch.no_grad():
75
+ for batch in iterator:
76
+ predictions = net(batch.text).squeeze(1)
77
+ loss = criterion(predictions, batch.label)
78
+ epoch_loss += loss.item()
79
+ rounded_preds = torch.round(torch.sigmoid(predictions))
80
+ correct += (rounded_preds == batch.label).sum().item()
81
+ total += len(batch.label)
82
+ all_labels.extend(batch.label.cpu().numpy())
83
+ all_predictions.extend(rounded_preds.cpu().numpy())
84
+ accuracy = 100 * correct / total
85
+ st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
86
+ return accuracy, all_labels, all_predictions
87
+
88
+ # Load the data
89
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90
+ TEXT, LABEL, train_iterator, valid_iterator, test_iterator = load_data()
91
+
92
+ # Streamlit interface
93
+ st.title("RNN for Text Classification on IMDb Dataset")
94
+
95
+ st.write("""
96
+ This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the IMDb dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance.
97
+ """)
98
+
99
+ # Sidebar for input parameters
100
+ st.sidebar.header('Model Hyperparameters')
101
+ embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100)
102
+ hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256)
103
+ n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2)
104
+ dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1)
105
+ learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001)
106
+ epochs = st.sidebar.slider('Epochs', 1, 20, 5)
107
+
108
+ # Create the network
109
+ vocab_size = len(TEXT.vocab)
110
+ output_size = 1
111
+ net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
112
+ criterion = nn.BCEWithLogitsLoss()
113
+ optimizer = optim.Adam(net.parameters(), lr=learning_rate)
114
+
115
+ # Add vertical space
116
+ st.write('\n' * 10)
117
+
118
+ # Train the network
119
+ if st.sidebar.button('Train Network'):
120
+ loss_values = train_network(net, train_iterator, optimizer, criterion, epochs)
121
+
122
+ # Plot the loss values
123
+ plt.figure(figsize=(10, 5))
124
+ plt.plot(range(1, epochs + 1), loss_values, marker='o')
125
+ plt.title('Training Loss Over Epochs')
126
+ plt.xlabel('Epoch')
127
+ plt.ylabel('Loss')
128
+ plt.grid(True)
129
+ st.pyplot(plt)
130
+
131
+ # Store the trained model in the session state
132
+ st.session_state['trained_model'] = net
133
+
134
+ # Test the network
135
+ if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
136
+ accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_iterator, criterion)
137
+ st.write(f'Test Accuracy: {accuracy:.2f}%')
138
+
139
+ # Display results in a table
140
+ st.write('Ground Truth vs Predicted')
141
+ results = pd.DataFrame({
142
+ 'Ground Truth': all_labels,
143
+ 'Predicted': all_predictions
144
+ })
145
+ st.table(results.head(50)) # Display first 50 results for brevity
146
+
147
+ # Visualize some test results
148
+ def visualize_text_predictions(iterator, net):
149
+ net.eval()
150
+ samples = []
151
+ with torch.no_grad():
152
+ for batch in iterator:
153
+ predictions = torch.round(torch.sigmoid(net(batch.text).squeeze(1)))
154
+ samples.extend(zip(batch.text.cpu(), batch.label.cpu(), predictions.cpu()))
155
+ if len(samples) >= 10:
156
+ break
157
+ return samples[:10]
158
+
159
+ if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
160
+ samples = visualize_text_predictions(test_iterator, st.session_state['trained_model'])
161
+ st.write('Ground Truth vs Predicted for Sample Texts')
162
+ for i, (text, true_label, predicted) in enumerate(samples):
163
+ st.write(f'Sample {i+1}')
164
+ st.text(' '.join([TEXT.vocab.itos[token] for token in text]))
165
+ st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')