Spaces:
Runtime error
Runtime error
Last commit not found
""" Network utils | |
- poolfeat: aggregate superpixel features from pixel features | |
- upfeat: reconstruction pixel features from superpixel features | |
- quantize: quantization features given a codebook | |
""" | |
import torch | |
def poolfeat(input, prob, avg = True): | |
""" A function to aggregate superpixel features from pixel features | |
Args: | |
input (tensor): input feature tensor. | |
prob (tensor): one-hot superpixel segmentation. | |
avg (bool, optional): average or sum the pixel features to get superpixel features | |
Returns: | |
cluster_feat (tensor): the superpixel features | |
Shape: | |
input: (B, C, H, W) | |
prob: (B, N, H, W) | |
cluster_feat: (B, N, C) | |
""" | |
B, C, H, W = input.shape | |
B, N, H, W = prob.shape | |
prob_flat = prob.view(B, N, -1) | |
input_flat = input.view(B, C, -1) | |
cluster_feat = torch.matmul(prob_flat, input_flat.permute(0, 2, 1)) | |
if avg: | |
cluster_sum = torch.sum(prob_flat, dim = -1).view(B, N , 1) | |
cluster_feat = cluster_feat / (cluster_sum + 1e-8) | |
return cluster_feat | |
def upfeat(input, prob): | |
""" A function to compute pixel features from superpixel features | |
Args: | |
input (tensor): superpixel feature tensor. | |
prob (tensor): one-hot superpixel segmentation. | |
Returns: | |
reconstr_feat (tensor): the pixel features. | |
Shape: | |
input: (B, N, C) | |
prob: (B, N, H, W) | |
reconstr_feat: (B, C, H, W) | |
""" | |
B, N, H, W = prob.shape | |
prob_flat = prob.view(B, N, -1) | |
reconstr_feat = torch.matmul(prob_flat.permute(0, 2, 1), input) | |
reconstr_feat = reconstr_feat.view(B, H, W, -1).permute(0, 3, 1, 2) | |
return reconstr_feat | |
def quantize(z, embedding, beta = 0.25): | |
""" | |
Inputs the output of the encoder network z and maps it to a discrete | |
one-hot vector that is the index of the closest embedding vector e_j | |
Args: | |
z (tensor): features from the encoder network | |
embedding (tensor): codebook | |
beta (scalar, optional): commit loss weight | |
Returns: | |
z_q: quantized features | |
loss: vq loss + commit loss * beta | |
min_encodings: quantization assignment one hot vector | |
min_encoding_indices: quantization assignment | |
Shape: | |
z: B, N, C | |
embedding: B, K, C | |
z_q: B, N, C | |
min_encodings: B, N, K | |
min_encoding_indices: B, N, 1 | |
Note: | |
Adapted from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py | |
""" | |
# B, 256, 32 | |
if embedding.shape[0] == 1: | |
d = torch.sum(z ** 2, dim=2, keepdim=True) + torch.sum(embedding**2, dim=2) - 2 * torch.matmul(z, embedding.transpose(1, 2)) | |
else: | |
ds = [] | |
for i in range(embedding.shape[0]): | |
z_i = z[i:i+1] | |
embedding_i = embedding[i:i+1] | |
ds.append(torch.sum(z_i ** 2, dim=2, keepdim=True) + torch.sum(embedding_i**2, dim=2) - 2 * torch.matmul(z_i, embedding_i.transpose(1, 2))) | |
d = torch.cat(ds) | |
## could possible replace this here | |
# #\start... | |
# find closest encodings | |
min_encoding_indices = torch.argmin(d, dim=2).unsqueeze(2) # B, 256, 1 | |
#min_encodings = torch.zeros( | |
# min_encoding_indices.shape[0], self.n_e).to(z) | |
#min_encodings.scatter_(1, min_encoding_indices, 1) | |
n_e = embedding.shape[1] # 32 | |
min_encodings = torch.zeros(z.shape[0], z.shape[1], n_e).to(z) | |
min_encodings.scatter_(2, min_encoding_indices, 1) | |
# dtype min encodings: torch.float32 | |
# min_encodings shape: torch.Size([2048, 512]) | |
# min_encoding_indices.shape: torch.Size([2048, 1]) | |
# get quantized latent vectors | |
z_q = torch.matmul(min_encodings, embedding).view(z.shape) | |
#.........\end | |
# with: | |
# .........\start | |
#min_encoding_indices = torch.argmin(d, dim=1) | |
#z_q = self.embedding(min_encoding_indices) | |
# ......\end......... (TODO) | |
# compute loss for embedding | |
loss = torch.mean((z_q.detach()-z)**2) + beta * torch.mean((z_q - z.detach()) ** 2) | |
# preserve gradients | |
z_q = z + (z_q - z).detach() | |
return z_q, loss, (min_encodings, min_encoding_indices, d) |