mridulk commited on
Commit
17191f4
·
1 Parent(s): 642d5e2

added data

Browse files
ldm/data/__init__.py ADDED
File without changes
ldm/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (141 Bytes). View file
 
ldm/data/__pycache__/base.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
ldm/data/__pycache__/constants.cpython-38.pyc ADDED
Binary file (1.24 kB). View file
 
ldm/data/__pycache__/custom.cpython-38.pyc ADDED
Binary file (3.16 kB). View file
 
ldm/data/__pycache__/custom_cub.cpython-38.pyc ADDED
Binary file (3.15 kB). View file
 
ldm/data/__pycache__/i2sb_dataloader.cpython-38.pyc ADDED
Binary file (4.52 kB). View file
 
ldm/data/__pycache__/imagenet.cpython-38.pyc ADDED
Binary file (14.5 kB). View file
 
ldm/data/__pycache__/phylogeny.cpython-38.pyc ADDED
Binary file (8.57 kB). View file
 
ldm/data/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.61 kB). View file
 
ldm/data/base.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
+ import numpy as np
4
+ import albumentations
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+
10
+ class Txt2ImgIterableBaseDataset(IterableDataset):
11
+ '''
12
+ Define an interface to make the IterableDatasets for text2img data chainable
13
+ '''
14
+ def __init__(self, num_records=0, valid_ids=None, size=256):
15
+ super().__init__()
16
+ self.num_records = num_records
17
+ self.valid_ids = valid_ids
18
+ self.sample_ids = valid_ids
19
+ self.size = size
20
+
21
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
22
+
23
+ def __len__(self):
24
+ return self.num_records
25
+
26
+ @abstractmethod
27
+ def __iter__(self):
28
+ pass
29
+
30
+ class ImagePaths(Dataset):
31
+ def __init__(self, paths, size=None, random_crop=False, horizontalflip=False, random_contrast=False, shiftrotate=False, labels=None, unique_skipped_labels=[]):
32
+ self.size = size
33
+ self.random_crop = random_crop
34
+
35
+ self.labels = dict() if labels is None else labels
36
+ self.labels["file_path_"] = paths
37
+ self._length = len(paths)
38
+
39
+ self.labels_without_skipped = None
40
+ if len(unique_skipped_labels)!=0:
41
+ self.labels_without_skipped = dict()
42
+ for i in self.labels.keys():
43
+ self.labels_without_skipped[i] = [a for indx, a in enumerate(labels[i]) if labels['class'][indx] not in unique_skipped_labels]
44
+ self._length = len(self.labels_without_skipped['class'])
45
+
46
+
47
+
48
+
49
+ if self.size is not None and self.size > 0:
50
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
51
+ l = [self.rescaler ]
52
+ if not self.random_crop:
53
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
54
+ else:
55
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
56
+ l.append(self.cropper)
57
+ if horizontalflip==True:
58
+ l.append(albumentations.HorizontalFlip(p=0.2))
59
+ if shiftrotate==True:
60
+ l.append(albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=45, border_mode=0, value=( int(0.485*255), int(0.456*255), int(0.406*255 )), p=0.3))
61
+ if random_contrast==True:
62
+ l.append(albumentations.RandomBrightnessContrast(p=0.3))
63
+ self.preprocessor = albumentations.Compose(l)
64
+ else:
65
+ self.preprocessor = lambda **kwargs: kwargs
66
+
67
+ def __len__(self):
68
+ return self._length
69
+
70
+ def preprocess_image(self, image_path):
71
+ image = Image.open(image_path)
72
+ if not image.mode == "RGB":
73
+ image = image.convert("RGB")
74
+ image = np.array(image).astype(np.uint8)
75
+ image = self.preprocessor(image=image)["image"]
76
+ image = (image/127.5 - 1.0).astype(np.float32)
77
+ return image
78
+
79
+ def __getitem__(self, i):
80
+ labels = self.labels if self.labels_without_skipped is None else self.labels_without_skipped
81
+ example = dict()
82
+ example["image"] = self.preprocess_image(labels["file_path_"][i])
83
+ for k in labels:
84
+ example[k] = labels[k][i]
85
+ return example
ldm/data/constants.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DISENTANGLER_DECODER_OUTPUT = 'output'
2
+ DISENTANGLER_ENCODER_INPUT = 'in'
3
+ QUANTIZED_PHYLO_OUTPUT = 'zq_phylo'
4
+ DISENTANGLER_CLASS_OUTPUT = 'class'
5
+ QUANTIZED_PHYLO_NONATTRIBUTE_OUTPUT = 'zq_phylo_nonattribute'
6
+ DISENTANGLER_NON_ATTRIBUTE_TO_ATTRIBUTE_OUTPUT = 'nonattribate_to_attribute'
7
+ DISENTANGLER_NON_ATTRIBUTE_CLASS_OUTPUT = 'adversarial_classifier_output'
8
+ DISENTANGLER_ADV_MAPPING_OUTPUT = 'adversarial_mapping_output'
9
+ DISENTANGLER_ADV_LEARNING_OUTPUT = 'adversarial_learning_output'
10
+ NON_CLASS_TENSORS = [DISENTANGLER_ADV_LEARNING_OUTPUT, DISENTANGLER_ADV_MAPPING_OUTPUT, DISENTANGLER_ENCODER_INPUT, DISENTANGLER_DECODER_OUTPUT, QUANTIZED_PHYLO_OUTPUT, DISENTANGLER_NON_ATTRIBUTE_TO_ATTRIBUTE_OUTPUT, QUANTIZED_PHYLO_NONATTRIBUTE_OUTPUT, DISENTANGLER_NON_ATTRIBUTE_CLASS_OUTPUT]
11
+
12
+ CLASS_TENSORS = [DISENTANGLER_CLASS_OUTPUT]
13
+
14
+ DATASET_CLASSNAME = 'class_name'
15
+
16
+ PHYLOCONFIG_KEY = "phylomodel_params"
17
+ LRFACTOR_KEY = "lr_factor"
18
+ LRCYCLE = "lr_cycle"
19
+ DISENTANGLERTYPE_KEY = 'disentangler_type'
20
+ COMPLETE_CKPT_KEY = "posttraining_ckpt"
21
+
22
+ HISTOGRAMS_FOLDER='code_histograms'
23
+ HISTOGRAMS_FILE="histograms.pkl"
24
+
25
+ DISENTANGLER_PHYLO_LOSS="/disentangler_phylo_loss"
26
+ TRANSFORMER_LOSS="/loss"
27
+ RECLOSS = "/rec_loss"
28
+ BASERECLOSS = "/base_true_rec_loss"
29
+
30
+ TEST_DIR="results_summary"
31
+
32
+ TSNE_FOLDER='tsne'
ldm/data/custom.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on https://github.com/CompVis/taming-transformers
2
+
3
+ import pickle
4
+ from torch.utils.data import Dataset
5
+ from ldm.data.base import ImagePaths
6
+ import ldm.data.constants as CONSTANTS
7
+
8
+
9
+
10
+ class CustomBase(Dataset):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__()
13
+ self.data = None
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, i):
19
+ example = self.data[i]
20
+ return example
21
+
22
+ class CustomTrain(CustomBase):
23
+ def __init__(self, size, training_images_list_file, horizontalflip=False, random_contrast=False, shiftrotate=False, add_labels=False, unique_skipped_labels=[], class_to_node=None):
24
+ super().__init__()
25
+ with open(training_images_list_file, "r") as f:
26
+ paths = sorted(f.read().splitlines())
27
+
28
+ labels=None
29
+ if add_labels:
30
+ labels_per_file = list(map(lambda path: path.split('/')[-2], paths))
31
+ labels_set = sorted(list(set(labels_per_file)))
32
+ self.labels_to_idx = {label_name: i for i, label_name in enumerate(labels_set)}
33
+
34
+ if class_to_node:
35
+ with open(class_to_node, 'rb') as pickle_file:
36
+ class_to_node_dict = pickle.load(pickle_file)
37
+ labels = {
38
+ CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
39
+ CONSTANTS.DATASET_CLASSNAME: labels_per_file,
40
+ 'class_to_node': [class_to_node_dict[label_name] for label_name in labels_per_file]
41
+ }
42
+ # labels = [self.labels_to_idx[label_name] for label_name in labels_per_file]
43
+
44
+ else:
45
+ labels = {
46
+ CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
47
+ CONSTANTS.DATASET_CLASSNAME: labels_per_file
48
+ }
49
+
50
+ self.indx_to_label = {v: k for k, v in self.labels_to_idx.items()}
51
+
52
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False, horizontalflip=horizontalflip,
53
+ random_contrast=random_contrast, shiftrotate=shiftrotate, labels=labels,
54
+ unique_skipped_labels=unique_skipped_labels)
55
+
56
+
57
+ class CustomTest(CustomTrain):
58
+ def __init__(self, size, test_images_list_file, add_labels=False, unique_skipped_labels=[], class_to_node=None):
59
+ super().__init__(size, test_images_list_file, add_labels=add_labels,
60
+ unique_skipped_labels=unique_skipped_labels, class_to_node=class_to_node)
61
+
62
+
ldm/data/custom_cub.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on https://github.com/CompVis/taming-transformers
2
+
3
+ import pickle
4
+ from torch.utils.data import Dataset
5
+ from ldm.data.base import ImagePaths
6
+ import ldm.data.constants as CONSTANTS
7
+
8
+
9
+
10
+ class CustomBase(Dataset):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__()
13
+ self.data = None
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, i):
19
+ example = self.data[i]
20
+ return example
21
+
22
+ class CustomTrain(CustomBase):
23
+ def __init__(self, size, training_images_list_file, horizontalflip=False, random_contrast=False, shiftrotate=False, add_labels=False, unique_skipped_labels=[], class_to_node=None):
24
+ super().__init__()
25
+ with open(training_images_list_file, "r") as f:
26
+ paths = sorted(f.read().splitlines())
27
+
28
+ labels=None
29
+ if add_labels:
30
+ labels_per_file = list(map(lambda path: path.split('/')[-2], paths))
31
+ # labels_per_file = [i.split('.')[1].replace('_', ' ') for i in labels_per_file]
32
+ labels_set = sorted(list(set(labels_per_file)))
33
+ self.labels_to_idx = {label_name: i for i, label_name in enumerate(labels_set)}
34
+
35
+ if class_to_node:
36
+ with open(class_to_node, 'rb') as pickle_file:
37
+ class_to_node_dict = pickle.load(pickle_file)
38
+ labels = {
39
+ CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
40
+ CONSTANTS.DATASET_CLASSNAME: [class_to_node_dict[label_name] for label_name in labels_per_file],
41
+ # 'class_to_node': [class_to_node_dict[label_name] for label_name in labels_per_file]
42
+ }
43
+ # labels = [self.labels_to_idx[label_name] for label_name in labels_per_file]
44
+
45
+ else:
46
+ labels = {
47
+ CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
48
+ CONSTANTS.DATASET_CLASSNAME: labels_per_file
49
+ }
50
+ self.indx_to_label = {v: k for k, v in self.labels_to_idx.items()}
51
+
52
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False, horizontalflip=horizontalflip,
53
+ random_contrast=random_contrast, shiftrotate=shiftrotate, labels=labels,
54
+ unique_skipped_labels=unique_skipped_labels)
55
+
56
+
57
+ class CustomTest(CustomTrain):
58
+ def __init__(self, size, test_images_list_file, add_labels=False, unique_skipped_labels=[], class_to_node=None):
59
+ super().__init__(size, test_images_list_file, add_labels=add_labels,
60
+ unique_skipped_labels=unique_skipped_labels, class_to_node=class_to_node)
61
+
62
+
ldm/data/i2sb_dataloader.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from torchvision.transforms import Compose, Resize, ToTensor
6
+ import imageio
7
+ from tqdm import tqdm
8
+
9
+ class pix2pixDataset(Dataset):
10
+ def __init__(self, dataset="maps", data_dir="/projects/ml4science/datasets_pix2pix/", split="train", normalize=True, transforms=None, preload=False, image_size=256, direction="BtoA"):
11
+ self.datadir = os.path.join(data_dir, dataset)
12
+ self.img_name_list_path = os.path.join(data_dir, dataset, split)
13
+ if not os.path.exists(self.datadir):
14
+ print(f'Dataset directory {self.datadir} does not exists')
15
+
16
+ self.normalize=normalize
17
+ self.image_name_list = os.listdir(self.img_name_list_path)
18
+ self.preload = preload
19
+ self.direction = direction
20
+ if transforms is None:
21
+ self.transforms = Compose([
22
+ ToTensor(), # Convert to torch tensor
23
+ Resize((image_size, image_size), antialias=False), # Resize to 256x256
24
+ ])
25
+ else:
26
+ self.transforms = transforms
27
+
28
+ if self.preload:
29
+ self.x_list, self.y_list= (), ()
30
+ for name in tqdm(self.image_name_list):
31
+ x, y = self.load_every(name)
32
+ self.x_list = self.x_list + (x,)
33
+ self.y_list = self.y_list + (y,)
34
+ self.x_list = torch.stack(self.x_list, 0)
35
+ self.y_list = torch.stack(self.y_list, 0)
36
+ print(f"{split} dataset preloaded!")
37
+
38
+ def load_every(self, name):
39
+ img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name)))
40
+ img_H, img_W = img_array.shape[0], img_array.shape[1]
41
+ if self.normalize:
42
+ img_array = self.normalize_fn(img_array)
43
+ x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :]
44
+ x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform
45
+ return x_img.float(), y_img.float()
46
+
47
+ def normalize_fn(self, x):
48
+ return (x/255. -0.5)*2
49
+
50
+ def unnormalize_fn(self, x):
51
+ return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images
52
+
53
+ def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation)
54
+ class_cond = None
55
+ if self.preload:
56
+ x_img, y_img = self.x_list[index], self.y_list[index]
57
+ else:
58
+ name = self.image_name_list[index]
59
+ x_img, y_img = self.load_every(name)
60
+ # if self.direction == "BtoA":
61
+ # return x_img, y_img, class_cond
62
+ # elif self.direction == "AtoB":
63
+ # return y_img, x_img, class_cond
64
+ batch ={
65
+ "image1":x_img,
66
+ "image2":y_img,
67
+ }
68
+ return batch
69
+
70
+ def __len__(self):
71
+ return len(self.image_name_list)
72
+
73
+ class FishDataset(Dataset):
74
+ def __init__(self, data_dir="/projects/ml4science/FishDiffusion/", split="train", normalize=True, transforms=None, preload=False, image_size=128):
75
+ self.datadir = os.path.join(data_dir)
76
+ self.img_name_list_path = os.path.join(data_dir, split)
77
+
78
+ if not os.path.exists(self.datadir):
79
+ print(f'Dataset directory {self.datadir} does not exists')
80
+
81
+ self.normalize=normalize
82
+ self.image_name_list = os.listdir(self.img_name_list_path)
83
+ self.preload = preload
84
+
85
+ if transforms is None:
86
+ # self.transforms = Compose([
87
+ # ToTensor(), # Convert to torch tensor
88
+ # Resize((image_size, image_size), antialias=False), # Resize to 256x256
89
+ # ])
90
+ self.transforms = Compose([
91
+ ToTensor(), # Convert to torch tensor
92
+ ])
93
+ else:
94
+ self.transforms = transforms
95
+
96
+ if self.preload:
97
+ self.x_list, self.y_list, self.class_id = (), (), []
98
+ for name in tqdm(self.image_name_list):
99
+ x, y = self.load_every(name)
100
+ cls_id = int(name.split("_")[-1][:-4])
101
+ self.x_list = self.x_list + (x,)
102
+ self.y_list = self.y_list + (y,)
103
+ self.class_id.append(cls_id)
104
+ self.x_list = torch.stack(self.x_list, 0)
105
+ self.y_list = torch.stack(self.y_list, 0)
106
+ self.class_id = torch.tensor(self.class_id)
107
+ print(f"{split} dataset preloaded!")
108
+
109
+ def load_every(self, name):
110
+ img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name)))
111
+ img_H, img_W = img_array.shape[0], img_array.shape[1]
112
+ if self.normalize:
113
+ img_array = self.normalize_fn(img_array)
114
+ x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :]
115
+ x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform
116
+ return x_img.float(), y_img.float()
117
+
118
+ def normalize_fn(self, x):
119
+ return (x/255. -0.5)*2
120
+
121
+ def unnormalize_fn(self, x):
122
+ return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images
123
+
124
+ def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation)
125
+ if self.preload:
126
+ x_img, y_img, class_id = self.x_list[index], self.y_list[index], self.class_id[index]
127
+ else:
128
+ name = self.image_name_list[index]
129
+ class_id = torch.tensor(int(name.split("_")[-1][:-4]))
130
+ x_img, y_img = self.load_every(name)
131
+ return x_img, y_img, class_id
132
+
133
+ def __len__(self):
134
+ return len(self.image_name_list)
ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
ldm/data/phylogeny.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import math
4
+ import pickle
5
+ import pprint
6
+ pp = pprint.PrettyPrinter(indent=4)
7
+
8
+ # For phylogeny parsing
9
+ # !pip install opentree
10
+ from opentree import OT
11
+ # !pip install ete3
12
+ from ete3 import Tree, PhyloTree
13
+
14
+ # Constants
15
+ Fix_Tree = True
16
+ format_ = 1 #8
17
+
18
+ class Phylogeny:
19
+ # Phylogeny class for Fish dataset
20
+ # If node_ids is None, it assumes that the tree already exists. Otherwise, you have to pass node_ids (i.e., list of species names).
21
+ def __init__(self, filePath, node_ids=None, verbose=False):
22
+ # filenames for phylo tree and cached mapping ottid-speciesname
23
+ cleaned_fine_tree_fileName = "cleaned_metadata.tre"
24
+ name_conversion_file = "name_conversion.pkl"
25
+ self.ott_ids = []
26
+ self.ott_id_dict = {}
27
+ self.node_ids = node_ids
28
+ self.treeFileNameAndPath = os.path.join(filePath, cleaned_fine_tree_fileName)
29
+ self.conversionFileNameAndPath = os.path.join(filePath, name_conversion_file)
30
+ self.total_distance = -1 # -1 means we never calculated it before.
31
+
32
+ self.distance_matrix = {}
33
+ self.species_groups_within_relative_distance = {}
34
+
35
+ self.get_ott_ids(node_ids, verbose=verbose)
36
+ self.get_tree(self.treeFileNameAndPath)
37
+ self.get_total_distance()
38
+
39
+ # Given two species names, get the phylo distance between them
40
+ def get_distance(self, species1, species2):
41
+ d= None
42
+ if self.distance_matrix[species1][species2] == -1:
43
+ if species1 == species2:
44
+ return 0
45
+
46
+ ott_id1 = 'ott' + str(self.ott_id_dict[species1])
47
+ ott_id2 = 'ott' + str(self.ott_id_dict[species2])
48
+ d = self.tree.get_distance(ott_id1, ott_id2)
49
+
50
+ self.distance_matrix[species1][species2] = d
51
+ else:
52
+ d = self.distance_matrix[species1][species2]
53
+
54
+ return d
55
+
56
+ # relative_distance = 0 => species node itself
57
+ # relative_distance = 1 => all species
58
+ def get_siblings_by_name(self, species, relative_distance, verbose=False):
59
+ self.get_species_groups(relative_distance, verbose)
60
+ for species_group in self.species_groups_within_relative_distance[relative_distance]:
61
+ if species in species_group:
62
+ return species_group
63
+
64
+ raise species+" was not found in " + self.species_groups_within_relative_distance[relative_distance]
65
+
66
+ def get_parent_by_name(self, species, relative_distance, verbose=False):
67
+ ott_id = 'ott' + str(self.ott_id_dict[species])
68
+ parent = self.get_parent_by_ottid(ott_id, relative_distance, verbose)
69
+ return parent
70
+
71
+ def get_distance_between_parents(self, species1, species2, relative_distance):
72
+ parent1 = self.get_parent_by_name(species1, relative_distance)
73
+ parent2 = self.get_parent_by_name(species2, relative_distance)
74
+ return self.tree.get_distance(parent1, parent2)
75
+
76
+ def get_species_groups(self, relative_distance, verbose=False):
77
+ if relative_distance not in self.species_groups_within_relative_distance.keys():
78
+ groups = {}
79
+
80
+ for species in self.getLabelList():
81
+ parent_node = self.get_parent_by_name(species, relative_distance, verbose)
82
+ parent = parent_node.name
83
+ if parent not in groups.keys():
84
+ groups[parent] = [species]
85
+ else:
86
+ groups[parent].append(species)
87
+
88
+ self.species_groups_within_relative_distance[relative_distance] = groups.values()
89
+
90
+ if verbose:
91
+ print("At relative_distance", relative_distance, ", the groups are:", groups.values())
92
+
93
+ return self.species_groups_within_relative_distance[relative_distance]
94
+
95
+
96
+
97
+ def getLabelList(self):
98
+ return list(self.node_ids)
99
+
100
+
101
+ # ------- privete functions
102
+
103
+ def get_total_distance(self):
104
+ if self.node_ids is None:
105
+ self.node_ids = self.ott_id_dict.keys()
106
+
107
+ self.init_distance_matrix()
108
+
109
+ # For one time, measure distance from all leaves down to root. They all should be equal.
110
+ # Save the value and reuse it.
111
+
112
+ if self.total_distance==-1:
113
+ for leaf in self.tree.iter_leaves():
114
+ total_distance = self.tree.get_distance(leaf) # gets distance to rootprint
115
+ assert math.isclose(self.total_distance, total_distance) or self.total_distance==-1
116
+ self.total_distance = total_distance
117
+
118
+ return self.total_distance
119
+
120
+ def init_distance_matrix(self):
121
+ for i in self.node_ids:
122
+ self.distance_matrix[i] = {}
123
+ for j in self.node_ids:
124
+ self.distance_matrix[i][j] = -1
125
+
126
+ def get_parent_by_ottid(self, ott_id, relative_distance, verbose=False):
127
+ abs_distance = relative_distance*self.total_distance
128
+ species_node = self.tree.search_nodes(name=ott_id)[0]
129
+ if verbose:
130
+ print('distance to ancestor: ', abs_distance, ". relaive distance: ", relative_distance)
131
+
132
+ # keep going up till distance exceeds abs_distance
133
+ distance = 0
134
+ parent = species_node
135
+ while distance < abs_distance:
136
+ if parent.up is None:
137
+ break
138
+ parent = parent.up
139
+ distance = self.tree.get_distance(parent, species_node)
140
+
141
+ return parent
142
+
143
+
144
+
145
+ # return ott_id_list
146
+ # node_ids: list of taxa
147
+ # returns: corresponding list of ott_ids
148
+ def get_ott_ids(self, node_ids, verbose=False):
149
+ if not os.path.exists(self.conversionFileNameAndPath):
150
+ if node_ids is None:
151
+ raise TypeError('No existing ottid-speciesnames found. node_ids should be a list of species names.')
152
+ if verbose:
153
+ print('Included taxonomy: ', node_ids, len(node_ids))
154
+ df2 = pd.DataFrame(columns=['in csv', 'in response', 'Same?'])
155
+
156
+ # Get the matches
157
+ resp = OT.tnrs_match(node_ids, do_approximate_matching=True)
158
+ matches = resp.response_dict['results']
159
+ unmatched_names = resp.response_dict['unmatched_names']
160
+
161
+ # Get the corresponding ott_ids
162
+ ott_ids = set()
163
+ ott_id_dict={}
164
+ assert len(unmatched_names)==0 # everything is matched!
165
+ for match_array in matches:
166
+ match_array_matches = match_array['matches']
167
+ assert len(match_array_matches)==1, match_array['name'] + " has too many matches" + str(list(map(lambda x: x['matched_name'], match_array_matches))) # we have a single unambiguous match!
168
+ first_match = match_array_matches[0]
169
+ ott_id = first_match['taxon']['ott_id']
170
+ ott_ids.add(ott_id)
171
+ if verbose:
172
+ #some original and matched names are not exactly the same. Not a bug
173
+ df2 = df2.append({'in csv':match_array['name'], 'in response': first_match['matched_name'], 'Same?': match_array['name'] == first_match['matched_name']}, ignore_index=True)
174
+ ott_id_dict[match_array['name']] = ott_id
175
+ ott_ids = list(ott_ids)
176
+
177
+ if verbose:
178
+ print(df2[df2['Same?']== False])
179
+ pp.pprint(ott_id_dict)
180
+
181
+ with open(self.conversionFileNameAndPath, 'wb') as f:
182
+ pickle.dump([ott_ids, ott_id_dict], f)
183
+ else:
184
+ with open(self.conversionFileNameAndPath, 'rb') as f:
185
+ ott_ids, ott_id_dict = pickle.load(f)
186
+
187
+
188
+
189
+ self.ott_ids = ott_ids
190
+ self.ott_id_dict = ott_id_dict
191
+ print(self.ott_id_dict)
192
+
193
+ def fix_tree(self, treeFileNameAndPath):
194
+ tree = PhyloTree(treeFileNameAndPath, format=format_)
195
+
196
+ # Special case for Fish dataset: Fix Esox Americanus.
197
+ D = tree.search_nodes(name="mrcaott47023ott496121")[0]
198
+ D.name = "ott496115"
199
+ tree.write(format=format_, outfile=treeFileNameAndPath)
200
+
201
+ def get_tree(self, treeFileNameAndPath):
202
+ if not os.path.exists(treeFileNameAndPath):
203
+ output = OT.synth_induced_tree(ott_ids=self.ott_ids, ignore_unknown_ids=False, label_format='id') # name_and_id ott_ids=list(ott_ids),
204
+
205
+ output.tree.write(path = treeFileNameAndPath, schema = "newick")
206
+
207
+ if Fix_Tree:
208
+ self.fix_tree(treeFileNameAndPath)
209
+
210
+ self.tree = PhyloTree(treeFileNameAndPath, format=format_)
211
+
212
+ class PhylogenyCUB:
213
+ # Phylogeny class for CUB dataset
214
+ def __init__(self, filePath, node_ids=None, verbose=False):
215
+ # cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-AllSpecies.phy"
216
+ # cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-AllSpecies-cub-names.phy"
217
+ cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-27Species-cub-names.phy"
218
+ self.node_ids = node_ids
219
+ self.treeFileNameAndPath = os.path.join(filePath, cleaned_fine_tree_fileName)
220
+ self.total_distance = -1 # -1 means we never calculated it before.
221
+
222
+ self.distance_matrix = {}
223
+ self.species_groups_within_relative_distance = {}
224
+
225
+ self.get_tree(self.treeFileNameAndPath)
226
+ self.get_total_distance()
227
+
228
+ # Given two species names, get the phylo distance between them
229
+ def get_distance(self, species1, species2):
230
+ d= None
231
+ if self.distance_matrix[species1][species2] == -1:
232
+ if species1 == species2:
233
+ return 0
234
+ d = self.tree.get_distance(species1, species2)
235
+
236
+ self.distance_matrix[species1][species2] = d
237
+ else:
238
+ d = self.distance_matrix[species1][species2]
239
+
240
+ return d
241
+
242
+ # relative_distance = 0 => species node itself
243
+ # relative_distance = 1 => all species
244
+ def get_siblings_by_name(self, species, relative_distance, verbose=False):
245
+ #NOTE: This implementation was causing inconsistencies since finding the parent.get_leaves() was not equivalent to get_species_groups
246
+ # ott_id = 'ott' + str(self.ott_id_dict[species])
247
+ # return self.get_siblings_by_ottid(ott_id, relative_distance, get_ottids, verbose)
248
+
249
+ self.get_species_groups(relative_distance, verbose)
250
+ for species_group in self.species_groups_within_relative_distance[relative_distance]:
251
+ if species in species_group:
252
+ return species_group
253
+
254
+ raise species+" was not found in " + self.species_groups_within_relative_distance[relative_distance]
255
+
256
+ def get_parent_by_name(self, species, relative_distance, verbose=False):
257
+ abs_distance = relative_distance*self.total_distance
258
+ species_node = self.tree.search_nodes(name=species)[0]
259
+ if verbose:
260
+ print('distance to ancestor: ', abs_distance, ". relaive distance: ", relative_distance)
261
+
262
+ # keep going up till distance exceeds abs_distance
263
+ distance = 0
264
+ parent = species_node
265
+ while distance < abs_distance:
266
+ if parent.up is None:
267
+ break
268
+ parent = parent.up
269
+ distance = self.tree.get_distance(parent, species_node)
270
+
271
+ return parent
272
+
273
+ def get_distance_between_parents(self, species1, species2, relative_distance):
274
+ parent1 = self.get_parent_by_name(species1, relative_distance)
275
+ parent2 = self.get_parent_by_name(species2, relative_distance)
276
+ return self.tree.get_distance(parent1, parent2)
277
+
278
+ def get_species_groups(self, relative_distance, verbose=False):
279
+ if relative_distance not in self.species_groups_within_relative_distance.keys():
280
+ groups = {}
281
+
282
+ for species in self.getLabelList():
283
+ parent_node = self.get_parent_by_name(species, relative_distance, verbose)
284
+ parent = parent_node.name
285
+ if parent not in groups.keys():
286
+ groups[parent] = [species]
287
+ else:
288
+ groups[parent].append(species)
289
+
290
+ self.species_groups_within_relative_distance[relative_distance] = groups.values()
291
+
292
+ if verbose:
293
+ print("At relative_distance", relative_distance, ", the groups are:", groups.values())
294
+
295
+ return self.species_groups_within_relative_distance[relative_distance]
296
+
297
+
298
+ def getLabelList(self):
299
+ return list(self.node_ids)
300
+
301
+
302
+ # ------- privete functions
303
+
304
+ def get_total_distance(self):
305
+ if self.node_ids is None:
306
+ self.node_ids = sorted([leaf.name for leaf in self.tree.iter_leaves()])
307
+
308
+ self.init_distance_matrix()
309
+
310
+ # maximum distance between root and lead node taken as total distance
311
+ leaf_to_root_distances = [self.tree.get_distance(leaf) for leaf in self.tree.iter_leaves()]
312
+ self.total_distance = max(leaf_to_root_distances)
313
+
314
+ return self.total_distance
315
+
316
+ def init_distance_matrix(self):
317
+ for i in self.node_ids:
318
+ self.distance_matrix[i] = {}
319
+ for j in self.node_ids:
320
+ self.distance_matrix[i][j] = -1
321
+
322
+ def get_tree(self, treeFileNameAndPath):
323
+ # if not os.path.exists(treeFileNameAndPath):
324
+ # output = OT.synth_induced_tree(ott_ids=self.ott_ids, ignore_unknown_ids=False, label_format='id') # name_and_id ott_ids=list(ott_ids),
325
+
326
+ # output.tree.write(path = treeFileNameAndPath, schema = "newick")
327
+
328
+ self.tree = PhyloTree(treeFileNameAndPath, format=format_)
329
+
330
+ # setting a dummy name to the internal nodes if it is unnamed
331
+ for i, node in enumerate(self.tree.traverse("postorder")):
332
+ if not len(node.name) > 0:
333
+ node.name = str(i)
ldm/data/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on https://github.com/CompVis/taming-transformers
2
+
3
+ import collections
4
+
5
+ import torch
6
+ from ldm.data.helper_types import Annotation
7
+ from torch._six import string_classes
8
+ from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
9
+
10
+
11
+ def custom_collate(batch):
12
+ r"""source: pytorch 1.9.0, only one modification to original code """
13
+
14
+ elem = batch[0]
15
+ elem_type = type(elem)
16
+ if isinstance(elem, torch.Tensor):
17
+ out = None
18
+ if torch.utils.data.get_worker_info() is not None:
19
+ # If we're in a background process, concatenate directly into a
20
+ # shared memory tensor to avoid an extra copy
21
+ numel = sum([x.numel() for x in batch])
22
+ storage = elem.storage()._new_shared(numel)
23
+ out = elem.new(storage)
24
+ return torch.stack(batch, 0, out=out)
25
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
26
+ and elem_type.__name__ != 'string_':
27
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
28
+ # array of string classes and object
29
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
30
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
31
+
32
+ return custom_collate([torch.as_tensor(b) for b in batch])
33
+ elif elem.shape == (): # scalars
34
+ return torch.as_tensor(batch)
35
+ elif isinstance(elem, float):
36
+ return torch.tensor(batch, dtype=torch.float64)
37
+ elif isinstance(elem, int):
38
+ return torch.tensor(batch)
39
+ elif isinstance(elem, string_classes):
40
+ return batch
41
+ elif isinstance(elem, collections.abc.Mapping):
42
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
43
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
44
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
45
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
46
+ return batch # added
47
+ elif isinstance(elem, collections.abc.Sequence):
48
+ # check to make sure that the elements in batch have consistent size
49
+ it = iter(batch)
50
+ elem_size = len(next(it))
51
+ if not all(len(elem) == elem_size for elem in it):
52
+ raise RuntimeError('each element in list of batch should be of equal size')
53
+ transposed = zip(*batch)
54
+ return [custom_collate(samples) for samples in transposed]
55
+
56
+ raise TypeError(default_collate_err_msg_format.format(elem_type))