pytorch / pages /17_RNN.py
eaglelandsonce's picture
Update pages/17_RNN.py
713f0f4 verified
raw
history blame
8.08 kB
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import IMDB
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
# Define the RNN model
class RNN(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout):
super(RNN, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.dropout(self.embedding(x))
h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device)
out, _ = self.rnn(x, h0)
out = self.fc(out[:, -1, :])
return out
# Create a custom collate function to pad sequences
def collate_batch(batch):
texts, labels = zip(*batch)
text_lengths = [len(text) for text in texts]
texts_padded = pad_sequence(texts, batch_first=True, padding_value=vocab["<pad>"])
return texts_padded, torch.tensor(labels, dtype=torch.float), text_lengths
# Function to load the data
@st.cache_data
def load_data():
tokenizer = get_tokenizer("basic_english")
train_iter, test_iter = IMDB(split=('train', 'test'))
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])
# Define the text and label processing pipelines
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: 1 if x == 'pos' else 0
# Process the data into tensors
def process_data(data_iter):
texts, labels = [], []
for label, text in data_iter:
texts.append(torch.tensor(text_pipeline(text), dtype=torch.long))
labels.append(label_pipeline(label))
return texts, torch.tensor(labels, dtype=torch.float)
train_texts, train_labels = process_data(train_iter)
test_texts, test_labels = process_data(test_iter)
# Create DataLoaders
train_dataset = list(zip(train_texts, train_labels))
test_dataset = list(zip(test_texts, test_labels))
train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
return vocab, train_loader, valid_loader, test_loader
# Function to train the network
def train_network(net, iterator, optimizer, criterion, epochs):
loss_values = []
for epoch in range(epochs):
epoch_loss = 0
net.train()
for texts, labels, _ in iterator:
texts, labels = texts.to(device), labels.to(device)
optimizer.zero_grad()
predictions = net(texts).squeeze(1)
loss = criterion(predictions, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(iterator)
loss_values.append(epoch_loss)
st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}')
st.write('Finished Training')
return loss_values
# Function to evaluate the network
def evaluate_network(net, iterator, criterion):
epoch_loss = 0
correct = 0
total = 0
all_labels = []
all_predictions = []
net.eval()
with torch.no_grad():
for texts, labels, _ in iterator:
texts, labels = texts.to(device), labels.to(device)
predictions = net(texts).squeeze(1)
loss = criterion(predictions, labels)
epoch_loss += loss.item()
rounded_preds = torch.round(torch.sigmoid(predictions))
correct += (rounded_preds == labels).sum().item()
total += len(labels)
all_labels.extend(labels.cpu().numpy())
all_predictions.extend(rounded_preds.cpu().numpy())
accuracy = 100 * correct / total
st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%')
return accuracy, all_labels, all_predictions
# Load the data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Display a loading message with some vertical space
st.markdown("<div style='margin-top: 50px;'><b>Loading data...</b></div>", unsafe_allow_html=True)
vocab, train_loader, valid_loader, test_loader = load_data()
# Streamlit interface
st.title("RNN for Text Classification on IMDb Dataset")
st.write("""
This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the IMDb dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance.
""")
# Sidebar for input parameters
st.sidebar.header('Model Hyperparameters')
embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100)
hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256)
n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2)
dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1)
learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001)
epochs = st.sidebar.slider('Epochs', 1, 20, 5)
# Create the network
vocab_size = len(vocab)
output_size = 1
net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
# Add vertical space
st.write('\n' * 10)
# Train the network
if st.sidebar.button('Train Network'):
loss_values = train_network(net, train_loader, optimizer, criterion, epochs)
# Plot the loss values
plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs + 1), loss_values, marker='o')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
st.pyplot(plt)
# Store the trained model in the session state
st.session_state['trained_model'] = net
# Test the network
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_loader, criterion)
st.write(f'Test Accuracy: {accuracy:.2f}%')
# Display results in a table
st.write('Ground Truth vs Predicted')
results = pd.DataFrame({
'Ground Truth': all_labels,
'Predicted': all_predictions
})
st.table(results.head(50)) # Display first 50 results for brevity
# Visualize some test results
def visualize_text_predictions(iterator, net):
net.eval()
samples = []
with torch.no_grad():
for texts, labels, _ in iterator:
predictions = torch.round(torch.sigmoid(net(texts).squeeze(1)))
samples.extend(zip(texts.cpu(), labels.cpu(), predictions.cpu()))
if len(samples) >= 10:
break
return samples[:10]
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
samples = visualize_text_predictions(test_loader, st.session_state['trained_model'])
st.write('Ground Truth vs Predicted for Sample Texts')
for i, (text, true_label, predicted) in enumerate(samples):
st.write(f'Sample {i+1}')
st.text(' '.join([vocab.get_itos()[token] for token in text]))
st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')