Spaces:
Build error
Build error
import os | |
import subprocess | |
import sys | |
import cv2 | |
import gdown | |
from PIL import Image | |
import numpy as np | |
import streamlit as st | |
import torch | |
from torchvision import transforms | |
def setup_env(path='Variations-of-SFANet-for-Crowd-Counting'): | |
if os.path.exists(path): | |
return path | |
subprocess.run( | |
[ | |
'git', | |
'clone', | |
f'https://github.com/Pongpisit-Thanasutives/{path}.git', | |
f'{path}', | |
], | |
capture_output=True, | |
check=True, | |
) | |
sys.path.append(path) | |
with open(os.path.join(path, 'models', '__init__.py'), 'w') as f: | |
f.write('') | |
return path | |
def get_model(path, weights): | |
from models import M_SFANet_UCF_QNRF | |
model = M_SFANet_UCF_QNRF.Model() | |
model.load_state_dict( | |
torch.load(weights, map_location=torch.device('cpu'))) | |
return model.eval() | |
def download_weights( | |
url='https://drive.google.com/uc?id=1fGuH4o0hKbgdP1kaj9rbjX2HUL1IH0oo', | |
out="Paper's_weights_UCF_QNRF.zip", | |
): | |
weights = "Paper's_weights_UCF_QNRF/best_M-SFANet*_UCF_QNRF.pth" | |
if os.path.exists(weights): | |
return weights | |
gdown.download(url, out) | |
subprocess.run( | |
['unzip', out], | |
capture_output=True, | |
check=True, | |
) | |
return weights | |
def transform_image(img): | |
trans = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
height, width = img.size[1], img.size[0] | |
height = round(height / 16) * 16 | |
width = round(width / 16) * 16 | |
img = cv2.resize(np.array(img), (width, height), cv2.INTER_CUBIC) | |
return trans(Image.fromarray(img))[None, :] | |
def main(): | |
st.write("Demo of [Encoder-Decoder Based Convolutional Neural Networks with Multi-Scale-Aware Modules for Crowd Counting](https://arxiv.org/abs/2003.05586)") # noqa | |
path = setup_env() | |
weights = download_weights() | |
model = get_model(path, weights) | |
image_file = st.file_uploader( | |
"Upload image", type=['png', 'jpg', 'jpeg']) | |
if image_file is not None: | |
image = Image.open(image_file).convert('RGB') | |
st.image(image) | |
density_map = model(transform_image(image)) | |
density_map_img = density_map.detach().numpy()[0].transpose(1, 2, 0) | |
st.image(density_map_img / density_map_img.max()) | |
st.write("Estimated count: ", torch.sum(density_map).item()) | |
else: | |
st.write("Example image to use that you can drag and drop:") | |
st.image(Image.open('crowd.jpg').convert('RGB')) | |
main() | |