Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""BatchNorm (BN) utility functions and custom batch-size BN implementations""" | |
from functools import partial | |
import torch | |
import torch.nn as nn | |
from pytorchvideo.layers.batch_norm import ( | |
NaiveSyncBatchNorm1d, | |
NaiveSyncBatchNorm3d, | |
) # noqa | |
def get_norm(cfg): | |
""" | |
Args: | |
cfg (CfgNode): model building configs, details are in the comments of | |
the config file. | |
Returns: | |
nn.Module: the normalization layer. | |
""" | |
if cfg.BN.NORM_TYPE in {"batchnorm", "sync_batchnorm_apex"}: | |
return nn.BatchNorm3d | |
elif cfg.BN.NORM_TYPE == "sub_batchnorm": | |
return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) | |
elif cfg.BN.NORM_TYPE == "sync_batchnorm": | |
return partial( | |
NaiveSyncBatchNorm3d, | |
num_sync_devices=cfg.BN.NUM_SYNC_DEVICES, | |
global_sync=cfg.BN.GLOBAL_SYNC, | |
) | |
else: | |
raise NotImplementedError( | |
"Norm type {} is not supported".format(cfg.BN.NORM_TYPE) | |
) | |
class SubBatchNorm3d(nn.Module): | |
""" | |
The standard BN layer computes stats across all examples in a GPU. In some | |
cases it is desirable to compute stats across only a subset of examples | |
(e.g., in multigrid training https://arxiv.org/abs/1912.00998). | |
SubBatchNorm3d splits the batch dimension into N splits, and run BN on | |
each of them separately (so that the stats are computed on each subset of | |
examples (1/N of batch) independently. During evaluation, it aggregates | |
the stats from all splits into one BN. | |
""" | |
def __init__(self, num_splits, **args): | |
""" | |
Args: | |
num_splits (int): number of splits. | |
args (list): other arguments. | |
""" | |
super(SubBatchNorm3d, self).__init__() | |
self.num_splits = num_splits | |
num_features = args["num_features"] | |
# Keep only one set of weight and bias. | |
if args.get("affine", True): | |
self.affine = True | |
args["affine"] = False | |
self.weight = torch.nn.Parameter(torch.ones(num_features)) | |
self.bias = torch.nn.Parameter(torch.zeros(num_features)) | |
else: | |
self.affine = False | |
self.bn = nn.BatchNorm3d(**args) | |
args["num_features"] = num_features * num_splits | |
self.split_bn = nn.BatchNorm3d(**args) | |
def _get_aggregated_mean_std(self, means, stds, n): | |
""" | |
Calculate the aggregated mean and stds. | |
Args: | |
means (tensor): mean values. | |
stds (tensor): standard deviations. | |
n (int): number of sets of means and stds. | |
""" | |
mean = means.view(n, -1).sum(0) / n | |
std = ( | |
stds.view(n, -1).sum(0) / n | |
+ ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n | |
) | |
return mean.detach(), std.detach() | |
def aggregate_stats(self): | |
""" | |
Synchronize running_mean, and running_var. Call this before eval. | |
""" | |
if self.split_bn.track_running_stats: | |
( | |
self.bn.running_mean.data, | |
self.bn.running_var.data, | |
) = self._get_aggregated_mean_std( | |
self.split_bn.running_mean, | |
self.split_bn.running_var, | |
self.num_splits, | |
) | |
def forward(self, x): | |
if self.training: | |
n, c, t, h, w = x.shape | |
x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) | |
x = self.split_bn(x) | |
x = x.view(n, c, t, h, w) | |
else: | |
x = self.bn(x) | |
if self.affine: | |
x = x * self.weight.view((-1, 1, 1, 1)) | |
x = x + self.bias.view((-1, 1, 1, 1)) | |
return x |