avanish07 commited on
Commit
a33f382
·
1 Parent(s): f5297ff

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +103 -0
  2. model.pt +3 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import gradio as gr
3
+ import scipy.io as io
4
+ import PIL.Image as Image
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ import scipy
8
+ import json
9
+ from matplotlib import cm as CM
10
+ import torch.nn as nn
11
+ import torch
12
+ from torchvision import models
13
+
14
+
15
+ class CSRNet(nn.Module):
16
+ def __init__(self, load_weights=False):
17
+ super(CSRNet, self).__init__()
18
+ self.seen = 0
19
+ self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
20
+ self.backend_feat = [512, 512, 512, 256, 128, 64]
21
+ self.frontend = make_layers(self.frontend_feat)
22
+ self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True)
23
+ self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
24
+ if not load_weights:
25
+ mod = models.vgg16(pretrained=True)
26
+ self._initialize_weights()
27
+ mod_dict = mod.state_dict()
28
+ frontend_dict = self.frontend.state_dict()
29
+ for k, v in mod_dict.items():
30
+ if k in frontend_dict:
31
+ frontend_dict[k].data = v.data
32
+
33
+ def forward(self,x):
34
+ x = self.frontend(x)
35
+ x = self.backend(x)
36
+ x = self.output_layer(x)
37
+ return x
38
+ def _initialize_weights(self):
39
+ for m in self.modules():
40
+ if isinstance(m, nn.Conv2d):
41
+ nn.init.normal_(m.weight, std=0.01)
42
+ if m.bias is not None:
43
+ nn.init.constant_(m.bias, 0)
44
+ elif isinstance(m, nn.BatchNorm2d):
45
+ nn.init.constant_(m.weight, 1)
46
+ nn.init.constant_(m.bias, 0)
47
+
48
+
49
+ def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
50
+ if dilation:
51
+ d_rate = 2
52
+ else:
53
+ d_rate = 1
54
+ layers = []
55
+ for v in cfg:
56
+ if v == 'M':
57
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
58
+ else:
59
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
60
+ if batch_norm:
61
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
62
+ else:
63
+ layers += [conv2d, nn.ReLU(inplace=True)]
64
+ in_channels = v
65
+ return nn.Sequential(*layers)
66
+
67
+
68
+ # Load the CSRNet model
69
+ csrmodel = CSRNet()
70
+ checkpoint = torch.load("model.pt")
71
+ csrmodel.load_state_dict(checkpoint)
72
+ csrmodel.eval()
73
+
74
+ # Set the transformation for image preprocessing
75
+ transform = transforms.Compose([
76
+ transforms.ToPILImage(),
77
+ transforms.Resize((256, 256)),
78
+ transforms.ToTensor(),
79
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80
+ ])
81
+
82
+ # Define the prediction function
83
+ def predict_count(input_image):
84
+ # Preprocess the input image
85
+ image = transform(input_image).unsqueeze(0)
86
+
87
+ # Perform the forward pass
88
+ output = csrmodel(image)
89
+
90
+ # Calculate the predicted count
91
+ predicted_count = int(output.detach().cpu().sum().numpy())
92
+
93
+ return predicted_count
94
+
95
+ # Define the input and output interfaces for Gradio
96
+ input_interface = gr.inputs.Image()
97
+ output_interface = gr.outputs.Textbox()
98
+
99
+ # Create the Gradio app
100
+ grapp = gr.Interface(fn=predict_count, inputs=input_interface, outputs=output_interface)
101
+
102
+ # Launch the app
103
+ grapp.launch()
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8657ef16df4513ea38577f5e8e82f587dfe23e98f76656c63e2b3536c892766a
3
+ size 65059836
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ h5py
2
+ scipy
3
+ Pillow
4
+ numpy
5
+ matplotlib
6
+ torch
7
+ torchvision