Spaces:
Running
Running
import torch | |
EPS=1e-10 | |
def get_CosineDistance_matrix(features): | |
if features.dim() >2: | |
features = features.reshape(features.shape[0], -1) | |
features_norm = features / (EPS + features.norm(dim=1)[:, None]) | |
ans = torch.mm(features_norm, features_norm.transpose(0,1)) | |
# We want distance, not similarity. | |
ans = torch.add(-ans, 1.) | |
return ans | |
def aggregatefrom_specimen_to_species(sorted_class_names_according_to_class_indx, specimen_distance_matrix, z_size, channels): | |
unique_sorted_class_names_according_to_class_indx = sorted(set(sorted_class_names_according_to_class_indx)) | |
# species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), 256, 16, 16) | |
species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), channels, z_size, z_size) | |
for indx_i, i in enumerate(unique_sorted_class_names_according_to_class_indx): | |
class_i_indices = [idx for idx, element in enumerate(sorted_class_names_according_to_class_indx) if element == i] | |
species_dist_matrix[indx_i] = torch.mean(specimen_distance_matrix[class_i_indices,:], dim=0, keepdim=True) | |
return species_dist_matrix |