import streamlit as st import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from PIL import Image import io # Set page config st.set_page_config(page_title="CIFAR-10 Classifier", layout="wide", initial_sidebar_state="expanded") # Custom CSS for dark theme st.markdown(""" """, unsafe_allow_html=True) # Model definition class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(64 * 8 * 8, 512) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 64 * 8 * 8) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # Function to train the model @st.cache_resource def train_model(): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) model = SimpleCNN() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) for epoch in range(5): # Train for 5 epochs for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() return model # Function to load or train the model @st.cache_resource def get_model(): try: model = SimpleCNN() model.load_state_dict(torch.load('cifar10_model.pth')) model.eval() except: model = train_model() torch.save(model.state_dict(), 'cifar10_model.pth') return model # Sidebar st.sidebar.title("Upload Image") uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) # Main content st.markdown("
Created with Streamlit and PyTorch
", unsafe_allow_html=True)