crowd-counting / app.py
Daniel Nouri
Demo of Crowd Counting work by Thanasutives et al
752eb09
import os
import subprocess
import sys
import cv2
import gdown
from PIL import Image
import numpy as np
import streamlit as st
import torch
from torchvision import transforms
def setup_env(path='Variations-of-SFANet-for-Crowd-Counting'):
if os.path.exists(path):
return path
subprocess.run(
[
'git',
'clone',
f'https://github.com/Pongpisit-Thanasutives/{path}.git',
f'{path}',
],
capture_output=True,
check=True,
)
sys.path.append(path)
with open(os.path.join(path, 'models', '__init__.py'), 'w') as f:
f.write('')
return path
def get_model(path, weights):
from models import M_SFANet_UCF_QNRF
model = M_SFANet_UCF_QNRF.Model()
model.load_state_dict(
torch.load(weights, map_location=torch.device('cpu')))
return model.eval()
def download_weights(
url='https://drive.google.com/uc?id=1fGuH4o0hKbgdP1kaj9rbjX2HUL1IH0oo',
out="Paper's_weights_UCF_QNRF.zip",
):
weights = "Paper's_weights_UCF_QNRF/best_M-SFANet*_UCF_QNRF.pth"
if os.path.exists(weights):
return weights
gdown.download(url, out)
subprocess.run(
['unzip', out],
capture_output=True,
check=True,
)
return weights
def transform_image(img):
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
height, width = img.size[1], img.size[0]
height = round(height / 16) * 16
width = round(width / 16) * 16
img = cv2.resize(np.array(img), (width, height), cv2.INTER_CUBIC)
return trans(Image.fromarray(img))[None, :]
def main():
st.write("Demo of [Encoder-Decoder Based Convolutional Neural Networks with Multi-Scale-Aware Modules for Crowd Counting](https://arxiv.org/abs/2003.05586)") # noqa
path = setup_env()
weights = download_weights()
model = get_model(path, weights)
image_file = st.file_uploader(
"Upload image", type=['png', 'jpg', 'jpeg'])
if image_file is not None:
image = Image.open(image_file).convert('RGB')
st.image(image)
density_map = model(transform_image(image))
density_map_img = density_map.detach().numpy()[0].transpose(1, 2, 0)
st.image(density_map_img / density_map_img.max())
st.write("Estimated count: ", torch.sum(density_map).item())
else:
st.write("Example image to use that you can drag and drop:")
st.image(Image.open('crowd.jpg').convert('RGB'))
main()