File size: 4,304 Bytes
ce28db8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import io
# Set page config
st.set_page_config(page_title="CIFAR-10 Classifier", layout="centered", initial_sidebar_state="collapsed")
# Custom CSS for dark theme
st.markdown("""
<style>
.stApp {
background-color: #0E1117;
color: #FAFAFA;
}
.stButton>button {
background-color: #4CAF50;
color: white;
}
.stHeader {
background-color: #262730;
color: white;
padding: 1rem;
border-radius: 5px;
}
.stImage {
background-color: #262730;
padding: 10px;
border-radius: 5px;
}
.stSuccess {
background-color: #262730;
color: #4CAF50;
padding: 10px;
border-radius: 5px;
}
</style>
""", unsafe_allow_html=True)
# Model definition
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Function to train the model
@st.cache_resource
def train_model():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(5): # Train for 5 epochs
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
return model
# Function to load or train the model
@st.cache_resource
def get_model():
try:
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_model.pth'))
model.eval()
except:
model = train_model()
torch.save(model.state_dict(), 'cifar10_model.pth')
return model
# Streamlit app
st.markdown("<h1 class='stHeader'>CIFAR-10 Image Classification</h1>", unsafe_allow_html=True)
st.write("Upload an image to classify it into one of the CIFAR-10 categories.")
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.markdown("<div class='stImage'>", unsafe_allow_html=True)
st.image(image, caption='Uploaded Image', use_column_width=True)
st.markdown("</div>", unsafe_allow_html=True)
# Predict button
if st.button('Classify Image'):
# Load model
model = get_model()
# Preprocess image
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output, 1)
# Display result
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
st.markdown(f"<div class='stSuccess'>Prediction: {classes[predicted.item()]}</div>", unsafe_allow_html=True)
# Footer
st.markdown("---")
st.markdown("<p style='text-align: center; color: #666;'>Created with Streamlit and PyTorch</p>", unsafe_allow_html=True) |