from kornia.contrib import connected_components | |
import torch | |
import pdb | |
import matplotlib.pyplot as plt | |
import time | |
# def reorder_int_labels(x): | |
# _, y = torch.unique(x, return_inverse=True) | |
# y -= y.min() | |
# return y | |
# def label_connected_component(labels, max_area=500, min_area=20, max_ccs=128, num_iterations=500): | |
# assert len(labels.size()) == 2 | |
# # per-label binary mask | |
# unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1] | |
# binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] | |
# # label connected components | |
# cc = connected_components(binary_masks.unsqueeze(1), num_iterations=num_iterations) # [?, 1, H, W] | |
# cc = reorder_int_labels(cc) | |
# bincount = torch.bincount(cc.long().flatten()) | |
# # find all connected components (id, mask, area, valid) | |
# # cc_id = torch.nonzero(bincount) # [num_cc] | |
# cc_id = torch.argsort(bincount)[-max_ccs:] | |
# cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] | |
# cc_area = bincount[cc_id] # [num_cc] | |
# valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] | |
# valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1] | |
# # final labels for connected component | |
# out = valid * cc_mask | |
# out = out.argmax(0) | |
# return out | |
def reorder_int_labels(x): | |
_, y = torch.unique(x, return_inverse=True) | |
y -= y.min() | |
return y | |
def label_connected_component(labels, min_area=20, topk=256): | |
size = labels.size() | |
assert len(size) == 2 | |
max_area = size[0] * size[1] - 1 | |
# per-label binary mask | |
unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1], where ? is the number of unique id | |
binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] | |
# label connected components | |
# cc is an integer tensor, each unique id represents a single connected component | |
cc = connected_components(binary_masks.unsqueeze(1), num_iterations=500) # [?, 1, H, W] | |
# reorder indices in cc so that cc_area tensor below is a smaller | |
cc = reorder_int_labels(cc) | |
# area of each connected components | |
cc_area = torch.bincount(cc.long().flatten().cpu()).cuda() # bincount on GPU is much slower | |
num_cc = cc_area.shape[0] | |
valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] | |
if num_cc < topk: | |
selected_cc = torch.arange(num_cc).cuda() | |
else: | |
_, selected_cc = torch.topk(cc_area, k=topk) | |
valid = valid[selected_cc] | |
# collapse the 0th dimension, since there is only matched one connected component (across 0th dimension) | |
cc_mask = (cc == selected_cc.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] | |
cc_mask = cc_mask * valid.reshape(-1, 1, 1) | |
out = cc_mask.argmax(0) | |
return out | |
# def reorder_int_labels(x): | |
# _, y = torch.unique(x, return_inverse=True) | |
# y -= y.min() | |
# return y | |
# def label_connected_component(labels, max_area=500, min_area=20): | |
# assert len(labels.size()) == 2 | |
# # per-label binary mask | |
# unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1] | |
# binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] | |
# # label connected components | |
# cc = connected_components(binary_masks.unsqueeze(1)) # [?, 1, H, W] | |
# cc = reorder_int_labels(cc) | |
# bincount = torch.bincount(cc.long().flatten()) | |
# # find all connected components (id, mask, area, valid) | |
# cc_id = torch.nonzero(bincount) # [num_cc] | |
# cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] | |
# cc_area = bincount[cc_id] # [num_cc] | |
# valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] | |
# valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1] | |
# # final labels for connected component | |
# out = valid * cc_mask | |
# out = out.argmax(0) | |