Spaces:
Build error
Build error
ahmedghani
commited on
Commit
•
18fc351
1
Parent(s):
e8f5478
adding model files
Browse files- README.md +1 -1
- Resnet101.py +98 -0
- app.py +75 -0
- requirements.txt +5 -0
- resnet101_ckpt.pth +3 -0
- samples/car1.jpg +0 -0
- samples/car2.jpeg +0 -0
- samples/car3.jpg +0 -0
- samples/car4.jpg +0 -0
- samples/car5.jpg +0 -0
- samples/cat1.jpg +0 -0
- samples/cat2.jpg +0 -0
- samples/cat3.jpeg +0 -0
- samples/cat4.png +0 -0
- samples/cat5.jpg +0 -0
- samples/dog1.jpeg +0 -0
- samples/dog2.jpg +0 -0
- samples/dog3.jpg +0 -0
- samples/dog4.jpg +0 -0
- samples/dog5.jpg +0 -0
- samples/horse1.jpg +0 -0
- samples/horse2.jpg +0 -0
- samples/horse3.jpeg +0 -0
- samples/horse4.jpg +0 -0
- samples/horse5.jpg +0 -0
- samples/not-found.jpg +0 -0
- train.py +125 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Image Classification On CIFAR10
|
3 |
-
emoji:
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Image Classification On CIFAR10
|
3 |
+
emoji: 📷
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
Resnet101.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
|
7 |
+
class BasicBlock(nn.Module):
|
8 |
+
expansion = 1
|
9 |
+
|
10 |
+
def __init__(self, in_planes, planes, stride=1):
|
11 |
+
super(BasicBlock, self).__init__()
|
12 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
13 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
14 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
15 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
16 |
+
|
17 |
+
self.shortcut = nn.Sequential()
|
18 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
19 |
+
self.shortcut = nn.Sequential(
|
20 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
21 |
+
nn.BatchNorm2d(self.expansion*planes)
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
26 |
+
out = self.bn2(self.conv2(out))
|
27 |
+
out += self.shortcut(x)
|
28 |
+
out = F.relu(out)
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
class Bottleneck(nn.Module):
|
33 |
+
expansion = 4
|
34 |
+
|
35 |
+
def __init__(self, in_planes, planes, stride=1):
|
36 |
+
super(Bottleneck, self).__init__()
|
37 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
38 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
39 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
40 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
41 |
+
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
42 |
+
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
43 |
+
|
44 |
+
self.shortcut = nn.Sequential()
|
45 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
46 |
+
self.shortcut = nn.Sequential(
|
47 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
48 |
+
nn.BatchNorm2d(self.expansion*planes)
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
53 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
54 |
+
out = self.bn3(self.conv3(out))
|
55 |
+
out += self.shortcut(x)
|
56 |
+
out = F.relu(out)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class ResNet(nn.Module):
|
61 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
62 |
+
super(ResNet, self).__init__()
|
63 |
+
self.in_planes = 64
|
64 |
+
|
65 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
66 |
+
self.bn1 = nn.BatchNorm2d(64)
|
67 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
68 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
69 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
70 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
71 |
+
self.linear = nn.Linear(512*block.expansion, num_classes)
|
72 |
+
|
73 |
+
if block == BasicBlock:
|
74 |
+
self.name = "resnet" + str(sum(num_blocks) * 2 + 2)
|
75 |
+
else:
|
76 |
+
self.name = "resnet" + str(sum(num_blocks) * 3 + 2)
|
77 |
+
|
78 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
79 |
+
strides = [stride] + [1]*(num_blocks-1)
|
80 |
+
layers = []
|
81 |
+
for stride in strides:
|
82 |
+
layers.append(block(self.in_planes, planes, stride))
|
83 |
+
self.in_planes = planes * block.expansion
|
84 |
+
return nn.Sequential(*layers)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
88 |
+
out = self.layer1(out)
|
89 |
+
out = self.layer2(out)
|
90 |
+
out = self.layer3(out)
|
91 |
+
out = self.layer4(out)
|
92 |
+
out = F.avg_pool2d(out, 4)
|
93 |
+
out = out.view(out.size(0), -1)
|
94 |
+
out = self.linear(out)
|
95 |
+
return out
|
96 |
+
|
97 |
+
def ResNet101():
|
98 |
+
return ResNet(Bottleneck, [3,4,23,3])
|
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Resnet101 import *
|
2 |
+
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
print("Loading Resnet101 model...")
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
model = torch.load("resnet101_ckpt.pth", map_location=device)
|
8 |
+
net = ResNet101()
|
9 |
+
net.to(device)
|
10 |
+
net = torch.nn.DataParallel(net)
|
11 |
+
net.load_state_dict(model['net'])
|
12 |
+
|
13 |
+
print("Model loaded")
|
14 |
+
print("Device: ", device)
|
15 |
+
|
16 |
+
# Define a transform to convert the image to tensor
|
17 |
+
transform = transforms.Compose([
|
18 |
+
transforms.Resize([32, 32]),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
21 |
+
])
|
22 |
+
|
23 |
+
def predict_image(image):
|
24 |
+
|
25 |
+
# Convert the image to PyTorch tensor
|
26 |
+
img_tensor = transform(Image.fromarray(image))
|
27 |
+
img_tensor.to(device)
|
28 |
+
with torch.no_grad():
|
29 |
+
outputs = net(img_tensor[None, ...])
|
30 |
+
_, predicted = outputs.max(1)
|
31 |
+
classes = ['plane', 'car', 'bird', 'cat', 'deer',
|
32 |
+
'dog', 'frog', 'horse', 'ship', 'truck']
|
33 |
+
res = classes[predicted[0].item()]
|
34 |
+
print("Predicted class: ", res)
|
35 |
+
if res == 'car':
|
36 |
+
return Image.open("samples/car2.jpeg"), Image.open("samples/car3.jpg"), Image.open("samples/car4.jpg"), Image.open("samples/car5.jpg")
|
37 |
+
elif res == 'cat':
|
38 |
+
return Image.open("samples/cat2.jpg"), Image.open("samples/cat3.jpeg"), Image.open("samples/cat4.png"), Image.open("samples/cat5.jpg")
|
39 |
+
elif res == 'dog':
|
40 |
+
return Image.open("samples/dog2.jpg"), Image.open("samples/dog3.jpg"), Image.open("samples/dog4.jpg"), Image.open("samples/dog5.jpg")
|
41 |
+
elif res == 'horse':
|
42 |
+
return Image.open("samples/horse2.jpg"), Image.open("samples/horse3.jpeg"), Image.open("samples/horse4.jpg"), Image.open("samples/horse5.jpg")
|
43 |
+
else:
|
44 |
+
return Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg")
|
45 |
+
|
46 |
+
def set_example_image(example: list) -> dict:
|
47 |
+
return gr.Image.update(value=example[0])
|
48 |
+
|
49 |
+
demo = gr.Blocks()
|
50 |
+
with demo:
|
51 |
+
gr.Markdown('''
|
52 |
+
<center>
|
53 |
+
<h1>Image Classification trained on Resnet101</h1>
|
54 |
+
<p>
|
55 |
+
Image classification model trained on Resnet101. The dataset used is the CIFAR-10 dataset.
|
56 |
+
It will detect 4 classes of images: car, cat, dog and horse. Then it will show you 4 images of the same class.
|
57 |
+
</p>
|
58 |
+
</center>
|
59 |
+
''')
|
60 |
+
|
61 |
+
with gr.Row():
|
62 |
+
input_image = gr.Image(label="Input image")
|
63 |
+
with gr.Row():
|
64 |
+
output_imgs = [gr.Image(label='Closest Image 1', type='numpy', interactive=False),
|
65 |
+
gr.Image(label='Closest Image 2', type='numpy', interactive=False),
|
66 |
+
gr.Image(label='Closest Image 3', type='numpy', interactive=False),
|
67 |
+
gr.Image(label='Closest Image 4', type='numpy', interactive=False)]
|
68 |
+
button = gr.Button("Classifier")
|
69 |
+
with gr.Row():
|
70 |
+
example_images = gr.Dataset(components=[input_image],
|
71 |
+
samples=[["samples/cat1.jpg"], ["samples/car1.jpg"], ["samples/dog1.jpeg"], ["samples/horse1.jpg"]])
|
72 |
+
example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)
|
73 |
+
button.click(predict_image, inputs=input_image, outputs=output_imgs)
|
74 |
+
|
75 |
+
demo.launch(debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
torchaudio
|
4 |
+
opencv-python
|
5 |
+
gradio
|
resnet101_ckpt.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57f0c7486996c89c17d88249ab1a3620da4affb29f9dfa86917bf96028a8b0bc
|
3 |
+
size 170593359
|
samples/car1.jpg
ADDED
samples/car2.jpeg
ADDED
samples/car3.jpg
ADDED
samples/car4.jpg
ADDED
samples/car5.jpg
ADDED
samples/cat1.jpg
ADDED
samples/cat2.jpg
ADDED
samples/cat3.jpeg
ADDED
samples/cat4.png
ADDED
samples/cat5.jpg
ADDED
samples/dog1.jpeg
ADDED
samples/dog2.jpg
ADDED
samples/dog3.jpg
ADDED
samples/dog4.jpg
ADDED
samples/dog5.jpg
ADDED
samples/horse1.jpg
ADDED
samples/horse2.jpg
ADDED
samples/horse3.jpeg
ADDED
samples/horse4.jpg
ADDED
samples/horse5.jpg
ADDED
samples/not-found.jpg
ADDED
train.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Train CIFAR10 with PyTorch.'''
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.backends.cudnn as cudnn
|
7 |
+
|
8 |
+
import torchvision
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
|
11 |
+
import os
|
12 |
+
from Resnet101 import *
|
13 |
+
|
14 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
+
best_acc = 0 # best test accuracy
|
16 |
+
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
17 |
+
end_epoch = 300
|
18 |
+
resume = False
|
19 |
+
|
20 |
+
# Data
|
21 |
+
print('==> Preparing data..')
|
22 |
+
transform_train = transforms.Compose([
|
23 |
+
transforms.RandomCrop(32, padding=4),
|
24 |
+
transforms.RandomHorizontalFlip(),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
27 |
+
])
|
28 |
+
|
29 |
+
transform_test = transforms.Compose([
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
32 |
+
])
|
33 |
+
|
34 |
+
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
|
35 |
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
|
36 |
+
|
37 |
+
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
|
38 |
+
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
|
39 |
+
|
40 |
+
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
41 |
+
|
42 |
+
# Model
|
43 |
+
print('==> Building model..')
|
44 |
+
net = ResNet101()
|
45 |
+
net_name = net.name
|
46 |
+
save_path = './checkpoint/{0}_ckpt.pth'.format(net.name)
|
47 |
+
net = net.to(device)
|
48 |
+
if device == 'cuda':
|
49 |
+
net = torch.nn.DataParallel(net)
|
50 |
+
cudnn.benchmark = True
|
51 |
+
|
52 |
+
if resume:
|
53 |
+
# Load best checkpoint trained last time.
|
54 |
+
print('==> Resuming from checkpoint..')
|
55 |
+
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
|
56 |
+
checkpoint = torch.load(save_path)
|
57 |
+
net.load_state_dict(checkpoint['net'])
|
58 |
+
best_acc = checkpoint['acc']
|
59 |
+
start_epoch = checkpoint['epoch']
|
60 |
+
|
61 |
+
criterion = nn.CrossEntropyLoss()
|
62 |
+
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
63 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=70, gamma=0.1)
|
64 |
+
|
65 |
+
# Training
|
66 |
+
def train(epoch):
|
67 |
+
print('\nEpoch: %d' % epoch)
|
68 |
+
net.train()
|
69 |
+
train_loss = 0
|
70 |
+
correct = 0
|
71 |
+
total = 0
|
72 |
+
for batch_idx, (inputs, targets) in enumerate(trainloader):
|
73 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
74 |
+
optimizer.zero_grad()
|
75 |
+
outputs = net(inputs)
|
76 |
+
loss = criterion(outputs, targets)
|
77 |
+
loss.backward()
|
78 |
+
optimizer.step()
|
79 |
+
|
80 |
+
train_loss += loss.item()
|
81 |
+
_, predicted = outputs.max(1)
|
82 |
+
total += targets.size(0)
|
83 |
+
correct += predicted.eq(targets).sum().item()
|
84 |
+
print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
85 |
+
|
86 |
+
def test(epoch):
|
87 |
+
global best_acc
|
88 |
+
net.eval()
|
89 |
+
test_loss = 0
|
90 |
+
correct = 0
|
91 |
+
total = 0
|
92 |
+
with torch.no_grad():
|
93 |
+
for batch_idx, (inputs, targets) in enumerate(testloader):
|
94 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
95 |
+
outputs = net(inputs)
|
96 |
+
loss = criterion(outputs, targets)
|
97 |
+
|
98 |
+
test_loss += loss.item()
|
99 |
+
_, predicted = outputs.max(1)
|
100 |
+
total += targets.size(0)
|
101 |
+
correct += predicted.eq(targets).sum().item()
|
102 |
+
|
103 |
+
print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
104 |
+
|
105 |
+
# Save checkpoint.
|
106 |
+
acc = 100.*correct/total
|
107 |
+
if acc > best_acc:
|
108 |
+
print('Saving ' + net_name + ' ..')
|
109 |
+
state = {
|
110 |
+
'net': net.state_dict(),
|
111 |
+
'acc': acc,
|
112 |
+
'epoch': epoch,
|
113 |
+
}
|
114 |
+
if not os.path.isdir('checkpoint'):
|
115 |
+
os.mkdir('checkpoint')
|
116 |
+
torch.save(state, save_path)
|
117 |
+
best_acc = acc
|
118 |
+
|
119 |
+
|
120 |
+
for epoch in range(start_epoch, end_epoch):
|
121 |
+
train(epoch)
|
122 |
+
test(epoch)
|
123 |
+
scheduler.step()
|
124 |
+
|
125 |
+
print("\nTesting best accuracy:", best_acc)
|