eaglelandsonce commited on
Commit
5368558
·
verified ·
1 Parent(s): 11b52a7

Delete pages/20_ResNet2.py

Browse files
Files changed (1) hide show
  1. pages/20_ResNet2.py +0 -128
pages/20_ResNet2.py DELETED
@@ -1,128 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import matplotlib.pyplot as plt
6
- import torchvision.transforms as transforms
7
- from torchvision.datasets import CIFAR10
8
- from torch.utils.data import DataLoader
9
-
10
- # Define the ResNet model
11
- class BasicBlock(nn.Module):
12
- expansion = 1
13
-
14
- def __init__(self, in_planes, planes, stride=1):
15
- super(BasicBlock, self).__init__()
16
- self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
17
- self.bn1 = nn.BatchNorm2d(planes)
18
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
19
- self.bn2 = nn.BatchNorm2d(planes)
20
-
21
- self.shortcut = nn.Sequential()
22
- if stride != 1 or in_planes != self.expansion * planes:
23
- self.shortcut = nn.Sequential(
24
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
25
- nn.BatchNorm2d(self.expansion * planes)
26
- )
27
-
28
- def forward(self, x):
29
- identity = x
30
- out = F.relu(self.bn1(self.conv1(x)))
31
- out = self.bn2(self.conv2(out))
32
- out += self.shortcut(identity)
33
- out = F.relu(out)
34
- return out
35
-
36
- class ResNet(nn.Module):
37
- def __init__(self, block, num_blocks, num_classes=10):
38
- super(ResNet, self).__init__()
39
- self.in_planes = 64
40
-
41
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
42
- self.bn1 = nn.BatchNorm2d(64)
43
- self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
44
- self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
45
- self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
46
- self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
47
- self.linear = nn.Linear(512 * block.expansion, num_classes)
48
-
49
- def _make_layer(self, block, planes, num_blocks, stride):
50
- strides = [stride] + [1] * (num_blocks - 1)
51
- layers = []
52
- for stride in strides:
53
- layers.append(block(self.in_planes, planes, stride))
54
- self.in_planes = planes * block.expansion
55
- return nn.Sequential(*layers)
56
-
57
- def forward(self, x):
58
- out = F.relu(self.bn1(self.conv1(x)))
59
- out = self.layer1(out)
60
- out = self.layer2(out)
61
- out = self.layer3(out)
62
- out = self.layer4(out)
63
- out = F.avg_pool2d(out, 4)
64
- out = out.view(out.size(0), -1)
65
- out = self.linear(out)
66
- return out
67
-
68
- def ResNet18():
69
- return ResNet(BasicBlock, [2, 2, 2, 2])
70
-
71
- # Define a function to load CIFAR-10 dataset
72
- def load_data():
73
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
74
- train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
75
- train_loader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=2)
76
- return train_loader
77
-
78
- # Streamlit Interface
79
- st.title('ResNet with Streamlit')
80
- st.write("This is an example of integrating a ResNet model with Streamlit.")
81
-
82
- # Load data button
83
- if st.button('Load Data'):
84
- st.write("Loading CIFAR-10 data...")
85
- train_loader = load_data()
86
- st.write("Data loaded successfully!")
87
-
88
- # Initialize and test the model
89
- if st.button('Initialize and Test ResNet18'):
90
- net = ResNet18()
91
- sample_input = torch.randn(1, 3, 32, 32)
92
- output = net(sample_input)
93
- st.write("Output size: ", output.size())
94
-
95
- # Train the model (for demonstration, we'll just do one epoch)
96
- if st.button('Train ResNet18'):
97
- st.write("Training ResNet18 on CIFAR-10...")
98
- net = ResNet18()
99
- train_loader = load_data()
100
- criterion = nn.CrossEntropyLoss()
101
- optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
102
-
103
- net.train()
104
- for epoch in range(1): # Single epoch for demonstration
105
- running_loss = 0.0
106
- for i, data in enumerate(train_loader, 0):
107
- inputs, labels = data
108
- optimizer.zero_grad()
109
- outputs = net(inputs)
110
- loss = criterion(outputs, labels)
111
- loss.backward()
112
- optimizer.step()
113
- running_loss += loss.item()
114
- if i % 100 == 99: # Print every 100 mini-batches
115
- st.write(f'Epoch [{epoch + 1}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')
116
- running_loss = 0.0
117
-
118
- st.write("Training complete!")
119
-
120
- # Plotting example (dummy plot for demonstration)
121
- if st.button('Show Plot'):
122
- st.write("Displaying a sample plot...")
123
- fig, ax = plt.subplots()
124
- ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
125
- st.pyplot(fig)
126
-
127
- # To run the Streamlit app, use the command below in your terminal:
128
- # streamlit run your_script_name.py