Spaces:
Sleeping
Sleeping
import h5py | |
import gradio as gr | |
import scipy.io as io | |
import matplotlib.pyplot as plt | |
import PIL.Image as Image | |
import numpy as np | |
from torchvision import transforms | |
import scipy | |
import json | |
from matplotlib import cm as CM | |
import torch.nn as nn | |
import torch | |
from torchvision import models | |
class CSRNet(nn.Module): | |
def __init__(self, load_weights=False): | |
super(CSRNet, self).__init__() | |
self.seen = 0 | |
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] | |
self.backend_feat = [512, 512, 512, 256, 128, 64] | |
self.frontend = make_layers(self.frontend_feat) | |
self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True) | |
self.output_layer = nn.Conv2d(64, 1, kernel_size=1) | |
if not load_weights: | |
mod = models.vgg16(pretrained=True) | |
self._initialize_weights() | |
mod_dict = mod.state_dict() | |
frontend_dict = self.frontend.state_dict() | |
for k, v in mod_dict.items(): | |
if k in frontend_dict: | |
frontend_dict[k].data = v.data | |
def forward(self,x): | |
x = self.frontend(x) | |
x = self.backend(x) | |
x = self.output_layer(x) | |
return x | |
def _initialize_weights(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.normal_(m.weight, std=0.01) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False): | |
if dilation: | |
d_rate = 2 | |
else: | |
d_rate = 1 | |
layers = [] | |
for v in cfg: | |
if v == 'M': | |
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |
else: | |
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate) | |
if batch_norm: | |
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] | |
else: | |
layers += [conv2d, nn.ReLU(inplace=True)] | |
in_channels = v | |
return nn.Sequential(*layers) | |
# Load the CSRNet model | |
csrmodel = CSRNet(load_weights=True).cpu() | |
checkpoint = torch.load("model.pt", map_location=torch.device('cpu')) | |
csrmodel.load_state_dict(checkpoint) | |
csrmodel.eval() | |
# Set the transformation for image preprocessing | |
transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Define the prediction function | |
def predict_count(input_image): | |
image = transform(input_image).unsqueeze(0).cpu() | |
output = csrmodel(image) | |
predicted_count = int(output.detach().cpu().sum().numpy()) | |
density_map = output.detach().cpu().numpy().reshape(output.shape[2], output.shape[3]) | |
density_map_color = plt.cm.jet(density_map / np.max(density_map)) | |
return predicted_count, density_map_color | |
from gradio.components import Image | |
from gradio.components import Label, Image | |
input_interface = gr.inputs.Image(label="Input Image") | |
output_interface = [ | |
Label(label="Predicted Count"), | |
Image(label="Density Map", type="numpy") | |
] | |
# Create the Gradio app with both interfaces | |
grapp = gr.Interface(fn=predict_count, inputs=input_interface, outputs=output_interface) | |
# Launch the app | |
grapp.launch() |