Crowd_Count / app.py
avanish07's picture
Update app.py
5f77b3b
import h5py
import gradio as gr
import scipy.io as io
import matplotlib.pyplot as plt
import PIL.Image as Image
import numpy as np
from torchvision import transforms
import scipy
import json
from matplotlib import cm as CM
import torch.nn as nn
import torch
from torchvision import models
class CSRNet(nn.Module):
def __init__(self, load_weights=False):
super(CSRNet, self).__init__()
self.seen = 0
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
self.backend_feat = [512, 512, 512, 256, 128, 64]
self.frontend = make_layers(self.frontend_feat)
self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True)
self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
if not load_weights:
mod = models.vgg16(pretrained=True)
self._initialize_weights()
mod_dict = mod.state_dict()
frontend_dict = self.frontend.state_dict()
for k, v in mod_dict.items():
if k in frontend_dict:
frontend_dict[k].data = v.data
def forward(self,x):
x = self.frontend(x)
x = self.backend(x)
x = self.output_layer(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
if dilation:
d_rate = 2
else:
d_rate = 1
layers = []
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
# Load the CSRNet model
csrmodel = CSRNet(load_weights=True).cpu()
checkpoint = torch.load("model.pt", map_location=torch.device('cpu'))
csrmodel.load_state_dict(checkpoint)
csrmodel.eval()
# Set the transformation for image preprocessing
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define the prediction function
def predict_count(input_image):
image = transform(input_image).unsqueeze(0).cpu()
output = csrmodel(image)
predicted_count = int(output.detach().cpu().sum().numpy())
density_map = output.detach().cpu().numpy().reshape(output.shape[2], output.shape[3])
density_map_color = plt.cm.jet(density_map / np.max(density_map))
return predicted_count, density_map_color
from gradio.components import Image
from gradio.components import Label, Image
input_interface = gr.inputs.Image(label="Input Image")
output_interface = [
Label(label="Predicted Count"),
Image(label="Density Map", type="numpy")
]
# Create the Gradio app with both interfaces
grapp = gr.Interface(fn=predict_count, inputs=input_interface, outputs=output_interface)
# Launch the app
grapp.launch()