File size: 3,521 Bytes
a33f382
 
 
607ec5d
a33f382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f326c04
963d428
a33f382
 
 
 
 
 
 
 
 
 
 
 
 
f326c04
a33f382
 
3b4985d
 
6bd3919
3b4985d
5f77b3b
 
6bd3919
 
 
 
 
a33f382
6bd3919
 
a33f382
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import h5py
import gradio as gr
import scipy.io as io
import matplotlib.pyplot as plt
import PIL.Image as Image
import numpy as np
from torchvision import transforms
import scipy
import json
from matplotlib import cm as CM
import torch.nn as nn
import torch
from torchvision import models


class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat = [512, 512, 512, 256, 128, 64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            mod = models.vgg16(pretrained=True)
            self._initialize_weights()
            mod_dict = mod.state_dict()
            frontend_dict = self.frontend.state_dict()
            for k, v in mod_dict.items():
                if k in frontend_dict:
                    frontend_dict[k].data = v.data

    def forward(self,x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


# Load the CSRNet model
csrmodel = CSRNet(load_weights=True).cpu()
checkpoint = torch.load("model.pt", map_location=torch.device('cpu'))
csrmodel.load_state_dict(checkpoint)
csrmodel.eval()

# Set the transformation for image preprocessing
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the prediction function
def predict_count(input_image):
    image = transform(input_image).unsqueeze(0).cpu()
    output = csrmodel(image)
    predicted_count = int(output.detach().cpu().sum().numpy())
    density_map = output.detach().cpu().numpy().reshape(output.shape[2], output.shape[3])
    density_map_color = plt.cm.jet(density_map / np.max(density_map))

    return predicted_count, density_map_color
from gradio.components import Image
from gradio.components import Label, Image
input_interface = gr.inputs.Image(label="Input Image")
output_interface = [
    Label(label="Predicted Count"),
    Image(label="Density Map", type="numpy")
]

# Create the Gradio app with both interfaces
grapp = gr.Interface(fn=predict_count, inputs=input_interface, outputs=output_interface)

# Launch the app
grapp.launch()