Crowd_Count / app.py
avanish07's picture
Upload 3 files
a33f382
raw
history blame
3.22 kB
import h5py
import gradio as gr
import scipy.io as io
import PIL.Image as Image
import numpy as np
from torchvision import transforms
import scipy
import json
from matplotlib import cm as CM
import torch.nn as nn
import torch
from torchvision import models
class CSRNet(nn.Module):
def __init__(self, load_weights=False):
super(CSRNet, self).__init__()
self.seen = 0
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
self.backend_feat = [512, 512, 512, 256, 128, 64]
self.frontend = make_layers(self.frontend_feat)
self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True)
self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
if not load_weights:
mod = models.vgg16(pretrained=True)
self._initialize_weights()
mod_dict = mod.state_dict()
frontend_dict = self.frontend.state_dict()
for k, v in mod_dict.items():
if k in frontend_dict:
frontend_dict[k].data = v.data
def forward(self,x):
x = self.frontend(x)
x = self.backend(x)
x = self.output_layer(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
if dilation:
d_rate = 2
else:
d_rate = 1
layers = []
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
# Load the CSRNet model
csrmodel = CSRNet()
checkpoint = torch.load("model.pt")
csrmodel.load_state_dict(checkpoint)
csrmodel.eval()
# Set the transformation for image preprocessing
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define the prediction function
def predict_count(input_image):
# Preprocess the input image
image = transform(input_image).unsqueeze(0)
# Perform the forward pass
output = csrmodel(image)
# Calculate the predicted count
predicted_count = int(output.detach().cpu().sum().numpy())
return predicted_count
# Define the input and output interfaces for Gradio
input_interface = gr.inputs.Image()
output_interface = gr.outputs.Textbox()
# Create the Gradio app
grapp = gr.Interface(fn=predict_count, inputs=input_interface, outputs=output_interface)
# Launch the app
grapp.launch()