phylo-diffusion / ldm /analysis_utils.py
mridulk's picture
added few ldm files
642d5e2
raw
history blame
1.2 kB
import torch
EPS=1e-10
def get_CosineDistance_matrix(features):
if features.dim() >2:
features = features.reshape(features.shape[0], -1)
features_norm = features / (EPS + features.norm(dim=1)[:, None])
ans = torch.mm(features_norm, features_norm.transpose(0,1))
# We want distance, not similarity.
ans = torch.add(-ans, 1.)
return ans
def aggregatefrom_specimen_to_species(sorted_class_names_according_to_class_indx, specimen_distance_matrix, z_size, channels):
unique_sorted_class_names_according_to_class_indx = sorted(set(sorted_class_names_according_to_class_indx))
# species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), 256, 16, 16)
species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), channels, z_size, z_size)
for indx_i, i in enumerate(unique_sorted_class_names_according_to_class_indx):
class_i_indices = [idx for idx, element in enumerate(sorted_class_names_according_to_class_indx) if element == i]
species_dist_matrix[indx_i] = torch.mean(specimen_distance_matrix[class_i_indices,:], dim=0, keepdim=True)
return species_dist_matrix