rajsecrets0 commited on
Commit
ce28db8
·
verified ·
1 Parent(s): e35167f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ import io
9
+
10
+ # Set page config
11
+ st.set_page_config(page_title="CIFAR-10 Classifier", layout="centered", initial_sidebar_state="collapsed")
12
+
13
+ # Custom CSS for dark theme
14
+ st.markdown("""
15
+ <style>
16
+ .stApp {
17
+ background-color: #0E1117;
18
+ color: #FAFAFA;
19
+ }
20
+ .stButton>button {
21
+ background-color: #4CAF50;
22
+ color: white;
23
+ }
24
+ .stHeader {
25
+ background-color: #262730;
26
+ color: white;
27
+ padding: 1rem;
28
+ border-radius: 5px;
29
+ }
30
+ .stImage {
31
+ background-color: #262730;
32
+ padding: 10px;
33
+ border-radius: 5px;
34
+ }
35
+ .stSuccess {
36
+ background-color: #262730;
37
+ color: #4CAF50;
38
+ padding: 10px;
39
+ border-radius: 5px;
40
+ }
41
+ </style>
42
+ """, unsafe_allow_html=True)
43
+
44
+ # Model definition
45
+ class SimpleCNN(nn.Module):
46
+ def __init__(self):
47
+ super(SimpleCNN, self).__init__()
48
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
49
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
50
+ self.pool = nn.MaxPool2d(2, 2)
51
+ self.fc1 = nn.Linear(64 * 8 * 8, 512)
52
+ self.fc2 = nn.Linear(512, 10)
53
+
54
+ def forward(self, x):
55
+ x = self.pool(torch.relu(self.conv1(x)))
56
+ x = self.pool(torch.relu(self.conv2(x)))
57
+ x = x.view(-1, 64 * 8 * 8)
58
+ x = torch.relu(self.fc1(x))
59
+ x = self.fc2(x)
60
+ return x
61
+
62
+ # Function to train the model
63
+ @st.cache_resource
64
+ def train_model():
65
+ transform = transforms.Compose([
66
+ transforms.ToTensor(),
67
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
68
+ ])
69
+
70
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
71
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
72
+
73
+ model = SimpleCNN()
74
+ criterion = nn.CrossEntropyLoss()
75
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
76
+
77
+ for epoch in range(5): # Train for 5 epochs
78
+ for i, data in enumerate(trainloader, 0):
79
+ inputs, labels = data
80
+ optimizer.zero_grad()
81
+ outputs = model(inputs)
82
+ loss = criterion(outputs, labels)
83
+ loss.backward()
84
+ optimizer.step()
85
+
86
+ return model
87
+
88
+ # Function to load or train the model
89
+ @st.cache_resource
90
+ def get_model():
91
+ try:
92
+ model = SimpleCNN()
93
+ model.load_state_dict(torch.load('cifar10_model.pth'))
94
+ model.eval()
95
+ except:
96
+ model = train_model()
97
+ torch.save(model.state_dict(), 'cifar10_model.pth')
98
+ return model
99
+
100
+ # Streamlit app
101
+ st.markdown("<h1 class='stHeader'>CIFAR-10 Image Classification</h1>", unsafe_allow_html=True)
102
+ st.write("Upload an image to classify it into one of the CIFAR-10 categories.")
103
+
104
+ # File uploader
105
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
106
+
107
+ if uploaded_file is not None:
108
+ # Display uploaded image
109
+ image = Image.open(uploaded_file)
110
+ st.markdown("<div class='stImage'>", unsafe_allow_html=True)
111
+ st.image(image, caption='Uploaded Image', use_column_width=True)
112
+ st.markdown("</div>", unsafe_allow_html=True)
113
+
114
+ # Predict button
115
+ if st.button('Classify Image'):
116
+ # Load model
117
+ model = get_model()
118
+
119
+ # Preprocess image
120
+ transform = transforms.Compose([
121
+ transforms.Resize((32, 32)),
122
+ transforms.ToTensor(),
123
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
124
+ ])
125
+ input_tensor = transform(image).unsqueeze(0)
126
+
127
+ # Make prediction
128
+ with torch.no_grad():
129
+ output = model(input_tensor)
130
+ _, predicted = torch.max(output, 1)
131
+
132
+ # Display result
133
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
134
+ st.markdown(f"<div class='stSuccess'>Prediction: {classes[predicted.item()]}</div>", unsafe_allow_html=True)
135
+
136
+ # Footer
137
+ st.markdown("---")
138
+ st.markdown("<p style='text-align: center; color: #666;'>Created with Streamlit and PyTorch</p>", unsafe_allow_html=True)