Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchtext.legacy import data, datasets | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pandas as pd | |
import numpy as np | |
# Define the RNN model | |
class RNN(nn.Module): | |
def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout): | |
super(RNN, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embed_size) | |
self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True) | |
self.fc = nn.Linear(hidden_size, output_size) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.dropout(self.embedding(x)) | |
h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device) | |
out, _ = self.rnn(x, h0) | |
out = self.fc(out[:, -1, :]) | |
return out | |
# Function to load the data | |
def load_data(): | |
TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm') | |
LABEL = data.LabelField(dtype=torch.float) | |
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL) | |
train_data, valid_data = train_data.split(split_ratio=0.8) | |
MAX_VOCAB_SIZE = 25_000 | |
TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE, vectors="glove.6B.100d", unk_init=torch.Tensor.normal_) | |
LABEL.build_vocab(train_data) | |
BATCH_SIZE = 64 | |
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( | |
(train_data, valid_data, test_data), | |
batch_size=BATCH_SIZE, | |
device=device) | |
return TEXT, LABEL, train_iterator, valid_iterator, test_iterator | |
# Function to train the network | |
def train_network(net, iterator, optimizer, criterion, epochs): | |
loss_values = [] | |
for epoch in range(epochs): | |
epoch_loss = 0 | |
net.train() | |
for batch in iterator: | |
optimizer.zero_grad() | |
predictions = net(batch.text).squeeze(1) | |
loss = criterion(predictions, batch.label) | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() | |
epoch_loss /= len(iterator) | |
loss_values.append(epoch_loss) | |
st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}') | |
st.write('Finished Training') | |
return loss_values | |
# Function to evaluate the network | |
def evaluate_network(net, iterator, criterion): | |
epoch_loss = 0 | |
correct = 0 | |
total = 0 | |
all_labels = [] | |
all_predictions = [] | |
net.eval() | |
with torch.no_grad(): | |
for batch in iterator: | |
predictions = net(batch.text).squeeze(1) | |
loss = criterion(predictions, batch.label) | |
epoch_loss += loss.item() | |
rounded_preds = torch.round(torch.sigmoid(predictions)) | |
correct += (rounded_preds == batch.label).sum().item() | |
total += len(batch.label) | |
all_labels.extend(batch.label.cpu().numpy()) | |
all_predictions.extend(rounded_preds.cpu().numpy()) | |
accuracy = 100 * correct / total | |
st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%') | |
return accuracy, all_labels, all_predictions | |
# Load the data | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
TEXT, LABEL, train_iterator, valid_iterator, test_iterator = load_data() | |
# Streamlit interface | |
st.title("RNN for Text Classification on IMDb Dataset") | |
st.write(""" | |
This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the IMDb dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance. | |
""") | |
# Sidebar for input parameters | |
st.sidebar.header('Model Hyperparameters') | |
embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100) | |
hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256) | |
n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2) | |
dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1) | |
learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001) | |
epochs = st.sidebar.slider('Epochs', 1, 20, 5) | |
# Create the network | |
vocab_size = len(TEXT.vocab) | |
output_size = 1 | |
net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device) | |
criterion = nn.BCEWithLogitsLoss() | |
optimizer = optim.Adam(net.parameters(), lr=learning_rate) | |
# Add vertical space | |
st.write('\n' * 10) | |
# Train the network | |
if st.sidebar.button('Train Network'): | |
loss_values = train_network(net, train_iterator, optimizer, criterion, epochs) | |
# Plot the loss values | |
plt.figure(figsize=(10, 5)) | |
plt.plot(range(1, epochs + 1), loss_values, marker='o') | |
plt.title('Training Loss Over Epochs') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.grid(True) | |
st.pyplot(plt) | |
# Store the trained model in the session state | |
st.session_state['trained_model'] = net | |
# Test the network | |
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'): | |
accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_iterator, criterion) | |
st.write(f'Test Accuracy: {accuracy:.2f}%') | |
# Display results in a table | |
st.write('Ground Truth vs Predicted') | |
results = pd.DataFrame({ | |
'Ground Truth': all_labels, | |
'Predicted': all_predictions | |
}) | |
st.table(results.head(50)) # Display first 50 results for brevity | |
# Visualize some test results | |
def visualize_text_predictions(iterator, net): | |
net.eval() | |
samples = [] | |
with torch.no_grad(): | |
for batch in iterator: | |
predictions = torch.round(torch.sigmoid(net(batch.text).squeeze(1))) | |
samples.extend(zip(batch.text.cpu(), batch.label.cpu(), predictions.cpu())) | |
if len(samples) >= 10: | |
break | |
return samples[:10] | |
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'): | |
samples = visualize_text_predictions(test_iterator, st.session_state['trained_model']) | |
st.write('Ground Truth vs Predicted for Sample Texts') | |
for i, (text, true_label, predicted) in enumerate(samples): | |
st.write(f'Sample {i+1}') | |
st.text(' '.join([TEXT.vocab.itos[token] for token in text])) | |
st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}') | |