Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from torch.utils.data import DataLoader | |
from sklearn.metrics import confusion_matrix | |
import numpy as np | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Streamlit interface | |
st.title("CNN for Image Classification using CIFAR-10") | |
st.write(""" | |
This application demonstrates how to build and train a Convolutional Neural Network (CNN) for image classification using the CIFAR-10 dataset. You can adjust hyperparameters, visualize sample images, and see the model's performance. | |
""") | |
# Hyperparameters | |
num_epochs = st.sidebar.slider("Number of epochs", 1, 20, 10) | |
batch_size = st.sidebar.slider("Batch size", 10, 200, 100, step=10) | |
learning_rate = st.sidebar.slider("Learning rate", 0.0001, 0.01, 0.001, step=0.0001) | |
# CIFAR-10 dataset | |
transform = transforms.Compose( | |
[transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
download=True, transform=transform) | |
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, | |
download=True, transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
# Display some sample images | |
st.write("## Sample Images from CIFAR-10 Dataset") | |
sample_images, sample_labels = next(iter(train_loader)) | |
fig, axes = plt.subplots(1, 6, figsize=(15, 5)) | |
for i in range(6): | |
axes[i].imshow(np.transpose(sample_images[i].numpy(), (1, 2, 0))) | |
axes[i].set_title(f'Label: {sample_labels[i].item()}') | |
axes[i].axis('off') | |
st.pyplot(fig) | |
# Class distribution | |
st.write("## Class Distribution in CIFAR-10 Dataset") | |
class_names = train_dataset.classes | |
class_counts = np.bincount([sample_labels[i].item() for i in range(len(sample_labels))]) | |
fig, ax = plt.subplots() | |
sns.barplot(x=class_names, y=class_counts, ax=ax) | |
ax.set_ylabel('Count') | |
ax.set_title('Class Distribution') | |
st.pyplot(fig) | |
# Define a Convolutional Neural Network | |
class CNN(nn.Module): | |
def __init__(self): | |
super(CNN, self).__init__() | |
self.layer1 = nn.Sequential( | |
nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
nn.BatchNorm2d(32), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2)) | |
self.layer2 = nn.Sequential( | |
nn.Conv2d(32, 64, kernel_size=3), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.MaxPool2d(2)) | |
# Automatically determine the size of the flattened features after convolution and pooling | |
self._to_linear = None | |
self.convs(torch.randn(1, 3, 32, 32)) | |
self.fc1 = nn.Linear(self._to_linear, 600) | |
self.drop = nn.Dropout2d(0.25) | |
self.fc2 = nn.Linear(600, 100) | |
self.fc3 = nn.Linear(100, 10) | |
def convs(self, x): | |
x = self.layer1(x) | |
x = self.layer2(x) | |
if self._to_linear is None: | |
self._to_linear = x.view(x.size(0), -1).shape[1] | |
return x | |
def forward(self, x): | |
x = self.convs(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.fc3(x) | |
return x | |
model = CNN().to(device) | |
# Loss and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
# Button to start training | |
if st.button("Start Training"): | |
# Lists to store losses and accuracy | |
train_losses = [] | |
test_losses = [] | |
test_accuracies = [] | |
# Progress bar | |
progress_bar = st.progress(0) | |
# Train the model | |
total_step = len(train_loader) | |
for epoch in range(num_epochs): | |
train_loss = 0 | |
for i, (images, labels) in enumerate(train_loader): | |
images = images.to(device) | |
labels = labels.to(device) | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
train_loss /= total_step | |
train_losses.append(train_loss) | |
st.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}') | |
# Test the model | |
model.eval() | |
with torch.no_grad(): | |
test_loss = 0 | |
correct = 0 | |
total = 0 | |
all_labels = [] | |
all_predictions = [] | |
for images, labels in test_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
test_loss += loss.item() | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
all_labels.extend(labels.cpu().numpy()) | |
all_predictions.extend(predicted.cpu().numpy()) | |
test_loss /= len(test_loader) | |
test_losses.append(test_loss) | |
accuracy = 100 * correct / total | |
test_accuracies.append(accuracy) | |
st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%') | |
model.train() | |
# Update progress bar | |
progress_bar.progress((epoch + 1) / num_epochs) | |
# Plotting the loss and accuracy | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
ax1.plot(range(1, num_epochs + 1), train_losses, label='Train Loss') | |
ax1.plot(range(1, num_epochs + 1), test_losses, label='Test Loss') | |
ax1.set_xlabel('Epoch') | |
ax1.set_ylabel('Loss') | |
ax1.set_title('Training and Test Loss') | |
ax1.legend() | |
ax2.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy') | |
ax2.set_xlabel('Epoch') | |
ax2.set_ylabel('Accuracy (%)') | |
ax2.set_title('Test Accuracy') | |
ax2.legend() | |
st.pyplot(fig) | |
# Confusion Matrix | |
cm = confusion_matrix(all_labels, all_predictions) | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names, cmap='Blues') | |
ax.set_xlabel('Predicted') | |
ax.set_ylabel('True') | |
ax.set_title('Confusion Matrix') | |
st.pyplot(fig) | |
# Save the model checkpoint | |
torch.save(model.state_dict(), 'cnn_model.pth') | |
st.write("Model training completed and saved as 'cnn_model.pth'") | |