File size: 5,534 Bytes
4f6b78d |
|
import torch
import numpy as np
import math
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
param = parameter.numel()
if param > 100000:
table.add_row([name, param])
total_params+=param
print(table)
print('total params: %.2f M' % (total_params/1000000.0))
return total_params
def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):
device = xy.device
dtype = xy.dtype
B, S, D = xy.shape
assert(D==2)
x = xy[:,:,0]
y = xy[:,:,1]
assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
pe = pe.reshape(B,S,C).type(dtype)
if cat_coords:
pe = torch.cat([pe, xy], dim=2) # B,N,C+2
return pe
class SimplePool():
def __init__(self, pool_size, version='pt'):
self.pool_size = pool_size
self.version = version
self.items = []
if not (version=='pt' or version=='np'):
print('version = %s; please choose pt or np')
assert(False) # please choose pt or np
def __len__(self):
return len(self.items)
def mean(self, min_size=1):
if min_size=='half':
pool_size_thresh = self.pool_size/2
else:
pool_size_thresh = min_size
if self.version=='np':
if len(self.items) >= pool_size_thresh:
return np.sum(self.items)/float(len(self.items))
else:
return np.nan
if self.version=='pt':
if len(self.items) >= pool_size_thresh:
return torch.sum(self.items)/float(len(self.items))
else:
return torch.from_numpy(np.nan)
def sample(self, with_replacement=True):
idx = np.random.randint(len(self.items))
if with_replacement:
return self.items[idx]
else:
return self.items.pop(idx)
def fetch(self, num=None):
if self.version=='pt':
item_array = torch.stack(self.items)
elif self.version=='np':
item_array = np.stack(self.items)
if num is not None:
# there better be some items
assert(len(self.items) >= num)
# if there are not that many elements just return however many there are
if len(self.items) < num:
return item_array
else:
idxs = np.random.randint(len(self.items), size=num)
return item_array[idxs]
else:
return item_array
def is_full(self):
full = len(self.items)==self.pool_size
return full
def empty(self):
self.items = []
def update(self, items):
for item in items:
if len(self.items) < self.pool_size:
# the pool is not full, so let's add this in
self.items.append(item)
else:
# the pool is full
# pop from the front
self.items.pop(0)
# add to the back
self.items.append(item)
return self.items
def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):
"""
Input:
xyz: pointcloud data, [B, N, C], where C is probably 3
npoint: number of samples
Return:
inds: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
xyz = xyz.float()
inds = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
if deterministic:
farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device)
else:
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
if include_ends:
if i==0:
farthest = 0
elif i==1:
farthest = N-1
inds[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
if npoint > N:
# if we need more samples, make them random
distance += torch.randn_like(distance)
return inds
def farthest_point_sample_py(xyz, npoint):
N,C = xyz.shape
inds = np.zeros(npoint, dtype=np.int32)
distance = np.ones(N) * 1e10
farthest = np.random.randint(0, N, dtype=np.int32)
for i in range(npoint):
inds[i] = farthest
centroid = xyz[farthest, :].reshape(1,C)
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
if npoint > N:
# if we need more samples, make them random
distance += np.random.randn(*distance.shape)
return inds
|