|
import torch |
|
import numpy as np |
|
import math |
|
from prettytable import PrettyTable |
|
|
|
def count_parameters(model): |
|
table = PrettyTable(["Modules", "Parameters"]) |
|
total_params = 0 |
|
for name, parameter in model.named_parameters(): |
|
if not parameter.requires_grad: |
|
continue |
|
param = parameter.numel() |
|
if param > 100000: |
|
table.add_row([name, param]) |
|
total_params+=param |
|
print(table) |
|
print('total params: %.2f M' % (total_params/1000000.0)) |
|
return total_params |
|
|
|
def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): |
|
device = xy.device |
|
dtype = xy.dtype |
|
B, S, D = xy.shape |
|
assert(D==2) |
|
x = xy[:,:,0] |
|
y = xy[:,:,1] |
|
assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' |
|
omega = torch.arange(C // 4, device=device) / (C // 4 - 1) |
|
omega = 1. / (temperature ** omega) |
|
|
|
y = y.flatten()[:, None] * omega[None, :] |
|
x = x.flatten()[:, None] * omega[None, :] |
|
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) |
|
pe = pe.reshape(B,S,C).type(dtype) |
|
if cat_coords: |
|
pe = torch.cat([pe, xy], dim=2) |
|
return pe |
|
|
|
class SimplePool(): |
|
def __init__(self, pool_size, version='pt'): |
|
self.pool_size = pool_size |
|
self.version = version |
|
self.items = [] |
|
|
|
if not (version=='pt' or version=='np'): |
|
print('version = %s; please choose pt or np') |
|
assert(False) |
|
|
|
def __len__(self): |
|
return len(self.items) |
|
|
|
def mean(self, min_size=1): |
|
if min_size=='half': |
|
pool_size_thresh = self.pool_size/2 |
|
else: |
|
pool_size_thresh = min_size |
|
|
|
if self.version=='np': |
|
if len(self.items) >= pool_size_thresh: |
|
return np.sum(self.items)/float(len(self.items)) |
|
else: |
|
return np.nan |
|
if self.version=='pt': |
|
if len(self.items) >= pool_size_thresh: |
|
return torch.sum(self.items)/float(len(self.items)) |
|
else: |
|
return torch.from_numpy(np.nan) |
|
|
|
def sample(self, with_replacement=True): |
|
idx = np.random.randint(len(self.items)) |
|
if with_replacement: |
|
return self.items[idx] |
|
else: |
|
return self.items.pop(idx) |
|
|
|
def fetch(self, num=None): |
|
if self.version=='pt': |
|
item_array = torch.stack(self.items) |
|
elif self.version=='np': |
|
item_array = np.stack(self.items) |
|
if num is not None: |
|
|
|
assert(len(self.items) >= num) |
|
|
|
|
|
if len(self.items) < num: |
|
return item_array |
|
else: |
|
idxs = np.random.randint(len(self.items), size=num) |
|
return item_array[idxs] |
|
else: |
|
return item_array |
|
|
|
def is_full(self): |
|
full = len(self.items)==self.pool_size |
|
return full |
|
|
|
def empty(self): |
|
self.items = [] |
|
|
|
def update(self, items): |
|
for item in items: |
|
if len(self.items) < self.pool_size: |
|
|
|
self.items.append(item) |
|
else: |
|
|
|
|
|
self.items.pop(0) |
|
|
|
self.items.append(item) |
|
return self.items |
|
|
|
def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): |
|
""" |
|
Input: |
|
xyz: pointcloud data, [B, N, C], where C is probably 3 |
|
npoint: number of samples |
|
Return: |
|
inds: sampled pointcloud index, [B, npoint] |
|
""" |
|
device = xyz.device |
|
B, N, C = xyz.shape |
|
xyz = xyz.float() |
|
inds = torch.zeros(B, npoint, dtype=torch.long).to(device) |
|
distance = torch.ones(B, N).to(device) * 1e10 |
|
if deterministic: |
|
farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device) |
|
else: |
|
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) |
|
batch_indices = torch.arange(B, dtype=torch.long).to(device) |
|
for i in range(npoint): |
|
if include_ends: |
|
if i==0: |
|
farthest = 0 |
|
elif i==1: |
|
farthest = N-1 |
|
inds[:, i] = farthest |
|
centroid = xyz[batch_indices, farthest, :].view(B, 1, C) |
|
dist = torch.sum((xyz - centroid) ** 2, -1) |
|
mask = dist < distance |
|
distance[mask] = dist[mask] |
|
farthest = torch.max(distance, -1)[1] |
|
|
|
if npoint > N: |
|
|
|
distance += torch.randn_like(distance) |
|
return inds |
|
|
|
def farthest_point_sample_py(xyz, npoint): |
|
N,C = xyz.shape |
|
inds = np.zeros(npoint, dtype=np.int32) |
|
distance = np.ones(N) * 1e10 |
|
farthest = np.random.randint(0, N, dtype=np.int32) |
|
for i in range(npoint): |
|
inds[i] = farthest |
|
centroid = xyz[farthest, :].reshape(1,C) |
|
dist = np.sum((xyz - centroid) ** 2, -1) |
|
mask = dist < distance |
|
distance[mask] = dist[mask] |
|
farthest = np.argmax(distance, -1) |
|
if npoint > N: |
|
|
|
distance += np.random.randn(*distance.shape) |
|
return inds |
|
|
|
|