Spaces:
Sleeping
Sleeping
# The model used in the paper Distribution Matching for Crowd Counting. | |
# Code adapted from https://github.com/cvlab-stonybrook/DM-Count/blob/master/models.py | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
from torch.hub import load_state_dict_from_url | |
from typing import Optional | |
from ..utils import make_vgg_layers, vgg_cfgs, vgg_urls | |
from ..utils import _init_weights | |
class VGG(nn.Module): | |
def __init__( | |
self, | |
features: nn.Module, | |
reduction: Optional[int] = None, | |
) -> None: | |
super().__init__() | |
self.features = features | |
self.reg_layer = nn.Sequential( | |
nn.Conv2d(512, 256, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 128, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
) | |
self.reg_layer.apply(_init_weights) | |
# Remove the density layer, as the output from this model is not final and will be further processed. | |
# self.density_layer = nn.Sequential(nn.Conv2d(128, 1, 1), nn.ReLU()) | |
self.encoder_reduction = 16 | |
self.reduction = self.encoder_reduction if reduction is None else reduction | |
self.channels = 128 | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.features(x) | |
if self.encoder_reduction != self.reduction: | |
x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") | |
x = self.reg_layer(x) | |
# x = self.density_layer(x) | |
return x | |
def _load_weights(model: VGG, url: str) -> VGG: | |
state_dict = load_state_dict_from_url(url) | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
print("Loading pre-trained weights") | |
if len(missing_keys) > 0: | |
print(f"Missing keys: {missing_keys}") | |
if len(unexpected_keys) > 0: | |
print(f"Unexpected keys: {unexpected_keys}") | |
return model | |
def vgg11(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["A"]), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg11"]) | |
def vgg11_bn(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg11_bn"]) | |
def vgg13(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["B"]), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg13"]) | |
def vgg13_bn(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg13_bn"]) | |
def vgg16(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["D"]), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg16"]) | |
def vgg16_bn(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg16_bn"]) | |
def vgg19(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["E"]), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg19"]) | |
def vgg19_bn(reduction: int = 8) -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True), reduction=reduction) | |
return _load_weights(model, vgg_urls["vgg19_bn"]) | |