instructblip / lavis /processors /ulip_processors.py
WhiteWolf21's picture
Initialization
be13417
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
"""
Adapted from ULIP codebase: https://github.com/salesforce/ULIP
"""
from lavis.common.registry import registry
from lavis.processors.blip_processors import BlipImageBaseProcessor
from omegaconf import OmegaConf
import torchvision.transforms as transforms
from lavis.models.ulip_models.utils.io import IO
import numpy as np
from PIL import Image
import torch
def pc_norm(pc):
""" pc: NxC, return NxC """
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def random_sample(permutation, pc, num):
np.random.shuffle(permutation)
pc = pc[permutation[:num]]
return pc
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
@registry.register_processor("ulip_pc")
class ULIPPCProcessor(BlipImageBaseProcessor):
def __init__(
self,
npoints=8192,
augment=False,
uniform=True,
ssl=False,
oversample=False,
use_height=False,
):
super().__init__()
self.npoints=npoints
self.augment=augment
self.uniform=uniform
self.ssl=ssl
self.oversample=oversample
self.use_height=use_height
self.permutation = np.arange(self.npoints)
def __call__(self, pc_data_path):
if isinstance(pc_data_path, np.ndarray):
pc_data = pc_data_path
else:
try:
pc_data = np.load(pc_data_path, allow_pickle=True)['arr_0'].astype(np.float32)
except:
pc_data = IO.get(pc_data_path).astype(np.float32)
data = pc_norm(pc_data)
if self.uniform and self.npoints < data.shape[0]:
data = farthest_point_sample(data, self.npoints)
else:
data = random_sample(self.permutation, data, self.npoints)
if self.augment:
data = random_point_dropout(data[None, ...])
data = random_scale_point_cloud(data)
data = shift_point_cloud(data)
data = rotate_perturbation_point_cloud(data)
data = rotate_point_cloud(data)
data = data.squeeze()
if self.ssl:
data_for_aug = data[:]
data_aug_1 = random_point_dropout(data_for_aug[None, ...])
data_aug_1 = random_scale_point_cloud(data_aug_1, scale_low=0.5, scale_high=1.5)
data_aug_1 = shift_point_cloud(data_aug_1, shift_range=0.4)
data_aug_1 = rotate_perturbation_point_cloud(data_aug_1, angle_sigma=0.1, angle_clip=0.3)
data_aug_1 = rotate_point_cloud(data_aug_1)
data_aug_1 = data_aug_1.squeeze()
data_aug_2 = random_point_dropout(data_for_aug[None, ...])
data_aug_2 = random_scale_point_cloud(data_aug_2, scale_low=0.5, scale_high=1.5)
data_aug_2 = shift_point_cloud(data_aug_2, shift_range=0.4)
data_aug_2 = rotate_perturbation_point_cloud(data_aug_2, angle_sigma=0.1, angle_clip=0.3)
data_aug_2 = rotate_point_cloud(data_aug_2)
data_aug_2 = data_aug_2.squeeze()
if self.use_height:
self.gravity_dim = 1
height_array = data[:, self.gravity_dim:self.gravity_dim + 1] - data[:,
self.gravity_dim:self.gravity_dim + 1].min()
data = np.concatenate((data, height_array), axis=1)
data = torch.from_numpy(data).float()
else:
data = torch.from_numpy(data).float()
if self.ssl:
return {"data": data, "data_aug_1": data_aug_1, "data_aug_2": data_aug_2}
else:
return data
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
npoints= cfg.get('npoints', 8192)
augment= cfg.get('augment',False)
uniform= cfg.get('uniform',True)
ssl= cfg.get('ssl',False)
oversample= cfg.get('oversample',False)
use_height= cfg.get('use_height',False)
return cls(
npoints=npoints,
augment=augment,
uniform=uniform,
ssl=ssl,
oversample=oversample,
use_height=use_height,
)