hieupt's picture
Upload crop.py
5378323 verified
raw
history blame contribute delete
589 Bytes
def centre_crop(x, target):
'''
Center-crop 3-dim. input tensor along last axis so it fits the target tensor shape
:param x: Input tensor
:param target: Shape of this tensor will be used as target shape
:return: Cropped input tensor
'''
if x is None:
return None
if target is None:
return x
target_shape = target.shape
diff = x.shape[-1] - target_shape[-1]
assert (diff % 2 == 0)
crop = diff // 2
if crop == 0:
return x
if crop < 0:
raise ArithmeticError
return x[:, :, crop:-crop].contiguous()