eaglelandsonce commited on
Commit
45282e7
·
verified ·
1 Parent(s): 925cb8a

Create 15_CNN.py

Browse files
Files changed (1) hide show
  1. pages/15_CNN.py +110 -0
pages/15_CNN.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ import torchvision
7
+ from torch.utils.data import DataLoader
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ # Define the CNN model
12
+ class Net(nn.Module):
13
+ def __init__(self):
14
+ super(Net, self).__init__()
15
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
16
+ self.pool = nn.MaxPool2d(2, 2)
17
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
18
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
19
+ self.fc1 = nn.Linear(128 * 4 * 4, 256)
20
+ self.fc2 = nn.Linear(256, 128)
21
+ self.fc3 = nn.Linear(128, 10)
22
+
23
+ def forward(self, x):
24
+ x = self.pool(F.relu(self.conv1(x)))
25
+ x = self.pool(F.relu(self.conv2(x)))
26
+ x = self.pool(F.relu(self.conv3(x)))
27
+ x = x.view(-1, 128 * 4 * 4)
28
+ x = F.relu(self.fc1(x))
29
+ x = F.relu(self.fc2(x))
30
+ x = self.fc3(x)
31
+ return x
32
+
33
+ # Load pre-trained model (if available)
34
+ def load_model():
35
+ net = Net()
36
+ try:
37
+ net.load_state_dict(torch.load('cnn_model.pth'))
38
+ net.eval()
39
+ st.write("Model loaded successfully")
40
+ except FileNotFoundError:
41
+ st.write("No pre-trained model found. Please train the model first.")
42
+ return net
43
+
44
+ # Function to predict the class of an image
45
+ def predict(image, model):
46
+ transform = transforms.Compose([
47
+ transforms.Resize((32, 32)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
50
+ ])
51
+ image = transform(image).unsqueeze(0)
52
+ outputs = model(image)
53
+ _, predicted = torch.max(outputs.data, 1)
54
+ return predicted.item()
55
+
56
+ # Training function
57
+ def train_model():
58
+ transform = transforms.Compose([
59
+ transforms.ToTensor(),
60
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
61
+ ])
62
+
63
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
64
+ trainloader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)
65
+
66
+ net = Net()
67
+ criterion = nn.CrossEntropyLoss()
68
+ optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
69
+
70
+ st.write("Training the model...")
71
+ for epoch in range(10):
72
+ running_loss = 0.0
73
+ for i, data in enumerate(trainloader, 0):
74
+ inputs, labels = data
75
+ optimizer.zero_grad()
76
+ outputs = net(inputs)
77
+ loss = criterion(outputs, labels)
78
+ loss.backward()
79
+ optimizer.step()
80
+
81
+ running_loss += loss.item()
82
+ if i % 100 == 99:
83
+ st.write(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
84
+ running_loss = 0.0
85
+
86
+ st.write('Finished Training')
87
+ torch.save(net.state_dict(), 'cnn_model.pth')
88
+ st.write('Model saved as cnn_model.pth')
89
+ return net
90
+
91
+ # Streamlit interface
92
+ st.title("CNN Image Classification with CIFAR-10")
93
+
94
+ mode = st.sidebar.selectbox("Mode", ["Train", "Predict"])
95
+
96
+ if mode == "Train":
97
+ if st.button("Train Model"):
98
+ model = train_model()
99
+
100
+ if mode == "Predict":
101
+ model = load_model()
102
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
103
+ if uploaded_file is not None:
104
+ image = Image.open(uploaded_file)
105
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
106
+ st.write("")
107
+ st.write("Classifying...")
108
+ class_idx = predict(image, model)
109
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
110
+ st.write(f'Prediction: {classes[class_idx]}')