|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torchvision import transforms |
|
import torch.nn.functional as F |
|
from torch.distributions.categorical import Categorical |
|
|
|
from kornia.filters.kernels import (get_spatial_gradient_kernel2d, |
|
normalize_kernel2d) |
|
|
|
def l2_normalize(x): |
|
return F.normalize(x, p=2.0, dim=-1, eps=1e-6) |
|
|
|
def reduce_max(x, dim, keepdim=True): |
|
return torch.max(x, dim=dim, keepdim=keepdim)[0] |
|
|
|
def coordinate_ims(batch_size, seq_length, imsize): |
|
static = False |
|
if seq_length == 0: |
|
static = True |
|
seq_length = 1 |
|
B = batch_size |
|
T = seq_length |
|
H,W = imsize |
|
ones = torch.ones([B,H,W,1], dtype=torch.float32) |
|
h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=torch.float32)) |
|
h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5) |
|
w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=torch.float32)) |
|
w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5) |
|
h = torch.stack([h]*T, 1) |
|
w = torch.stack([w]*T, 1) |
|
hw_ims = torch.cat([h,w], -1) |
|
if static: |
|
hw_ims = hw_ims[:,0] |
|
return hw_ims |
|
|
|
def dot_product_attention(queries, keys, normalize=True, eps=1e-8): |
|
""" |
|
Compute the normalized dot product between two PyTorch tensors |
|
""" |
|
B,N,D_q = queries.size() |
|
_B,N_k,D_k = keys.size() |
|
assert D_q == D_k, (queries.shape, keys.shape) |
|
if normalize: |
|
queries = F.normalize(queries, p=2.0, dim=-1, eps=eps) |
|
keys = F.normalize(keys, p=2.0, dim=-1, eps=eps) |
|
|
|
outputs = torch.matmul(queries, torch.transpose(keys, 1, 2)) |
|
attention = torch.transpose(outputs, 1, 2) |
|
|
|
return outputs |
|
|
|
def sample_image_inds_from_probs(probs, num_points, eps=1e-9): |
|
|
|
B,H,W = probs.shape |
|
P = num_points |
|
N = H*W |
|
|
|
probs = probs.reshape(B,N) |
|
probs = torch.maximum(probs + eps, torch.tensor(0., device=probs.device)) / (probs.sum(dim=-1, keepdim=True) + eps) |
|
dist = Categorical(probs=probs, validate_args=False) |
|
|
|
indices = dist.sample([P]).permute(1,0).to(torch.int32) |
|
|
|
indices_h = torch.minimum(torch.maximum(torch.div(indices, W, rounding_mode='floor'), torch.tensor(0)), torch.tensor(H-1)) |
|
indices_w = torch.minimum(torch.maximum(torch.fmod(indices, W), torch.tensor(0)), torch.tensor(W-1)) |
|
indices = torch.stack([indices_h, indices_w], dim=-1) |
|
return indices |
|
|
|
def get_gradient_image(image, mode='sobel', order=1, normalize_kernel=True): |
|
|
|
B,C,H,W = list(image.size()) |
|
|
|
|
|
kernel = get_spatial_gradient_kernel2d(mode, order) |
|
if normalize_kernel: |
|
kernel = normalize_kernel2d(kernel) |
|
tmp_kernel = kernel.to(image).detach() |
|
tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) |
|
kernel_flip = tmp_kernel.flip(-3) |
|
|
|
|
|
padding = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] |
|
out_channels = 3 if (order == 2) else 2 |
|
padded_image = F.pad(image.reshape(B*C, 1, H, W), padding, 'replicate')[:, :, None] |
|
gradient_image = F.conv3d(padded_image, kernel_flip, padding=0).view(B, C, out_channels, H, W) |
|
return gradient_image |
|
|
|
def sample_coordinates_at_borders(image, num_points=16, mask=None, sum_edges=True, normalized_coordinates=True): |
|
""" |
|
Sample num_points in normalized (h,w) coordinates from the borders of the input image |
|
""" |
|
B,C,H,W = list(image.size()) |
|
if mask is not None: |
|
assert mask.shape[2:] == image.shape[2:], (mask.size(), image.size()) |
|
else: |
|
mask = torch.ones(size=(B,1,H,W)).to(image) |
|
|
|
gradient_image = get_gradient_image(image * mask, mode='sobel', order=1) |
|
gradient_magnitude = torch.sqrt(torch.square(gradient_image).sum(dim=2)) |
|
if sum_edges: |
|
edges = gradient_magnitude.sum(1) |
|
else: |
|
edges = gradient_magnitude.max(1)[0] |
|
|
|
if mask is not None: |
|
edges = edges * mask[:,0] |
|
|
|
coordinates = sample_image_inds_from_probs(edges, num_points=num_points) |
|
if normalized_coordinates: |
|
coordinates = coordinates.to(torch.float32) |
|
coordinates /= torch.tensor([H-1,W-1], dtype=torch.float32)[None,None].to(coordinates.device) |
|
coordinates = 2.0 * coordinates - 1.0 |
|
return coordinates |
|
|
|
def index_into_images(images, indices, channels_last=False): |
|
""" |
|
index into an image at P points to get its values |
|
|
|
images: [B,C,H,W] |
|
indices: [B,P,2] |
|
""" |
|
assert indices.size(-1) == 2, indices.size() |
|
if channels_last: |
|
images = images.permute(0,3,1,2) |
|
B,C,H,W = images.shape |
|
_,P,_ = indices.shape |
|
inds_h, inds_w = list(indices.to(torch.long).permute(2,0,1)) |
|
inds_b = torch.arange(B, dtype=torch.long).unsqueeze(-1).expand(-1,P).to(indices) |
|
inds = torch.stack([inds_b, inds_h, inds_w], 0).to(torch.long) |
|
values = images.permute(0,2,3,1)[list(inds)] |
|
return values |
|
|
|
def soft_index(images, indices, scale_by_imsize=True): |
|
assert indices.shape[-1] == 2, indices.shape |
|
B,C,H,W = images.shape |
|
_,P,_ = indices.shape |
|
|
|
|
|
h_inds, w_inds = list(indices.permute(2,0,1)) |
|
if scale_by_imsize: |
|
h_inds = (h_inds + 1.0) * torch.tensor(H).to(h_inds) * 0.5 |
|
w_inds = (w_inds + 1.0) * torch.tensor(W).to(w_inds) * 0.5 |
|
|
|
h_inds = torch.maximum(torch.minimum(h_inds, torch.tensor(H-1).to(h_inds)), torch.tensor(0.).to(h_inds)) |
|
w_inds = torch.maximum(torch.minimum(w_inds, torch.tensor(W-1).to(w_inds)), torch.tensor(0.).to(w_inds)) |
|
|
|
h_floor = torch.floor(h_inds) |
|
w_floor = torch.floor(w_inds) |
|
h_ceil = torch.ceil(h_inds) |
|
w_ceil = torch.ceil(w_inds) |
|
|
|
bot_right_weight = (h_inds - h_floor) * (w_inds - w_floor) |
|
bot_left_weight = (h_inds - h_floor) * (w_ceil - w_inds) |
|
top_right_weight = (h_ceil - h_inds) * (w_inds - w_floor) |
|
top_left_weight = (h_ceil - h_inds) * (w_ceil - w_inds) |
|
|
|
in_bounds = (bot_right_weight + bot_left_weight + top_right_weight + top_left_weight) > 0.95 |
|
in_bounds = in_bounds.to(torch.float32) |
|
|
|
top_left_vals = index_into_images(images, torch.stack([h_floor, w_floor], -1)) |
|
top_right_vals = index_into_images(images, torch.stack([h_floor, w_ceil], -1)) |
|
bot_left_vals = index_into_images(images, torch.stack([h_ceil, w_floor], -1)) |
|
bot_right_vals = index_into_images(images, torch.stack([h_ceil, w_ceil], -1)) |
|
|
|
im_vals = top_left_vals * top_left_weight[...,None] |
|
im_vals += top_right_vals * top_right_weight[...,None] |
|
im_vals += bot_left_vals * bot_left_weight[...,None] |
|
im_vals += bot_right_vals * bot_right_weight[...,None] |
|
|
|
im_vals = im_vals.view(B,P,C) |
|
|
|
return im_vals |
|
|
|
def compute_compatibility(positions, plateau, phenotypes=None, availability=None, noise=0.1): |
|
""" |
|
Compute how well "fit" each agent is for the position it's at on the plateau, |
|
according to its "phenotype" |
|
|
|
positions: [B,P,2] |
|
plateau: [B,H,W,Q] |
|
phenotypes: [B,P,D] or None |
|
availability: [B,H,W,A] |
|
""" |
|
B,H,W,Q = plateau.shape |
|
P = positions.shape[1] |
|
if phenotypes is None: |
|
phenotypes = soft_index(plateau, positions) |
|
|
|
if availability is not None: |
|
assert list(availability.shape)[:-1] == list(plateau.shape)[:-1], (availability.shape, plateau.shape) |
|
A = availability.size(-1) |
|
assert P % A == 0, (P, A) |
|
S = P // A |
|
print("computing availability -- needlessly?", [B,H,W,A,Q]) |
|
plateau = availability[...,None] * plateau[...,None,:] |
|
plateau = plateau.view(B,H,W,A*Q) |
|
|
|
plateau_values = soft_index(plateau.permute(0,3,1,2), positions, scale_by_imsize=True) |
|
if noise > 0: |
|
plateau_values += noise * torch.rand(size=plateau_values.size(), dtype=torch.float32).to(plateau_values.device) |
|
|
|
if availability is not None: |
|
plateau_values = l2_normalize(plateau_values.view(B, P, A, Q)) |
|
inds = torch.tile(torch.eye(A)[None].expand(B,-1,-1), (1,S,1))[...,None] |
|
plateau_values = torch.sum(plateau_values * inds.to(plateau_values), dim=-2) |
|
else: |
|
plateau_values = l2_normalize(plateau_values) |
|
|
|
compatibility = torch.sum( |
|
l2_normalize(phenotypes) * plateau_values, dim=-1, keepdim=True) |
|
|
|
return compatibility |
|
|
|
def compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=None, eps=1e-6): |
|
"""Find overlaps between masks""" |
|
B,N,P = masks.shape |
|
if masks_target is None: |
|
masks_target = masks |
|
if mask_thresh is not None: |
|
masks = (masks > mask_thresh).to(torch.float32) |
|
masks_target = (masks_target > mask_thresh).to(torch.float32) |
|
|
|
|
|
overlaps = masks[...,None] * masks_target[...,None,:] |
|
I = overlaps.sum(dim=1) |
|
U = torch.maximum(masks[...,None], masks_target[...,None,:]).sum(dim=1) |
|
iou = I / torch.maximum(U, torch.tensor(eps, dtype=torch.float32)) |
|
|
|
return iou |
|
|
|
def compete_agents(masks, fitnesses, alive, |
|
mask_thresh=0.5, compete_thresh=0.2, |
|
sticky_winners=True): |
|
""" |
|
Kill off agents (which mask dimensions are "alive") based on mask overlap and fitnesses of each |
|
|
|
args: |
|
masks: [B,N,P] |
|
fitnesses: [B,P,1] |
|
alive: [B,P,1] |
|
|
|
returns: |
|
still_alive: [B,P,1] |
|
|
|
""" |
|
B,N,P = masks.shape |
|
assert list(alive.shape) == [B,P,1], alive.shape |
|
assert list(fitnesses.shape) == [B,P,1], fitnesses.shape |
|
|
|
|
|
overlaps = compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=mask_thresh) |
|
disputes = overlaps > compete_thresh |
|
|
|
|
|
disputes = torch.logical_and( |
|
disputes, torch.logical_not( |
|
torch.eye(P, dtype=torch.bool, device=disputes.device).unsqueeze(0).expand(B,-1,-1))) |
|
|
|
|
|
killed = torch.logical_and(disputes, fitnesses < torch.transpose(fitnesses, 1, 2)) |
|
|
|
|
|
if sticky_winners: |
|
winners = (alive > 0.5) |
|
losers = torch.logical_not(winners) |
|
|
|
|
|
winners_vs_losers = torch.logical_and(winners, torch.transpose(losers, 1, 2)) |
|
killed = torch.logical_and(killed, torch.logical_not(winners_vs_losers)) |
|
|
|
|
|
losers_vs_winners = torch.logical_and(losers, torch.transpose(winners, 1, 2)) |
|
losers_vs_winners_disputes = torch.logical_and(losers_vs_winners, disputes) |
|
killed = torch.logical_or(killed, losers_vs_winners_disputes) |
|
|
|
|
|
killed = torch.any(killed, dim=2, keepdim=True) |
|
alive = torch.logical_not(killed).to(torch.float32) |
|
|
|
return alive |
|
|
|
def compute_distance_weighted_vectors(vector_map, positions, mask=None, beta=1.0, eps=1e-8): |
|
""" |
|
compute vectors whose values are a weighted mean of vector_map, where weights are given by distance. |
|
""" |
|
B,H,W,D = vector_map.shape |
|
assert positions.size(-1) == 2, positions.size() |
|
B,P,_ = positions.shape |
|
N = H*W |
|
|
|
if mask is None: |
|
mask = torch.ones_like(vector_map[...,0:1]).to(vector_map.device) |
|
else: |
|
assert list(mask.shape) == [B,H,W,1] |
|
|
|
hw_grid = coordinate_ims(B, 0, [H,W]).view(B, N, 2).to(vector_map.device) |
|
delta_positions = hw_grid[:,None] - positions[:,:,None] |
|
distances = torch.sqrt(delta_positions[...,0]**2 + delta_positions[...,1]**2 + eps) |
|
|
|
|
|
inv_distances = (2.0 * np.sqrt(2.0)) / (distances + eps) |
|
inv_distances = F.softmax(beta * inv_distances * mask.view(B, 1, N), dim=-1) |
|
distance_weighted_vectors = torch.sum( |
|
vector_map.view(B, 1, N, D) * inv_distances[...,None], dim=2, keepdim=False) |
|
return distance_weighted_vectors |
|
|
|
def masks_from_phenotypes(plateau, phenotypes, normalize=True): |
|
|
|
B,H,W,Q = plateau.shape |
|
N = H*W |
|
masks = dot_product_attention( |
|
queries=plateau.view(B,N,Q), |
|
keys=phenotypes, |
|
normalize=normalize) |
|
masks = F.relu(masks) |
|
return masks |
|
|
|
class Competition(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
size=None, |
|
num_masks=16, |
|
num_competition_rounds=5, |
|
mask_beta=10.0, |
|
reduce_func=reduce_max, |
|
stop_gradient=True, |
|
stop_gradient_phenotypes=True, |
|
normalization_func=l2_normalize, |
|
sum_edges=True, |
|
mask_thresh=0.5, |
|
compete_thresh=0.2, |
|
sticky_winners=True, |
|
selection_strength=100.0, |
|
homing_strength=10.0, |
|
mask_dead_segments=True |
|
): |
|
super().__init__() |
|
self.num_masks = self.M = num_masks |
|
self.num_competition_rounds = num_competition_rounds |
|
self.mask_beta = mask_beta |
|
self.reduce_func = reduce_func |
|
self.normalization_func = normalization_func |
|
|
|
|
|
self.sg_func = lambda x: (x.detach() if stop_gradient else x) |
|
self.sg_phenotypes_func = lambda x: (x.detach() if stop_gradient_phenotypes else x) |
|
|
|
|
|
self.sum_edges = sum_edges |
|
|
|
|
|
self.mask_thresh = mask_thresh |
|
self.compete_thresh = compete_thresh |
|
self.sticky_winners = sticky_winners |
|
self.selection_strength = selection_strength |
|
self.homing_strength = homing_strength |
|
self.mask_dead_segments = mask_dead_segments |
|
|
|
|
|
self.B = self.T = self.BT = self.N = self.Q = None |
|
self.size = size |
|
if self.size: |
|
assert len(self.size) == 2, self.size |
|
|
|
def reshape_batch_time(self, x, merge=True): |
|
|
|
if merge: |
|
self.is_temporal = True |
|
B, T = x.size()[0:2] |
|
if self.B: |
|
assert (B == self.B), (B, self.B) |
|
else: |
|
self.B = B |
|
|
|
if self.T: |
|
assert (T == self.T), (T, self.T) |
|
else: |
|
self.T = T |
|
|
|
assert B*T == (self.B * self.T), (B*T, self.B*self.T) |
|
if self.BT is None: |
|
self.BT = self.B * self.T |
|
|
|
return torch.reshape(x, [self.BT] + list(x.size())[2:]) |
|
|
|
else: |
|
BT = x.size()[0] |
|
assert self.B and self.T, (self.B, self.T) |
|
if self.BT is not None: |
|
assert BT == self.BT, (BT, self.BT) |
|
else: |
|
self.BT = BT |
|
|
|
return torch.reshape(x, [self.B, self.T] + list(x.size())[1:]) |
|
|
|
def process_plateau_input(self, plateau): |
|
|
|
shape = plateau.size() |
|
if len(shape) == 5: |
|
self.is_temporal = True |
|
self.B, self.T, self.H, self.W, self.Q = shape |
|
self.N = self.H * self.W |
|
self.BT = self.B * self.T |
|
plateau = self.reshape_batch_time(plateau) |
|
elif (len(shape) == 4) and (self.size is None): |
|
self.is_temporal = False |
|
self.B, self.H, self.W, self.Q = shape |
|
self.N = self.H * self.W |
|
self.T = 1 |
|
self.BT = self.B*self.T |
|
elif (len(shape) == 4) and (self.size is not None): |
|
self.is_temporal = True |
|
self.B, self.T, self.N, self.Q = shape |
|
self.BT = self.B * self.T |
|
self.H, self.W = self.size |
|
plateau = self.reshape_batch_time(plateau) |
|
plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q]) |
|
elif len(shape) == 3: |
|
assert self.size is not None, \ |
|
"You need to specify an image size to reshape the plateau of shape %s" % shape |
|
self.is_temporal = False |
|
self.B, self.N, self.Q = shape |
|
self.T = 1 |
|
self.BT = self.B |
|
self.H, self.W = self.size |
|
plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q]) |
|
else: |
|
raise ValueError("input plateau map with shape %s cannot be reshaped to [BT, H, W, Q]" % shape) |
|
|
|
return plateau |
|
|
|
def forward(self, |
|
plateau, |
|
agents=None, |
|
alive=None, |
|
phenotypes=None, |
|
compete=True, |
|
update_pointers=True, |
|
yoke_phenotypes_to_agents=True, |
|
noise=0.1 |
|
): |
|
""" |
|
Find the uniform regions within the plateau map |
|
by competition between visual "indices." |
|
|
|
args: |
|
plateau: [B,[T],H,W,Q] feature map with smooth "plateaus" |
|
|
|
returns: |
|
masks: [B, [T], H, W, M] <float> one mask in each of M channels |
|
agents: [B, [T], M, 2] <float> positions of agents in normalized coordinates |
|
alive: [B, [T], M] <float> binary vector indicating which masks are valid |
|
phenotypes: [B, [T], M, Q] |
|
unharvested: [B, [T], H, W] <float> map of regions that weren't covered |
|
|
|
""" |
|
|
|
|
|
plateau = self.process_plateau_input(plateau) |
|
plateau = self.normalization_func(plateau) |
|
|
|
|
|
if agents is None: |
|
agents = sample_coordinates_at_borders( |
|
plateau.permute(0,3,1,2), |
|
num_points=self.M, |
|
mask=None, |
|
sum_edges=self.sum_edges) |
|
else: |
|
if self.is_temporal: |
|
agents = agents.view(self.BT, *agents.shape[2:]) |
|
|
|
|
|
if phenotypes is None: |
|
phenotypes = self.sg_phenotypes_func( |
|
self.normalization_func( |
|
soft_index(plateau.permute(0,3,1,2), |
|
agents, scale_by_imsize=True))) |
|
elif self.is_temporal: |
|
phenotypes = phenotypes.view(self.BT, *phenotypes.shape[2:]) |
|
|
|
|
|
|
|
|
|
if alive is None: |
|
alive = torch.ones_like(agents[...,-1:]) |
|
fitnesses = compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) |
|
alive_mask = None |
|
else: |
|
if self.is_temporal: |
|
alive = alive.view(self.BT, *alive.shape[2:]) |
|
alive_mask = (alive > 0.5).float() |
|
fitnesses = alive_mask + compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) * (1 - alive_mask) |
|
|
|
alive_t = torch.transpose(alive, 1, 2) |
|
|
|
|
|
masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True) |
|
|
|
|
|
unharvested = torch.minimum(self.reduce_func(masks_pred, dim=-1, keepdim=True), torch.tensor(1.0)) |
|
unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1) |
|
|
|
if alive_mask is not None: |
|
new_agents = sample_coordinates_at_borders( |
|
plateau.permute(0,3,1,2), num_points=self.M, |
|
mask=unharvested.permute(0,3,1,2), |
|
sum_edges=self.sum_edges) |
|
agents = agents * alive_mask + new_agents * (1.0 - alive_mask) |
|
|
|
new_phenotypes = self.sg_phenotypes_func( |
|
self.normalization_func( |
|
soft_index(plateau.permute(0,3,1,2), |
|
new_agents, scale_by_imsize=True))) |
|
phenotypes = phenotypes * alive_mask + new_phenotypes * (1.0 - alive_mask) |
|
|
|
for r in range(self.num_competition_rounds): |
|
|
|
|
|
|
|
alive_t = torch.transpose(alive, 1, 2) |
|
|
|
|
|
|
|
|
|
if update_pointers and compete: |
|
fitnesses = compute_compatibility( |
|
positions=agents, |
|
plateau=plateau, |
|
phenotypes=phenotypes, |
|
|
|
availability=None, |
|
noise=noise |
|
) |
|
|
|
|
|
|
|
in_bounds = torch.all( |
|
torch.logical_and(agents < 1.0, agents > -1.0), |
|
dim=-1, keepdim=True) |
|
fitnesses *= in_bounds.to(fitnesses) |
|
|
|
|
|
fitnesses -= 0.001 * torch.arange(self.M, dtype=torch.float32)[None,:,None].expand(self.BT,-1,-1).to(fitnesses.device) |
|
|
|
|
|
if yoke_phenotypes_to_agents: |
|
occupied_regions = self.sg_phenotypes_func( |
|
soft_index(plateau.permute(0,3,1,2), agents, scale_by_imsize=True)) |
|
masks_pred = masks_from_phenotypes(plateau, occupied_regions, normalize=True) |
|
|
|
|
|
|
|
if compete: |
|
alive = compete_agents(masks_pred, fitnesses, alive, |
|
mask_thresh=self.mask_thresh, |
|
compete_thresh=self.compete_thresh, |
|
sticky_winners=self.sticky_winners) |
|
|
|
alive *= in_bounds.to(alive) |
|
alive_t = torch.transpose(alive, 1, 2) |
|
|
|
|
|
if not yoke_phenotypes_to_agents: |
|
masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True) |
|
|
|
|
|
unharvested = torch.minimum(self.reduce_func(masks_pred * alive_t, dim=-1, keepdim=True), |
|
torch.tensor(1.0, dtype=torch.float32)) |
|
unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1) |
|
|
|
|
|
|
|
if update_pointers: |
|
if self.mask_thresh is not None: |
|
winner_phenotypes = (masks_pred[...,None] > self.mask_thresh).to(plateau) |
|
if self.selection_strength > 0: |
|
winner_phenotypes = winner_phenotypes * plateau.view(self.BT, self.N, 1, self.Q) |
|
winner_phenotypes = self.normalization_func(winner_phenotypes.mean(dim=1)) |
|
phenotypes += (alive * winner_phenotypes) * self.selection_strength |
|
|
|
|
|
alive_mask = (alive > 0.5).to(torch.float32) |
|
loser_agents = sample_coordinates_at_borders( |
|
plateau.permute(0,3,1,2), num_points=self.M, |
|
mask=unharvested.permute(0,3,1,2), |
|
sum_edges=self.sum_edges) |
|
agents = agents * alive_mask + loser_agents * (1.0 - alive_mask) |
|
|
|
|
|
|
|
loser_phenotypes = self.normalization_func( |
|
compute_distance_weighted_vectors(plateau, agents, mask=unharvested, beta=self.homing_strength)) |
|
phenotypes = alive_mask * phenotypes + (1.0 - alive_mask) * loser_phenotypes |
|
phenotypes = self.normalization_func(phenotypes) |
|
|
|
|
|
|
|
|
|
|
|
if self.mask_beta is not None: |
|
masks_pred = F.softmax( |
|
self.mask_beta * masks_pred * alive_t - \ |
|
self.mask_beta * (1.0 - alive_t), dim=-1) |
|
if self.mask_dead_segments: |
|
masks_pred *= alive_t |
|
|
|
masks_pred = masks_pred.view(self.BT,self.H,self.W,self.M) |
|
if self.is_temporal: |
|
masks_pred = self.reshape_batch_time(masks_pred, merge=False) |
|
agents = self.reshape_batch_time(agents, merge=False) |
|
alive = self.reshape_batch_time(alive, merge=False) |
|
phenotypes = self.reshape_batch_time(phenotypes, merge=False) |
|
unharvested = self.reshape_batch_time(unharvested, merge=False) |
|
|
|
return (masks_pred, agents, alive, phenotypes, unharvested) |
|
|
|
@staticmethod |
|
def masks_to_segments(masks): |
|
return masks.argmax(-1) |
|
|
|
@staticmethod |
|
def flatten_plateau_with_masks(plateau, masks, alive, flatten_masks=True): |
|
B,M,_ = alive.shape |
|
Q = plateau.shape[-1] |
|
if flatten_masks: |
|
masks = F.one_hot((alive[...,None,None,:,0] * masks).argmax(-1), num_classes=M).float() |
|
|
|
flat_plateau = torch.zeros_like(plateau) |
|
phenotypes = torch.zeros((B,M,Q), device=plateau.device).float() |
|
for b in range(B): |
|
m_inds = torch.where(alive[b,:,0])[0] |
|
masks_b = masks[b,...,m_inds] |
|
num_px = masks_b.sum((0,1)).clamp(min=1)[:,None] |
|
phenos_b = torch.einsum('hwk,hwq->kq', masks_b, plateau[b]) / num_px |
|
flat_plateau_b = (masks_b[...,None] * phenos_b[None,None]).sum(-2) |
|
|
|
phenotypes[b,m_inds,:] = phenos_b |
|
flat_plateau[b] = flat_plateau_b |
|
|
|
_norm = lambda x: F.normalize(x, p=2, dim=-1) |
|
return (_norm(flat_plateau), _norm(phenotypes)) |
|
|
|
@staticmethod |
|
def plot_agents(agents, alive, size=[128,128]): |
|
B,M,_ = alive.shape |
|
agent_map = -1 * torch.ones((B,*size), device=alive.device, dtype=torch.long) |
|
for b in range(B): |
|
inds = torch.where(alive[b,:,0]) |
|
for i in inds[0]: |
|
pos = agents[b,i]*0.5 + 0.5 |
|
pos = pos * torch.tensor(size, device=pos.device) |
|
hmin, wmin = list(torch.floor(pos).long()) |
|
hmax, wmax = list(torch.ceil(pos).long()) |
|
agent_map[b,[hmin,hmin,hmax,hmax],[wmin,wmax,wmin,wmax]] = i |
|
|
|
return agent_map |
|
|
|
if __name__ == '__main__': |
|
|
|
Comp = Competition(num_masks=32, num_competition_rounds=5) |
|
|
|
left = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([1.,0.2,0.]) |
|
middle = torch.ones(size=(32,16)).unsqueeze(-1) * torch.tensor([0.,1.,0.2]) |
|
right = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([0.1,0.,1.]) |
|
plateau = torch.cat([left, middle, right], dim=-2).unsqueeze(0) |
|
masks, agents, alive, phenotypes, unharvested = Comp(plateau) |
|
mask_inds = np.where(alive[0,:,0].numpy())[0] |
|
print(np.argmax(masks[0,...], axis=-1)) |
|
for ind in mask_inds: |
|
print("num pixels in mask %d ---> %d" % (ind, (np.argmax(masks[0], -1) == ind).sum())) |
|
|