def centre_crop(x, target): ''' Center-crop 3-dim. input tensor along last axis so it fits the target tensor shape :param x: Input tensor :param target: Shape of this tensor will be used as target shape :return: Cropped input tensor ''' if x is None: return None if target is None: return x target_shape = target.shape diff = x.shape[-1] - target_shape[-1] assert (diff % 2 == 0) crop = diff // 2 if crop == 0: return x if crop < 0: raise ArithmeticError return x[:, :, crop:-crop].contiguous()