Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchtext.data.utils import get_tokenizer | |
from torchtext.vocab import build_vocab_from_iterator | |
from torchtext.datasets import AG_NEWS | |
from torch.utils.data import DataLoader, random_split | |
from torch.nn.utils.rnn import pad_sequence | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import numpy as np | |
# Define the RNN model | |
class RNN(nn.Module): | |
def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout): | |
super(RNN, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embed_size) | |
self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True) | |
self.fc = nn.Linear(hidden_size, output_size) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.dropout(self.embedding(x)) | |
h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device) | |
out, _ = self.rnn(x, h0) | |
out = self.fc(out[:, -1, :]) | |
return out | |
# Create a custom collate function to pad sequences | |
def collate_batch(batch): | |
label_list, text_list = [], [] | |
for _label, _text in batch: | |
label_list.append(label_pipeline(_label)) | |
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64) | |
text_list.append(processed_text) | |
labels = torch.tensor(label_list, dtype=torch.int64) | |
texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"]) | |
return texts, labels | |
# Function to load the data | |
def load_data(): | |
tokenizer = get_tokenizer("basic_english") | |
train_iter = AG_NEWS(split='train') | |
test_iter = AG_NEWS(split='test') | |
def yield_tokens(data_iter): | |
for _, text in data_iter: | |
yield tokenizer(text) | |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"]) | |
vocab.set_default_index(vocab["<unk>"]) | |
global text_pipeline, label_pipeline | |
text_pipeline = lambda x: vocab(tokenizer(x)) | |
label_pipeline = lambda x: int(x) - 1 | |
# Create DataLoaders | |
train_dataset = list(train_iter) | |
test_dataset = list(test_iter) | |
train_size = int(0.8 * len(train_dataset)) | |
valid_size = len(train_dataset) - train_size | |
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size]) | |
BATCH_SIZE = 64 | |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) | |
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) | |
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) | |
return vocab, train_loader, valid_loader, test_loader | |
# Function to train the network | |
def train_network(net, iterator, optimizer, criterion, epochs): | |
loss_values = [] | |
for epoch in range(epochs): | |
epoch_loss = 0 | |
net.train() | |
for texts, labels in iterator: | |
texts, labels = texts.to(device), labels.to(device) | |
optimizer.zero_grad() | |
predictions = net(texts).squeeze(1) | |
loss = criterion(predictions, labels) | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() | |
epoch_loss /= len(iterator) | |
loss_values.append(epoch_loss) | |
st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}') | |
st.write('Finished Training') | |
return loss_values | |
# Function to evaluate the network | |
def evaluate_network(net, iterator, criterion): | |
epoch_loss = 0 | |
correct = 0 | |
total = 0 | |
all_labels = [] | |
all_predictions = [] | |
net.eval() | |
with torch.no_grad(): | |
for texts, labels in iterator: | |
texts, labels = texts.to(device), labels.to(device) | |
predictions = net(texts).squeeze(1) | |
loss = criterion(predictions, labels) | |
epoch_loss += loss.item() | |
_, predicted = torch.max(predictions, 1) | |
correct += (predicted == labels).sum().item() | |
total += len(labels) | |
all_labels.extend(labels.cpu().numpy()) | |
all_predictions.extend(predicted.cpu().numpy()) | |
accuracy = 100 * correct / total | |
st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%') | |
return accuracy, all_labels, all_predictions | |
# Load data | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
vocab, train_loader, valid_loader, test_loader = load_data() | |
# Streamlit interface | |
st.title("RNN for Text Classification on AG News Dataset") | |
st.write(""" | |
This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the AG News dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance. | |
""") | |
# Sidebar for input parameters | |
st.sidebar.header('Model Hyperparameters') | |
embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100) | |
hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256) | |
n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2) | |
dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1) | |
learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001) | |
epochs = st.sidebar.slider('Epochs', 1, 20, 5) | |
# Create the network | |
vocab_size = len(vocab) | |
output_size = 4 # Number of classes in AG_NEWS | |
net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(net.parameters(), lr=learning_rate) | |
# Add vertical space | |
st.write('\n' * 10) | |
# Train the network | |
if st.sidebar.button('Train Network'): | |
loss_values = train_network(net, train_loader, optimizer, criterion, epochs) | |
# Plot the loss values | |
plt.figure(figsize=(10, 5)) | |
plt.plot(range(1, epochs + 1), loss_values, marker='o') | |
plt.title('Training Loss Over Epochs') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.grid(True) | |
st.pyplot(plt) | |
# Store the trained model in the session state | |
st.session_state['trained_model'] = net | |
# Test the network | |
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'): | |
accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_loader, criterion) | |
st.write(f'Test Accuracy: {accuracy:.2f}%') | |
# Display results in a table | |
st.write('Ground Truth vs Predicted') | |
results = pd.DataFrame({ | |
'Ground Truth': [LABEL.vocab.itos[label] for label in all_labels], | |
'Predicted': [LABEL.vocab.itos[label] for label in all_predictions] | |
}) | |
st.table(results.head(50)) # Display first 50 results for brevity | |
# Visualize some test results | |
def visualize_text_predictions(iterator, net): | |
net.eval() | |
samples = [] | |
with torch.no_grad(): | |
for texts, labels in iterator: | |
predictions = torch.max(net(texts), 1)[1] | |
samples.extend(zip(texts.cpu(), labels.cpu(), predictions.cpu())) | |
if len(samples) >= 10: | |
break | |
return samples[:10] | |
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'): | |
samples = visualize_text_predictions(test_loader, st.session_state['trained_model']) | |
st.write('Ground Truth vs Predicted for Sample Texts') | |
for i, (text, true_label, predicted) in enumerate(samples): | |
st.write(f'Sample {i+1}') | |
st.text(' '.join([vocab.get_itos()[token] for token in text])) | |
st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}') | |