|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import io |
|
|
|
|
|
st.set_page_config(page_title="CIFAR-10 Classifier", layout="wide", initial_sidebar_state="expanded") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stApp { |
|
background-color: #0E1117; |
|
color: #FAFAFA; |
|
} |
|
.stButton>button { |
|
background-color: #4CAF50; |
|
color: white; |
|
} |
|
.stHeader { |
|
background-color: #262730; |
|
color: white; |
|
padding: 1rem; |
|
border-radius: 5px; |
|
margin-bottom: 1rem; |
|
} |
|
.stImage { |
|
background-color: #262730; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
.stSuccess { |
|
background-color: #262730; |
|
color: #4CAF50; |
|
padding: 10px; |
|
border-radius: 5px; |
|
margin-top: 1rem; |
|
} |
|
.upload-box { |
|
border: 2px dashed #4CAF50; |
|
border-radius: 5px; |
|
padding: 20px; |
|
text-align: center; |
|
cursor: pointer; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
class SimpleCNN(nn.Module): |
|
def __init__(self): |
|
super(SimpleCNN, self).__init__() |
|
self.conv1 = nn.Conv2d(3, 32, 3, padding=1) |
|
self.conv2 = nn.Conv2d(32, 64, 3, padding=1) |
|
self.pool = nn.MaxPool2d(2, 2) |
|
self.fc1 = nn.Linear(64 * 8 * 8, 512) |
|
self.fc2 = nn.Linear(512, 10) |
|
|
|
def forward(self, x): |
|
x = self.pool(torch.relu(self.conv1(x))) |
|
x = self.pool(torch.relu(self.conv2(x))) |
|
x = x.view(-1, 64 * 8 * 8) |
|
x = torch.relu(self.fc1(x)) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
@st.cache_resource |
|
def train_model(): |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) |
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) |
|
|
|
model = SimpleCNN() |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
for epoch in range(5): |
|
for i, data in enumerate(trainloader, 0): |
|
inputs, labels = data |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
return model |
|
|
|
|
|
@st.cache_resource |
|
def get_model(): |
|
try: |
|
model = SimpleCNN() |
|
model.load_state_dict(torch.load('cifar10_model.pth')) |
|
model.eval() |
|
except: |
|
model = train_model() |
|
torch.save(model.state_dict(), 'cifar10_model.pth') |
|
return model |
|
|
|
|
|
st.sidebar.title("Upload Image") |
|
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
st.markdown("<h1 class='stHeader'>CIFAR-10 Image Classification</h1>", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2, col3 = st.columns([1,2,1]) |
|
|
|
|
|
|
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
col1, col2, col3 = st.columns([1,2,1]) |
|
with col2: |
|
st.markdown("<div class='stImage'>", unsafe_allow_html=True) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
|
|
model = get_model() |
|
transform = transforms.Compose([ |
|
transforms.Resize((32, 32)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
input_tensor = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
|
|
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') |
|
_, predicted = torch.max(output, 1) |
|
st.sidebar.markdown("<div class='stSuccess'>", unsafe_allow_html=True) |
|
st.sidebar.write(f"Best Prediction: {classes[predicted.item()]}") |
|
st.sidebar.markdown("</div>", unsafe_allow_html=True) |
|
|
|
st.sidebar.write("Prediction Probabilities:") |
|
for i, prob in enumerate(probabilities): |
|
st.sidebar.write(f"{classes[i]}: {prob.item():.2%}") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("<p style='text-align: center; color: #666;'>Created with Streamlit and PyTorch</p>", unsafe_allow_html=True) |