ahmedghani commited on
Commit
18fc351
1 Parent(s): e8f5478

adding model files

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Image Classification On CIFAR10
3
- emoji: 🔥
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
 
1
  ---
2
  title: Image Classification On CIFAR10
3
+ emoji: 📷
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
Resnet101.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+
7
+ class BasicBlock(nn.Module):
8
+ expansion = 1
9
+
10
+ def __init__(self, in_planes, planes, stride=1):
11
+ super(BasicBlock, self).__init__()
12
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(planes)
14
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
15
+ self.bn2 = nn.BatchNorm2d(planes)
16
+
17
+ self.shortcut = nn.Sequential()
18
+ if stride != 1 or in_planes != self.expansion*planes:
19
+ self.shortcut = nn.Sequential(
20
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
21
+ nn.BatchNorm2d(self.expansion*planes)
22
+ )
23
+
24
+ def forward(self, x):
25
+ out = F.relu(self.bn1(self.conv1(x)))
26
+ out = self.bn2(self.conv2(out))
27
+ out += self.shortcut(x)
28
+ out = F.relu(out)
29
+ return out
30
+
31
+
32
+ class Bottleneck(nn.Module):
33
+ expansion = 4
34
+
35
+ def __init__(self, in_planes, planes, stride=1):
36
+ super(Bottleneck, self).__init__()
37
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
38
+ self.bn1 = nn.BatchNorm2d(planes)
39
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
40
+ self.bn2 = nn.BatchNorm2d(planes)
41
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
42
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
43
+
44
+ self.shortcut = nn.Sequential()
45
+ if stride != 1 or in_planes != self.expansion*planes:
46
+ self.shortcut = nn.Sequential(
47
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
48
+ nn.BatchNorm2d(self.expansion*planes)
49
+ )
50
+
51
+ def forward(self, x):
52
+ out = F.relu(self.bn1(self.conv1(x)))
53
+ out = F.relu(self.bn2(self.conv2(out)))
54
+ out = self.bn3(self.conv3(out))
55
+ out += self.shortcut(x)
56
+ out = F.relu(out)
57
+ return out
58
+
59
+
60
+ class ResNet(nn.Module):
61
+ def __init__(self, block, num_blocks, num_classes=10):
62
+ super(ResNet, self).__init__()
63
+ self.in_planes = 64
64
+
65
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
66
+ self.bn1 = nn.BatchNorm2d(64)
67
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
68
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
69
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
70
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
71
+ self.linear = nn.Linear(512*block.expansion, num_classes)
72
+
73
+ if block == BasicBlock:
74
+ self.name = "resnet" + str(sum(num_blocks) * 2 + 2)
75
+ else:
76
+ self.name = "resnet" + str(sum(num_blocks) * 3 + 2)
77
+
78
+ def _make_layer(self, block, planes, num_blocks, stride):
79
+ strides = [stride] + [1]*(num_blocks-1)
80
+ layers = []
81
+ for stride in strides:
82
+ layers.append(block(self.in_planes, planes, stride))
83
+ self.in_planes = planes * block.expansion
84
+ return nn.Sequential(*layers)
85
+
86
+ def forward(self, x):
87
+ out = F.relu(self.bn1(self.conv1(x)))
88
+ out = self.layer1(out)
89
+ out = self.layer2(out)
90
+ out = self.layer3(out)
91
+ out = self.layer4(out)
92
+ out = F.avg_pool2d(out, 4)
93
+ out = out.view(out.size(0), -1)
94
+ out = self.linear(out)
95
+ return out
96
+
97
+ def ResNet101():
98
+ return ResNet(Bottleneck, [3,4,23,3])
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Resnet101 import *
2
+ import gradio as gr
3
+ from PIL import Image
4
+
5
+ print("Loading Resnet101 model...")
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ model = torch.load("resnet101_ckpt.pth", map_location=device)
8
+ net = ResNet101()
9
+ net.to(device)
10
+ net = torch.nn.DataParallel(net)
11
+ net.load_state_dict(model['net'])
12
+
13
+ print("Model loaded")
14
+ print("Device: ", device)
15
+
16
+ # Define a transform to convert the image to tensor
17
+ transform = transforms.Compose([
18
+ transforms.Resize([32, 32]),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
21
+ ])
22
+
23
+ def predict_image(image):
24
+
25
+ # Convert the image to PyTorch tensor
26
+ img_tensor = transform(Image.fromarray(image))
27
+ img_tensor.to(device)
28
+ with torch.no_grad():
29
+ outputs = net(img_tensor[None, ...])
30
+ _, predicted = outputs.max(1)
31
+ classes = ['plane', 'car', 'bird', 'cat', 'deer',
32
+ 'dog', 'frog', 'horse', 'ship', 'truck']
33
+ res = classes[predicted[0].item()]
34
+ print("Predicted class: ", res)
35
+ if res == 'car':
36
+ return Image.open("samples/car2.jpeg"), Image.open("samples/car3.jpg"), Image.open("samples/car4.jpg"), Image.open("samples/car5.jpg")
37
+ elif res == 'cat':
38
+ return Image.open("samples/cat2.jpg"), Image.open("samples/cat3.jpeg"), Image.open("samples/cat4.png"), Image.open("samples/cat5.jpg")
39
+ elif res == 'dog':
40
+ return Image.open("samples/dog2.jpg"), Image.open("samples/dog3.jpg"), Image.open("samples/dog4.jpg"), Image.open("samples/dog5.jpg")
41
+ elif res == 'horse':
42
+ return Image.open("samples/horse2.jpg"), Image.open("samples/horse3.jpeg"), Image.open("samples/horse4.jpg"), Image.open("samples/horse5.jpg")
43
+ else:
44
+ return Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg")
45
+
46
+ def set_example_image(example: list) -> dict:
47
+ return gr.Image.update(value=example[0])
48
+
49
+ demo = gr.Blocks()
50
+ with demo:
51
+ gr.Markdown('''
52
+ <center>
53
+ <h1>Image Classification trained on Resnet101</h1>
54
+ <p>
55
+ Image classification model trained on Resnet101. The dataset used is the CIFAR-10 dataset.
56
+ It will detect 4 classes of images: car, cat, dog and horse. Then it will show you 4 images of the same class.
57
+ </p>
58
+ </center>
59
+ ''')
60
+
61
+ with gr.Row():
62
+ input_image = gr.Image(label="Input image")
63
+ with gr.Row():
64
+ output_imgs = [gr.Image(label='Closest Image 1', type='numpy', interactive=False),
65
+ gr.Image(label='Closest Image 2', type='numpy', interactive=False),
66
+ gr.Image(label='Closest Image 3', type='numpy', interactive=False),
67
+ gr.Image(label='Closest Image 4', type='numpy', interactive=False)]
68
+ button = gr.Button("Classifier")
69
+ with gr.Row():
70
+ example_images = gr.Dataset(components=[input_image],
71
+ samples=[["samples/cat1.jpg"], ["samples/car1.jpg"], ["samples/dog1.jpeg"], ["samples/horse1.jpg"]])
72
+ example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)
73
+ button.click(predict_image, inputs=input_image, outputs=output_imgs)
74
+
75
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ opencv-python
5
+ gradio
resnet101_ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57f0c7486996c89c17d88249ab1a3620da4affb29f9dfa86917bf96028a8b0bc
3
+ size 170593359
samples/car1.jpg ADDED
samples/car2.jpeg ADDED
samples/car3.jpg ADDED
samples/car4.jpg ADDED
samples/car5.jpg ADDED
samples/cat1.jpg ADDED
samples/cat2.jpg ADDED
samples/cat3.jpeg ADDED
samples/cat4.png ADDED
samples/cat5.jpg ADDED
samples/dog1.jpeg ADDED
samples/dog2.jpg ADDED
samples/dog3.jpg ADDED
samples/dog4.jpg ADDED
samples/dog5.jpg ADDED
samples/horse1.jpg ADDED
samples/horse2.jpg ADDED
samples/horse3.jpeg ADDED
samples/horse4.jpg ADDED
samples/horse5.jpg ADDED
samples/not-found.jpg ADDED
train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Train CIFAR10 with PyTorch.'''
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torch.nn.functional as F
6
+ import torch.backends.cudnn as cudnn
7
+
8
+ import torchvision
9
+ import torchvision.transforms as transforms
10
+
11
+ import os
12
+ from Resnet101 import *
13
+
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ best_acc = 0 # best test accuracy
16
+ start_epoch = 0 # start from epoch 0 or last checkpoint epoch
17
+ end_epoch = 300
18
+ resume = False
19
+
20
+ # Data
21
+ print('==> Preparing data..')
22
+ transform_train = transforms.Compose([
23
+ transforms.RandomCrop(32, padding=4),
24
+ transforms.RandomHorizontalFlip(),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
27
+ ])
28
+
29
+ transform_test = transforms.Compose([
30
+ transforms.ToTensor(),
31
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
32
+ ])
33
+
34
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
35
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
36
+
37
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
38
+ testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
39
+
40
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
41
+
42
+ # Model
43
+ print('==> Building model..')
44
+ net = ResNet101()
45
+ net_name = net.name
46
+ save_path = './checkpoint/{0}_ckpt.pth'.format(net.name)
47
+ net = net.to(device)
48
+ if device == 'cuda':
49
+ net = torch.nn.DataParallel(net)
50
+ cudnn.benchmark = True
51
+
52
+ if resume:
53
+ # Load best checkpoint trained last time.
54
+ print('==> Resuming from checkpoint..')
55
+ assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
56
+ checkpoint = torch.load(save_path)
57
+ net.load_state_dict(checkpoint['net'])
58
+ best_acc = checkpoint['acc']
59
+ start_epoch = checkpoint['epoch']
60
+
61
+ criterion = nn.CrossEntropyLoss()
62
+ optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
63
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=70, gamma=0.1)
64
+
65
+ # Training
66
+ def train(epoch):
67
+ print('\nEpoch: %d' % epoch)
68
+ net.train()
69
+ train_loss = 0
70
+ correct = 0
71
+ total = 0
72
+ for batch_idx, (inputs, targets) in enumerate(trainloader):
73
+ inputs, targets = inputs.to(device), targets.to(device)
74
+ optimizer.zero_grad()
75
+ outputs = net(inputs)
76
+ loss = criterion(outputs, targets)
77
+ loss.backward()
78
+ optimizer.step()
79
+
80
+ train_loss += loss.item()
81
+ _, predicted = outputs.max(1)
82
+ total += targets.size(0)
83
+ correct += predicted.eq(targets).sum().item()
84
+ print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
85
+
86
+ def test(epoch):
87
+ global best_acc
88
+ net.eval()
89
+ test_loss = 0
90
+ correct = 0
91
+ total = 0
92
+ with torch.no_grad():
93
+ for batch_idx, (inputs, targets) in enumerate(testloader):
94
+ inputs, targets = inputs.to(device), targets.to(device)
95
+ outputs = net(inputs)
96
+ loss = criterion(outputs, targets)
97
+
98
+ test_loss += loss.item()
99
+ _, predicted = outputs.max(1)
100
+ total += targets.size(0)
101
+ correct += predicted.eq(targets).sum().item()
102
+
103
+ print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
104
+
105
+ # Save checkpoint.
106
+ acc = 100.*correct/total
107
+ if acc > best_acc:
108
+ print('Saving ' + net_name + ' ..')
109
+ state = {
110
+ 'net': net.state_dict(),
111
+ 'acc': acc,
112
+ 'epoch': epoch,
113
+ }
114
+ if not os.path.isdir('checkpoint'):
115
+ os.mkdir('checkpoint')
116
+ torch.save(state, save_path)
117
+ best_acc = acc
118
+
119
+
120
+ for epoch in range(start_epoch, end_epoch):
121
+ train(epoch)
122
+ test(epoch)
123
+ scheduler.step()
124
+
125
+ print("\nTesting best accuracy:", best_acc)