Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchtext.data.utils import get_tokenizer | |
from torchtext.vocab import build_vocab_from_iterator | |
from torchtext.datasets import IMDB | |
from torch.utils.data import DataLoader, random_split | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pandas as pd | |
import numpy as np | |
from collections import Counter | |
from torch.nn.utils.rnn import pad_sequence | |
# Define the RNN model | |
class RNN(nn.Module): | |
def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout): | |
super(RNN, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embed_size) | |
self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True) | |
self.fc = nn.Linear(hidden_size, output_size) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.dropout(self.embedding(x)) | |
h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device) | |
out, _ = self.rnn(x, h0) | |
out = self.fc(out[:, -1, :]) | |
return out | |
# Create a custom collate function to pad sequences | |
def collate_batch(batch): | |
texts, labels = zip(*batch) | |
text_lengths = [len(text) for text in texts] | |
texts_padded = pad_sequence(texts, batch_first=True, padding_value=vocab["<pad>"]) | |
return texts_padded, torch.tensor(labels, dtype=torch.float), text_lengths | |
# Function to load the data | |
def load_data(): | |
tokenizer = get_tokenizer("basic_english") | |
train_iter, test_iter = IMDB(split=('train', 'test')) | |
def yield_tokens(data_iter): | |
for _, text in data_iter: | |
yield tokenizer(text) | |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"]) | |
vocab.set_default_index(vocab["<unk>"]) | |
# Define the text and label processing pipelines | |
text_pipeline = lambda x: vocab(tokenizer(x)) | |
label_pipeline = lambda x: 1 if x == 'pos' else 0 | |
# Process the data into tensors | |
def process_data(data_iter): | |
texts, labels = [], [] | |
for label, text in data_iter: | |
texts.append(torch.tensor(text_pipeline(text), dtype=torch.long)) | |
labels.append(label_pipeline(label)) | |
return texts, torch.tensor(labels, dtype=torch.float) | |
train_texts, train_labels = process_data(train_iter) | |
test_texts, test_labels = process_data(test_iter) | |
# Create DataLoaders | |
train_dataset = list(zip(train_texts, train_labels)) | |
test_dataset = list(zip(test_texts, test_labels)) | |
train_size = int(0.8 * len(train_dataset)) | |
valid_size = len(train_dataset) - train_size | |
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size]) | |
BATCH_SIZE = 64 | |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) | |
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) | |
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) | |
return vocab, train_loader, valid_loader, test_loader | |
# Function to train the network | |
def train_network(net, iterator, optimizer, criterion, epochs): | |
loss_values = [] | |
for epoch in range(epochs): | |
epoch_loss = 0 | |
net.train() | |
for texts, labels, _ in iterator: | |
texts, labels = texts.to(device), labels.to(device) | |
optimizer.zero_grad() | |
predictions = net(texts).squeeze(1) | |
loss = criterion(predictions, labels) | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() | |
epoch_loss /= len(iterator) | |
loss_values.append(epoch_loss) | |
st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}') | |
st.write('Finished Training') | |
return loss_values | |
# Function to evaluate the network | |
def evaluate_network(net, iterator, criterion): | |
epoch_loss = 0 | |
correct = 0 | |
total = 0 | |
all_labels = [] | |
all_predictions = [] | |
net.eval() | |
with torch.no_grad(): | |
for texts, labels, _ in iterator: | |
texts, labels = texts.to(device), labels.to(device) | |
predictions = net(texts).squeeze(1) | |
loss = criterion(predictions, labels) | |
epoch_loss += loss.item() | |
rounded_preds = torch.round(torch.sigmoid(predictions)) | |
correct += (rounded_preds == labels).sum().item() | |
total += len(labels) | |
all_labels.extend(labels.cpu().numpy()) | |
all_predictions.extend(rounded_preds.cpu().numpy()) | |
accuracy = 100 * correct / total | |
st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%') | |
return accuracy, all_labels, all_predictions | |
# Load the data | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Display a loading message with some vertical space | |
st.markdown("<div style='margin-top: 50px;'><b>Loading data...</b></div>", unsafe_allow_html=True) | |
vocab, train_loader, valid_loader, test_loader = load_data() | |
# Streamlit interface | |
st.title("RNN for Text Classification on IMDb Dataset") | |
st.write(""" | |
This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the IMDb dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance. | |
""") | |
# Sidebar for input parameters | |
st.sidebar.header('Model Hyperparameters') | |
embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100) | |
hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256) | |
n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2) | |
dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1) | |
learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001) | |
epochs = st.sidebar.slider('Epochs', 1, 20, 5) | |
# Create the network | |
vocab_size = len(vocab) | |
output_size = 1 | |
net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device) | |
criterion = nn.BCEWithLogitsLoss() | |
optimizer = optim.Adam(net.parameters(), lr=learning_rate) | |
# Add vertical space | |
st.write('\n' * 10) | |
# Train the network | |
if st.sidebar.button('Train Network'): | |
loss_values = train_network(net, train_loader, optimizer, criterion, epochs) | |
# Plot the loss values | |
plt.figure(figsize=(10, 5)) | |
plt.plot(range(1, epochs + 1), loss_values, marker='o') | |
plt.title('Training Loss Over Epochs') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.grid(True) | |
st.pyplot(plt) | |
# Store the trained model in the session state | |
st.session_state['trained_model'] = net | |
# Test the network | |
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'): | |
accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_loader, criterion) | |
st.write(f'Test Accuracy: {accuracy:.2f}%') | |
# Display results in a table | |
st.write('Ground Truth vs Predicted') | |
results = pd.DataFrame({ | |
'Ground Truth': all_labels, | |
'Predicted': all_predictions | |
}) | |
st.table(results.head(50)) # Display first 50 results for brevity | |
# Visualize some test results | |
def visualize_text_predictions(iterator, net): | |
net.eval() | |
samples = [] | |
with torch.no_grad(): | |
for texts, labels, _ in iterator: | |
predictions = torch.round(torch.sigmoid(net(texts).squeeze(1))) | |
samples.extend(zip(texts.cpu(), labels.cpu(), predictions.cpu())) | |
if len(samples) >= 10: | |
break | |
return samples[:10] | |
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'): | |
samples = visualize_text_predictions(test_loader, st.session_state['trained_model']) | |
st.write('Ground Truth vs Predicted for Sample Texts') | |
for i, (text, true_label, predicted) in enumerate(samples): | |
st.write(f'Sample {i+1}') | |
st.text(' '.join([vocab.get_itos()[token] for token in text])) | |
st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}') | |