File size: 5,534 Bytes
4f6b78d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import numpy as np
import math
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
param = parameter.numel()
if param > 100000:
table.add_row([name, param])
total_params+=param
print(table)
print('total params: %.2f M' % (total_params/1000000.0))
return total_params
def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):
device = xy.device
dtype = xy.dtype
B, S, D = xy.shape
assert(D==2)
x = xy[:,:,0]
y = xy[:,:,1]
assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
pe = pe.reshape(B,S,C).type(dtype)
if cat_coords:
pe = torch.cat([pe, xy], dim=2) # B,N,C+2
return pe
class SimplePool():
def __init__(self, pool_size, version='pt'):
self.pool_size = pool_size
self.version = version
self.items = []
if not (version=='pt' or version=='np'):
print('version = %s; please choose pt or np')
assert(False) # please choose pt or np
def __len__(self):
return len(self.items)
def mean(self, min_size=1):
if min_size=='half':
pool_size_thresh = self.pool_size/2
else:
pool_size_thresh = min_size
if self.version=='np':
if len(self.items) >= pool_size_thresh:
return np.sum(self.items)/float(len(self.items))
else:
return np.nan
if self.version=='pt':
if len(self.items) >= pool_size_thresh:
return torch.sum(self.items)/float(len(self.items))
else:
return torch.from_numpy(np.nan)
def sample(self, with_replacement=True):
idx = np.random.randint(len(self.items))
if with_replacement:
return self.items[idx]
else:
return self.items.pop(idx)
def fetch(self, num=None):
if self.version=='pt':
item_array = torch.stack(self.items)
elif self.version=='np':
item_array = np.stack(self.items)
if num is not None:
# there better be some items
assert(len(self.items) >= num)
# if there are not that many elements just return however many there are
if len(self.items) < num:
return item_array
else:
idxs = np.random.randint(len(self.items), size=num)
return item_array[idxs]
else:
return item_array
def is_full(self):
full = len(self.items)==self.pool_size
return full
def empty(self):
self.items = []
def update(self, items):
for item in items:
if len(self.items) < self.pool_size:
# the pool is not full, so let's add this in
self.items.append(item)
else:
# the pool is full
# pop from the front
self.items.pop(0)
# add to the back
self.items.append(item)
return self.items
def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):
"""
Input:
xyz: pointcloud data, [B, N, C], where C is probably 3
npoint: number of samples
Return:
inds: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
xyz = xyz.float()
inds = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
if deterministic:
farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device)
else:
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
if include_ends:
if i==0:
farthest = 0
elif i==1:
farthest = N-1
inds[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
if npoint > N:
# if we need more samples, make them random
distance += torch.randn_like(distance)
return inds
def farthest_point_sample_py(xyz, npoint):
N,C = xyz.shape
inds = np.zeros(npoint, dtype=np.int32)
distance = np.ones(N) * 1e10
farthest = np.random.randint(0, N, dtype=np.int32)
for i in range(npoint):
inds[i] = farthest
centroid = xyz[farthest, :].reshape(1,C)
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
if npoint > N:
# if we need more samples, make them random
distance += np.random.randn(*distance.shape)
return inds
|