eaglelandsonce commited on
Commit
03020ab
·
verified ·
1 Parent(s): 68c5fd5

Create 19_ResNet.py

Browse files
Files changed (1) hide show
  1. pages/19_ResNet.py +163 -0
pages/19_ResNet.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install necessary packages
2
+ # Ensure you have PyTorch, torchvision, and Streamlit installed
3
+ # You can install them using pip if you haven't already:
4
+ # pip install torch torchvision streamlit
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from torchvision import datasets, models, transforms
10
+ from torch.utils.data import DataLoader
11
+ import numpy as np
12
+ import time
13
+ import os
14
+ import copy
15
+ import streamlit as st
16
+ from PIL import Image
17
+ import matplotlib.pyplot as plt
18
+ import torchvision.transforms as T
19
+
20
+ # Data transformations
21
+ data_transforms = {
22
+ 'train': transforms.Compose([
23
+ transforms.RandomResizedCrop(224),
24
+ transforms.RandomHorizontalFlip(),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
27
+ ]),
28
+ 'val': transforms.Compose([
29
+ transforms.Resize(256),
30
+ transforms.CenterCrop(224),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33
+ ]),
34
+ }
35
+
36
+ # Load datasets
37
+ data_dir = 'path/to/data'
38
+ image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
39
+ for x in ['train', 'val']}
40
+ dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
41
+ for x in ['train', 'val']}
42
+ dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
43
+ class_names = image_datasets['train'].classes
44
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
+
46
+ # Load the pre-trained model
47
+ model_ft = models.resnet18(pretrained=True)
48
+ num_ftrs = model_ft.fc.in_features
49
+ model_ft.fc = nn.Linear(num_ftrs, len(class_names))
50
+ model_ft = model_ft.to(device)
51
+
52
+ # Define loss function and optimizer
53
+ criterion = nn.CrossEntropyLoss()
54
+ optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
55
+ scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
56
+
57
+ # Training and evaluation functions
58
+ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
59
+ since = time.time()
60
+
61
+ best_model_wts = copy.deepcopy(model.state_dict())
62
+ best_acc = 0.0
63
+
64
+ for epoch in range(num_epochs):
65
+ print('Epoch {}/{}'.format(epoch, num_epochs - 1))
66
+ print('-' * 10)
67
+
68
+ for phase in ['train', 'val']:
69
+ if phase == 'train':
70
+ model.train()
71
+ else:
72
+ model.eval()
73
+
74
+ running_loss = 0.0
75
+ running_corrects = 0
76
+
77
+ for inputs, labels in dataloaders[phase]:
78
+ inputs = inputs.to(device)
79
+ labels = labels.to(device)
80
+
81
+ optimizer.zero_grad()
82
+
83
+ with torch.set_grad_enabled(phase == 'train'):
84
+ outputs = model(inputs)
85
+ _, preds = torch.max(outputs, 1)
86
+ loss = criterion(outputs, labels)
87
+
88
+ if phase == 'train':
89
+ loss.backward()
90
+ optimizer.step()
91
+
92
+ running_loss += loss.item() * inputs.size(0)
93
+ running_corrects += torch.sum(preds == labels.data)
94
+
95
+ if phase == 'train':
96
+ scheduler.step()
97
+
98
+ epoch_loss = running_loss / dataset_sizes[phase]
99
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
100
+
101
+ print('{} Loss: {:.4f} Acc: {:.4f}'.format(
102
+ phase, epoch_loss, epoch_acc))
103
+
104
+ if phase == 'val' and epoch_acc > best_acc:
105
+ best_acc = epoch_acc
106
+ best_model_wts = copy.deepcopy(model.state_dict())
107
+
108
+ print()
109
+
110
+ time_elapsed = time.time() - since
111
+ print('Training complete in {:.0f}m {:.0f}s'.format(
112
+ time_elapsed // 60, time_elapsed % 60))
113
+ print('Best val Acc: {:4f}'.format(best_acc))
114
+
115
+ model.load_state_dict(best_model_wts)
116
+ return model
117
+
118
+ # Train the model
119
+ model_ft = train_model(model_ft, criterion, optimizer_ft, scheduler, num_epochs=25)
120
+
121
+ # Save the trained model
122
+ torch.save(model_ft.state_dict(), 'model_ft.pth')
123
+
124
+ # Streamlit Interface
125
+ st.title("Image Classification with Fine-tuned ResNet")
126
+
127
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
128
+
129
+ if uploaded_file is not None:
130
+ image = Image.open(uploaded_file)
131
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
132
+ st.write("")
133
+ st.write("Classifying...")
134
+
135
+ model_ft = models.resnet18(pretrained=True)
136
+ num_ftrs = model_ft.fc.in_features
137
+ model_ft.fc = nn.Linear(num_ftrs, len(class_names))
138
+ model_ft.load_state_dict(torch.load('model_ft.pth'))
139
+ model_ft = model_ft.to(device)
140
+ model_ft.eval()
141
+
142
+ preprocess = T.Compose([
143
+ T.Resize(256),
144
+ T.CenterCrop(224),
145
+ T.ToTensor(),
146
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
147
+ ])
148
+
149
+ img = preprocess(image).unsqueeze(0)
150
+ img = img.to(device)
151
+
152
+ with torch.no_grad():
153
+ outputs = model_ft(img)
154
+ _, preds = torch.max(outputs, 1)
155
+ predicted_class = class_names[preds[0]]
156
+
157
+ st.write(f"Predicted Class: {predicted_class}")
158
+
159
+ # Plotting the image with matplotlib
160
+ fig, ax = plt.subplots()
161
+ ax.imshow(image)
162
+ ax.set_title(f"Predicted: {predicted_class}")
163
+ st.pyplot(fig)