Spaces:
Running
on
Zero
Running
on
Zero
import pdb | |
from os import path | |
import torch | |
import torch.distributed as dist | |
import torch.autograd as autograd | |
import torch.cuda.comm as comm | |
from torch.autograd.function import once_differentiable | |
from torch.utils.cpp_extension import load | |
_src_path = path.join(path.dirname(path.abspath(__file__)), "src") | |
_backend = load(name="inplace_abn", | |
extra_cflags=["-O3"], | |
sources=[path.join(_src_path, f) for f in [ | |
"inplace_abn.cpp", | |
"inplace_abn_cpu.cpp", | |
"inplace_abn_cuda.cu", | |
"inplace_abn_cuda_half.cu" | |
]], | |
extra_cuda_cflags=["--expt-extended-lambda"]) | |
# Activation names | |
ACT_RELU = "relu" | |
ACT_LEAKY_RELU = "leaky_relu" | |
ACT_ELU = "elu" | |
ACT_NONE = "none" | |
def _check(fn, *args, **kwargs): | |
success = fn(*args, **kwargs) | |
if not success: | |
raise RuntimeError("CUDA Error encountered in {}".format(fn)) | |
def _broadcast_shape(x): | |
out_size = [] | |
for i, s in enumerate(x.size()): | |
if i != 1: | |
out_size.append(1) | |
else: | |
out_size.append(s) | |
return out_size | |
def _reduce(x): | |
if len(x.size()) == 2: | |
return x.sum(dim=0) | |
else: | |
n, c = x.size()[0:2] | |
return x.contiguous().view((n, c, -1)).sum(2).sum(0) | |
def _count_samples(x): | |
count = 1 | |
for i, s in enumerate(x.size()): | |
if i != 1: | |
count *= s | |
return count | |
def _act_forward(ctx, x): | |
if ctx.activation == ACT_LEAKY_RELU: | |
_backend.leaky_relu_forward(x, ctx.slope) | |
elif ctx.activation == ACT_ELU: | |
_backend.elu_forward(x) | |
elif ctx.activation == ACT_NONE: | |
pass | |
def _act_backward(ctx, x, dx): | |
if ctx.activation == ACT_LEAKY_RELU: | |
_backend.leaky_relu_backward(x, dx, ctx.slope) | |
elif ctx.activation == ACT_ELU: | |
_backend.elu_backward(x, dx) | |
elif ctx.activation == ACT_NONE: | |
pass | |
class InPlaceABN(autograd.Function): | |
def forward(ctx, x, weight, bias, running_mean, running_var, | |
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): | |
# Save context | |
ctx.training = training | |
ctx.momentum = momentum | |
ctx.eps = eps | |
ctx.activation = activation | |
ctx.slope = slope | |
ctx.affine = weight is not None and bias is not None | |
# Prepare inputs | |
count = _count_samples(x) | |
x = x.contiguous() | |
weight = weight.contiguous() if ctx.affine else x.new_empty(0) | |
bias = bias.contiguous() if ctx.affine else x.new_empty(0) | |
if ctx.training: | |
mean, var = _backend.mean_var(x) | |
# Update running stats | |
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) | |
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) | |
# Mark in-place modified tensors | |
ctx.mark_dirty(x, running_mean, running_var) | |
else: | |
mean, var = running_mean.contiguous(), running_var.contiguous() | |
ctx.mark_dirty(x) | |
# BN forward + activation | |
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) | |
_act_forward(ctx, x) | |
# Output | |
ctx.var = var | |
ctx.save_for_backward(x, var, weight, bias) | |
ctx.mark_non_differentiable(running_mean, running_var) | |
return x, running_mean, running_var | |
def backward(ctx, dz, _drunning_mean, _drunning_var): | |
z, var, weight, bias = ctx.saved_tensors | |
dz = dz.contiguous() | |
# Undo activation | |
_act_backward(ctx, z, dz) | |
if ctx.training: | |
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) | |
else: | |
# TODO: implement simplified CUDA backward for inference mode | |
edz = dz.new_zeros(dz.size(1)) | |
eydz = dz.new_zeros(dz.size(1)) | |
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) | |
# dweight = eydz * weight.sign() if ctx.affine else None | |
dweight = eydz if ctx.affine else None | |
if dweight is not None: | |
dweight[weight < 0] *= -1 | |
dbias = edz if ctx.affine else None | |
return dx, dweight, dbias, None, None, None, None, None, None, None | |
class InPlaceABNSync(autograd.Function): | |
def forward(cls, ctx, x, weight, bias, running_mean, running_var, | |
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True): | |
# Save context | |
ctx.training = training | |
ctx.momentum = momentum | |
ctx.eps = eps | |
ctx.activation = activation | |
ctx.slope = slope | |
ctx.affine = weight is not None and bias is not None | |
# Prepare inputs | |
ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
# count = _count_samples(x) | |
batch_size = x.new_tensor([x.shape[0]], dtype=torch.long) | |
x = x.contiguous() | |
weight = weight.contiguous() if ctx.affine else x.new_empty(0) | |
bias = bias.contiguous() if ctx.affine else x.new_empty(0) | |
if ctx.training: | |
mean, var = _backend.mean_var(x) | |
if ctx.world_size > 1: | |
# get global batch size | |
if equal_batches: | |
batch_size *= ctx.world_size | |
else: | |
dist.all_reduce(batch_size, dist.ReduceOp.SUM) | |
ctx.factor = x.shape[0] / float(batch_size.item()) | |
mean_all = mean.clone() * ctx.factor | |
dist.all_reduce(mean_all, dist.ReduceOp.SUM) | |
var_all = (var + (mean - mean_all) ** 2) * ctx.factor | |
dist.all_reduce(var_all, dist.ReduceOp.SUM) | |
mean = mean_all | |
var = var_all | |
# Update running stats | |
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) | |
count = batch_size.item() * x.view(x.shape[0], x.shape[1], -1).shape[-1] | |
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1))) | |
# Mark in-place modified tensors | |
ctx.mark_dirty(x, running_mean, running_var) | |
else: | |
mean, var = running_mean.contiguous(), running_var.contiguous() | |
ctx.mark_dirty(x) | |
# BN forward + activation | |
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) | |
_act_forward(ctx, x) | |
# Output | |
ctx.var = var | |
ctx.save_for_backward(x, var, weight, bias) | |
ctx.mark_non_differentiable(running_mean, running_var) | |
return x, running_mean, running_var | |
def backward(ctx, dz, _drunning_mean, _drunning_var): | |
z, var, weight, bias = ctx.saved_tensors | |
dz = dz.contiguous() | |
# Undo activation | |
_act_backward(ctx, z, dz) | |
if ctx.training: | |
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) | |
edz_local = edz.clone() | |
eydz_local = eydz.clone() | |
if ctx.world_size > 1: | |
edz *= ctx.factor | |
dist.all_reduce(edz, dist.ReduceOp.SUM) | |
eydz *= ctx.factor | |
dist.all_reduce(eydz, dist.ReduceOp.SUM) | |
else: | |
edz_local = edz = dz.new_zeros(dz.size(1)) | |
eydz_local = eydz = dz.new_zeros(dz.size(1)) | |
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) | |
# dweight = eydz_local * weight.sign() if ctx.affine else None | |
dweight = eydz_local if ctx.affine else None | |
if dweight is not None: | |
dweight[weight < 0] *= -1 | |
dbias = edz_local if ctx.affine else None | |
return dx, dweight, dbias, None, None, None, None, None, None, None | |
inplace_abn = InPlaceABN.apply | |
inplace_abn_sync = InPlaceABNSync.apply | |
__all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] | |