Spaces:
Sleeping
Sleeping
added data
Browse files- ldm/data/__init__.py +0 -0
- ldm/data/__pycache__/__init__.cpython-38.pyc +0 -0
- ldm/data/__pycache__/base.cpython-38.pyc +0 -0
- ldm/data/__pycache__/constants.cpython-38.pyc +0 -0
- ldm/data/__pycache__/custom.cpython-38.pyc +0 -0
- ldm/data/__pycache__/custom_cub.cpython-38.pyc +0 -0
- ldm/data/__pycache__/i2sb_dataloader.cpython-38.pyc +0 -0
- ldm/data/__pycache__/imagenet.cpython-38.pyc +0 -0
- ldm/data/__pycache__/phylogeny.cpython-38.pyc +0 -0
- ldm/data/__pycache__/utils.cpython-38.pyc +0 -0
- ldm/data/base.py +85 -0
- ldm/data/constants.py +32 -0
- ldm/data/custom.py +62 -0
- ldm/data/custom_cub.py +62 -0
- ldm/data/i2sb_dataloader.py +134 -0
- ldm/data/imagenet.py +394 -0
- ldm/data/lsun.py +92 -0
- ldm/data/phylogeny.py +333 -0
- ldm/data/utils.py +56 -0
ldm/data/__init__.py
ADDED
File without changes
|
ldm/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (141 Bytes). View file
|
|
ldm/data/__pycache__/base.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
ldm/data/__pycache__/constants.cpython-38.pyc
ADDED
Binary file (1.24 kB). View file
|
|
ldm/data/__pycache__/custom.cpython-38.pyc
ADDED
Binary file (3.16 kB). View file
|
|
ldm/data/__pycache__/custom_cub.cpython-38.pyc
ADDED
Binary file (3.15 kB). View file
|
|
ldm/data/__pycache__/i2sb_dataloader.cpython-38.pyc
ADDED
Binary file (4.52 kB). View file
|
|
ldm/data/__pycache__/imagenet.cpython-38.pyc
ADDED
Binary file (14.5 kB). View file
|
|
ldm/data/__pycache__/phylogeny.cpython-38.pyc
ADDED
Binary file (8.57 kB). View file
|
|
ldm/data/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.61 kB). View file
|
|
ldm/data/base.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
3 |
+
import numpy as np
|
4 |
+
import albumentations
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class Txt2ImgIterableBaseDataset(IterableDataset):
|
11 |
+
'''
|
12 |
+
Define an interface to make the IterableDatasets for text2img data chainable
|
13 |
+
'''
|
14 |
+
def __init__(self, num_records=0, valid_ids=None, size=256):
|
15 |
+
super().__init__()
|
16 |
+
self.num_records = num_records
|
17 |
+
self.valid_ids = valid_ids
|
18 |
+
self.sample_ids = valid_ids
|
19 |
+
self.size = size
|
20 |
+
|
21 |
+
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return self.num_records
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def __iter__(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
class ImagePaths(Dataset):
|
31 |
+
def __init__(self, paths, size=None, random_crop=False, horizontalflip=False, random_contrast=False, shiftrotate=False, labels=None, unique_skipped_labels=[]):
|
32 |
+
self.size = size
|
33 |
+
self.random_crop = random_crop
|
34 |
+
|
35 |
+
self.labels = dict() if labels is None else labels
|
36 |
+
self.labels["file_path_"] = paths
|
37 |
+
self._length = len(paths)
|
38 |
+
|
39 |
+
self.labels_without_skipped = None
|
40 |
+
if len(unique_skipped_labels)!=0:
|
41 |
+
self.labels_without_skipped = dict()
|
42 |
+
for i in self.labels.keys():
|
43 |
+
self.labels_without_skipped[i] = [a for indx, a in enumerate(labels[i]) if labels['class'][indx] not in unique_skipped_labels]
|
44 |
+
self._length = len(self.labels_without_skipped['class'])
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
if self.size is not None and self.size > 0:
|
50 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
|
51 |
+
l = [self.rescaler ]
|
52 |
+
if not self.random_crop:
|
53 |
+
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
|
54 |
+
else:
|
55 |
+
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
|
56 |
+
l.append(self.cropper)
|
57 |
+
if horizontalflip==True:
|
58 |
+
l.append(albumentations.HorizontalFlip(p=0.2))
|
59 |
+
if shiftrotate==True:
|
60 |
+
l.append(albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=45, border_mode=0, value=( int(0.485*255), int(0.456*255), int(0.406*255 )), p=0.3))
|
61 |
+
if random_contrast==True:
|
62 |
+
l.append(albumentations.RandomBrightnessContrast(p=0.3))
|
63 |
+
self.preprocessor = albumentations.Compose(l)
|
64 |
+
else:
|
65 |
+
self.preprocessor = lambda **kwargs: kwargs
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
return self._length
|
69 |
+
|
70 |
+
def preprocess_image(self, image_path):
|
71 |
+
image = Image.open(image_path)
|
72 |
+
if not image.mode == "RGB":
|
73 |
+
image = image.convert("RGB")
|
74 |
+
image = np.array(image).astype(np.uint8)
|
75 |
+
image = self.preprocessor(image=image)["image"]
|
76 |
+
image = (image/127.5 - 1.0).astype(np.float32)
|
77 |
+
return image
|
78 |
+
|
79 |
+
def __getitem__(self, i):
|
80 |
+
labels = self.labels if self.labels_without_skipped is None else self.labels_without_skipped
|
81 |
+
example = dict()
|
82 |
+
example["image"] = self.preprocess_image(labels["file_path_"][i])
|
83 |
+
for k in labels:
|
84 |
+
example[k] = labels[k][i]
|
85 |
+
return example
|
ldm/data/constants.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DISENTANGLER_DECODER_OUTPUT = 'output'
|
2 |
+
DISENTANGLER_ENCODER_INPUT = 'in'
|
3 |
+
QUANTIZED_PHYLO_OUTPUT = 'zq_phylo'
|
4 |
+
DISENTANGLER_CLASS_OUTPUT = 'class'
|
5 |
+
QUANTIZED_PHYLO_NONATTRIBUTE_OUTPUT = 'zq_phylo_nonattribute'
|
6 |
+
DISENTANGLER_NON_ATTRIBUTE_TO_ATTRIBUTE_OUTPUT = 'nonattribate_to_attribute'
|
7 |
+
DISENTANGLER_NON_ATTRIBUTE_CLASS_OUTPUT = 'adversarial_classifier_output'
|
8 |
+
DISENTANGLER_ADV_MAPPING_OUTPUT = 'adversarial_mapping_output'
|
9 |
+
DISENTANGLER_ADV_LEARNING_OUTPUT = 'adversarial_learning_output'
|
10 |
+
NON_CLASS_TENSORS = [DISENTANGLER_ADV_LEARNING_OUTPUT, DISENTANGLER_ADV_MAPPING_OUTPUT, DISENTANGLER_ENCODER_INPUT, DISENTANGLER_DECODER_OUTPUT, QUANTIZED_PHYLO_OUTPUT, DISENTANGLER_NON_ATTRIBUTE_TO_ATTRIBUTE_OUTPUT, QUANTIZED_PHYLO_NONATTRIBUTE_OUTPUT, DISENTANGLER_NON_ATTRIBUTE_CLASS_OUTPUT]
|
11 |
+
|
12 |
+
CLASS_TENSORS = [DISENTANGLER_CLASS_OUTPUT]
|
13 |
+
|
14 |
+
DATASET_CLASSNAME = 'class_name'
|
15 |
+
|
16 |
+
PHYLOCONFIG_KEY = "phylomodel_params"
|
17 |
+
LRFACTOR_KEY = "lr_factor"
|
18 |
+
LRCYCLE = "lr_cycle"
|
19 |
+
DISENTANGLERTYPE_KEY = 'disentangler_type'
|
20 |
+
COMPLETE_CKPT_KEY = "posttraining_ckpt"
|
21 |
+
|
22 |
+
HISTOGRAMS_FOLDER='code_histograms'
|
23 |
+
HISTOGRAMS_FILE="histograms.pkl"
|
24 |
+
|
25 |
+
DISENTANGLER_PHYLO_LOSS="/disentangler_phylo_loss"
|
26 |
+
TRANSFORMER_LOSS="/loss"
|
27 |
+
RECLOSS = "/rec_loss"
|
28 |
+
BASERECLOSS = "/base_true_rec_loss"
|
29 |
+
|
30 |
+
TEST_DIR="results_summary"
|
31 |
+
|
32 |
+
TSNE_FOLDER='tsne'
|
ldm/data/custom.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#based on https://github.com/CompVis/taming-transformers
|
2 |
+
|
3 |
+
import pickle
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from ldm.data.base import ImagePaths
|
6 |
+
import ldm.data.constants as CONSTANTS
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class CustomBase(Dataset):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super().__init__()
|
13 |
+
self.data = None
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.data)
|
17 |
+
|
18 |
+
def __getitem__(self, i):
|
19 |
+
example = self.data[i]
|
20 |
+
return example
|
21 |
+
|
22 |
+
class CustomTrain(CustomBase):
|
23 |
+
def __init__(self, size, training_images_list_file, horizontalflip=False, random_contrast=False, shiftrotate=False, add_labels=False, unique_skipped_labels=[], class_to_node=None):
|
24 |
+
super().__init__()
|
25 |
+
with open(training_images_list_file, "r") as f:
|
26 |
+
paths = sorted(f.read().splitlines())
|
27 |
+
|
28 |
+
labels=None
|
29 |
+
if add_labels:
|
30 |
+
labels_per_file = list(map(lambda path: path.split('/')[-2], paths))
|
31 |
+
labels_set = sorted(list(set(labels_per_file)))
|
32 |
+
self.labels_to_idx = {label_name: i for i, label_name in enumerate(labels_set)}
|
33 |
+
|
34 |
+
if class_to_node:
|
35 |
+
with open(class_to_node, 'rb') as pickle_file:
|
36 |
+
class_to_node_dict = pickle.load(pickle_file)
|
37 |
+
labels = {
|
38 |
+
CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
|
39 |
+
CONSTANTS.DATASET_CLASSNAME: labels_per_file,
|
40 |
+
'class_to_node': [class_to_node_dict[label_name] for label_name in labels_per_file]
|
41 |
+
}
|
42 |
+
# labels = [self.labels_to_idx[label_name] for label_name in labels_per_file]
|
43 |
+
|
44 |
+
else:
|
45 |
+
labels = {
|
46 |
+
CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
|
47 |
+
CONSTANTS.DATASET_CLASSNAME: labels_per_file
|
48 |
+
}
|
49 |
+
|
50 |
+
self.indx_to_label = {v: k for k, v in self.labels_to_idx.items()}
|
51 |
+
|
52 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False, horizontalflip=horizontalflip,
|
53 |
+
random_contrast=random_contrast, shiftrotate=shiftrotate, labels=labels,
|
54 |
+
unique_skipped_labels=unique_skipped_labels)
|
55 |
+
|
56 |
+
|
57 |
+
class CustomTest(CustomTrain):
|
58 |
+
def __init__(self, size, test_images_list_file, add_labels=False, unique_skipped_labels=[], class_to_node=None):
|
59 |
+
super().__init__(size, test_images_list_file, add_labels=add_labels,
|
60 |
+
unique_skipped_labels=unique_skipped_labels, class_to_node=class_to_node)
|
61 |
+
|
62 |
+
|
ldm/data/custom_cub.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#based on https://github.com/CompVis/taming-transformers
|
2 |
+
|
3 |
+
import pickle
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from ldm.data.base import ImagePaths
|
6 |
+
import ldm.data.constants as CONSTANTS
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class CustomBase(Dataset):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super().__init__()
|
13 |
+
self.data = None
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.data)
|
17 |
+
|
18 |
+
def __getitem__(self, i):
|
19 |
+
example = self.data[i]
|
20 |
+
return example
|
21 |
+
|
22 |
+
class CustomTrain(CustomBase):
|
23 |
+
def __init__(self, size, training_images_list_file, horizontalflip=False, random_contrast=False, shiftrotate=False, add_labels=False, unique_skipped_labels=[], class_to_node=None):
|
24 |
+
super().__init__()
|
25 |
+
with open(training_images_list_file, "r") as f:
|
26 |
+
paths = sorted(f.read().splitlines())
|
27 |
+
|
28 |
+
labels=None
|
29 |
+
if add_labels:
|
30 |
+
labels_per_file = list(map(lambda path: path.split('/')[-2], paths))
|
31 |
+
# labels_per_file = [i.split('.')[1].replace('_', ' ') for i in labels_per_file]
|
32 |
+
labels_set = sorted(list(set(labels_per_file)))
|
33 |
+
self.labels_to_idx = {label_name: i for i, label_name in enumerate(labels_set)}
|
34 |
+
|
35 |
+
if class_to_node:
|
36 |
+
with open(class_to_node, 'rb') as pickle_file:
|
37 |
+
class_to_node_dict = pickle.load(pickle_file)
|
38 |
+
labels = {
|
39 |
+
CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
|
40 |
+
CONSTANTS.DATASET_CLASSNAME: [class_to_node_dict[label_name] for label_name in labels_per_file],
|
41 |
+
# 'class_to_node': [class_to_node_dict[label_name] for label_name in labels_per_file]
|
42 |
+
}
|
43 |
+
# labels = [self.labels_to_idx[label_name] for label_name in labels_per_file]
|
44 |
+
|
45 |
+
else:
|
46 |
+
labels = {
|
47 |
+
CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
|
48 |
+
CONSTANTS.DATASET_CLASSNAME: labels_per_file
|
49 |
+
}
|
50 |
+
self.indx_to_label = {v: k for k, v in self.labels_to_idx.items()}
|
51 |
+
|
52 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False, horizontalflip=horizontalflip,
|
53 |
+
random_contrast=random_contrast, shiftrotate=shiftrotate, labels=labels,
|
54 |
+
unique_skipped_labels=unique_skipped_labels)
|
55 |
+
|
56 |
+
|
57 |
+
class CustomTest(CustomTrain):
|
58 |
+
def __init__(self, size, test_images_list_file, add_labels=False, unique_skipped_labels=[], class_to_node=None):
|
59 |
+
super().__init__(size, test_images_list_file, add_labels=add_labels,
|
60 |
+
unique_skipped_labels=unique_skipped_labels, class_to_node=class_to_node)
|
61 |
+
|
62 |
+
|
ldm/data/i2sb_dataloader.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.transforms import Compose, Resize, ToTensor
|
6 |
+
import imageio
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
class pix2pixDataset(Dataset):
|
10 |
+
def __init__(self, dataset="maps", data_dir="/projects/ml4science/datasets_pix2pix/", split="train", normalize=True, transforms=None, preload=False, image_size=256, direction="BtoA"):
|
11 |
+
self.datadir = os.path.join(data_dir, dataset)
|
12 |
+
self.img_name_list_path = os.path.join(data_dir, dataset, split)
|
13 |
+
if not os.path.exists(self.datadir):
|
14 |
+
print(f'Dataset directory {self.datadir} does not exists')
|
15 |
+
|
16 |
+
self.normalize=normalize
|
17 |
+
self.image_name_list = os.listdir(self.img_name_list_path)
|
18 |
+
self.preload = preload
|
19 |
+
self.direction = direction
|
20 |
+
if transforms is None:
|
21 |
+
self.transforms = Compose([
|
22 |
+
ToTensor(), # Convert to torch tensor
|
23 |
+
Resize((image_size, image_size), antialias=False), # Resize to 256x256
|
24 |
+
])
|
25 |
+
else:
|
26 |
+
self.transforms = transforms
|
27 |
+
|
28 |
+
if self.preload:
|
29 |
+
self.x_list, self.y_list= (), ()
|
30 |
+
for name in tqdm(self.image_name_list):
|
31 |
+
x, y = self.load_every(name)
|
32 |
+
self.x_list = self.x_list + (x,)
|
33 |
+
self.y_list = self.y_list + (y,)
|
34 |
+
self.x_list = torch.stack(self.x_list, 0)
|
35 |
+
self.y_list = torch.stack(self.y_list, 0)
|
36 |
+
print(f"{split} dataset preloaded!")
|
37 |
+
|
38 |
+
def load_every(self, name):
|
39 |
+
img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name)))
|
40 |
+
img_H, img_W = img_array.shape[0], img_array.shape[1]
|
41 |
+
if self.normalize:
|
42 |
+
img_array = self.normalize_fn(img_array)
|
43 |
+
x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :]
|
44 |
+
x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform
|
45 |
+
return x_img.float(), y_img.float()
|
46 |
+
|
47 |
+
def normalize_fn(self, x):
|
48 |
+
return (x/255. -0.5)*2
|
49 |
+
|
50 |
+
def unnormalize_fn(self, x):
|
51 |
+
return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images
|
52 |
+
|
53 |
+
def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation)
|
54 |
+
class_cond = None
|
55 |
+
if self.preload:
|
56 |
+
x_img, y_img = self.x_list[index], self.y_list[index]
|
57 |
+
else:
|
58 |
+
name = self.image_name_list[index]
|
59 |
+
x_img, y_img = self.load_every(name)
|
60 |
+
# if self.direction == "BtoA":
|
61 |
+
# return x_img, y_img, class_cond
|
62 |
+
# elif self.direction == "AtoB":
|
63 |
+
# return y_img, x_img, class_cond
|
64 |
+
batch ={
|
65 |
+
"image1":x_img,
|
66 |
+
"image2":y_img,
|
67 |
+
}
|
68 |
+
return batch
|
69 |
+
|
70 |
+
def __len__(self):
|
71 |
+
return len(self.image_name_list)
|
72 |
+
|
73 |
+
class FishDataset(Dataset):
|
74 |
+
def __init__(self, data_dir="/projects/ml4science/FishDiffusion/", split="train", normalize=True, transforms=None, preload=False, image_size=128):
|
75 |
+
self.datadir = os.path.join(data_dir)
|
76 |
+
self.img_name_list_path = os.path.join(data_dir, split)
|
77 |
+
|
78 |
+
if not os.path.exists(self.datadir):
|
79 |
+
print(f'Dataset directory {self.datadir} does not exists')
|
80 |
+
|
81 |
+
self.normalize=normalize
|
82 |
+
self.image_name_list = os.listdir(self.img_name_list_path)
|
83 |
+
self.preload = preload
|
84 |
+
|
85 |
+
if transforms is None:
|
86 |
+
# self.transforms = Compose([
|
87 |
+
# ToTensor(), # Convert to torch tensor
|
88 |
+
# Resize((image_size, image_size), antialias=False), # Resize to 256x256
|
89 |
+
# ])
|
90 |
+
self.transforms = Compose([
|
91 |
+
ToTensor(), # Convert to torch tensor
|
92 |
+
])
|
93 |
+
else:
|
94 |
+
self.transforms = transforms
|
95 |
+
|
96 |
+
if self.preload:
|
97 |
+
self.x_list, self.y_list, self.class_id = (), (), []
|
98 |
+
for name in tqdm(self.image_name_list):
|
99 |
+
x, y = self.load_every(name)
|
100 |
+
cls_id = int(name.split("_")[-1][:-4])
|
101 |
+
self.x_list = self.x_list + (x,)
|
102 |
+
self.y_list = self.y_list + (y,)
|
103 |
+
self.class_id.append(cls_id)
|
104 |
+
self.x_list = torch.stack(self.x_list, 0)
|
105 |
+
self.y_list = torch.stack(self.y_list, 0)
|
106 |
+
self.class_id = torch.tensor(self.class_id)
|
107 |
+
print(f"{split} dataset preloaded!")
|
108 |
+
|
109 |
+
def load_every(self, name):
|
110 |
+
img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name)))
|
111 |
+
img_H, img_W = img_array.shape[0], img_array.shape[1]
|
112 |
+
if self.normalize:
|
113 |
+
img_array = self.normalize_fn(img_array)
|
114 |
+
x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :]
|
115 |
+
x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform
|
116 |
+
return x_img.float(), y_img.float()
|
117 |
+
|
118 |
+
def normalize_fn(self, x):
|
119 |
+
return (x/255. -0.5)*2
|
120 |
+
|
121 |
+
def unnormalize_fn(self, x):
|
122 |
+
return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images
|
123 |
+
|
124 |
+
def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation)
|
125 |
+
if self.preload:
|
126 |
+
x_img, y_img, class_id = self.x_list[index], self.y_list[index], self.class_id[index]
|
127 |
+
else:
|
128 |
+
name = self.image_name_list[index]
|
129 |
+
class_id = torch.tensor(int(name.split("_")[-1][:-4]))
|
130 |
+
x_img, y_img = self.load_every(name)
|
131 |
+
return x_img, y_img, class_id
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return len(self.image_name_list)
|
ldm/data/imagenet.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, yaml, pickle, shutil, tarfile, glob
|
2 |
+
import cv2
|
3 |
+
import albumentations
|
4 |
+
import PIL
|
5 |
+
import numpy as np
|
6 |
+
import torchvision.transforms.functional as TF
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from functools import partial
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.utils.data import Dataset, Subset
|
12 |
+
|
13 |
+
import taming.data.utils as tdu
|
14 |
+
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
15 |
+
from taming.data.imagenet import ImagePaths
|
16 |
+
|
17 |
+
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
18 |
+
|
19 |
+
|
20 |
+
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
21 |
+
with open(path_to_yaml) as f:
|
22 |
+
di2s = yaml.load(f)
|
23 |
+
return dict((v,k) for k,v in di2s.items())
|
24 |
+
|
25 |
+
|
26 |
+
class ImageNetBase(Dataset):
|
27 |
+
def __init__(self, config=None):
|
28 |
+
self.config = config or OmegaConf.create()
|
29 |
+
if not type(self.config)==dict:
|
30 |
+
self.config = OmegaConf.to_container(self.config)
|
31 |
+
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
32 |
+
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
33 |
+
self._prepare()
|
34 |
+
self._prepare_synset_to_human()
|
35 |
+
self._prepare_idx_to_synset()
|
36 |
+
self._prepare_human_to_integer_label()
|
37 |
+
self._load()
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return len(self.data)
|
41 |
+
|
42 |
+
def __getitem__(self, i):
|
43 |
+
return self.data[i]
|
44 |
+
|
45 |
+
def _prepare(self):
|
46 |
+
raise NotImplementedError()
|
47 |
+
|
48 |
+
def _filter_relpaths(self, relpaths):
|
49 |
+
ignore = set([
|
50 |
+
"n06596364_9591.JPEG",
|
51 |
+
])
|
52 |
+
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
53 |
+
if "sub_indices" in self.config:
|
54 |
+
indices = str_to_indices(self.config["sub_indices"])
|
55 |
+
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
56 |
+
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
57 |
+
files = []
|
58 |
+
for rpath in relpaths:
|
59 |
+
syn = rpath.split("/")[0]
|
60 |
+
if syn in synsets:
|
61 |
+
files.append(rpath)
|
62 |
+
return files
|
63 |
+
else:
|
64 |
+
return relpaths
|
65 |
+
|
66 |
+
def _prepare_synset_to_human(self):
|
67 |
+
SIZE = 2655750
|
68 |
+
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
69 |
+
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
70 |
+
if (not os.path.exists(self.human_dict) or
|
71 |
+
not os.path.getsize(self.human_dict)==SIZE):
|
72 |
+
download(URL, self.human_dict)
|
73 |
+
|
74 |
+
def _prepare_idx_to_synset(self):
|
75 |
+
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
76 |
+
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
77 |
+
if (not os.path.exists(self.idx2syn)):
|
78 |
+
download(URL, self.idx2syn)
|
79 |
+
|
80 |
+
def _prepare_human_to_integer_label(self):
|
81 |
+
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
82 |
+
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
83 |
+
if (not os.path.exists(self.human2integer)):
|
84 |
+
download(URL, self.human2integer)
|
85 |
+
with open(self.human2integer, "r") as f:
|
86 |
+
lines = f.read().splitlines()
|
87 |
+
assert len(lines) == 1000
|
88 |
+
self.human2integer_dict = dict()
|
89 |
+
for line in lines:
|
90 |
+
value, key = line.split(":")
|
91 |
+
self.human2integer_dict[key] = int(value)
|
92 |
+
|
93 |
+
def _load(self):
|
94 |
+
with open(self.txt_filelist, "r") as f:
|
95 |
+
self.relpaths = f.read().splitlines()
|
96 |
+
l1 = len(self.relpaths)
|
97 |
+
self.relpaths = self._filter_relpaths(self.relpaths)
|
98 |
+
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
99 |
+
|
100 |
+
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
101 |
+
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
102 |
+
|
103 |
+
unique_synsets = np.unique(self.synsets)
|
104 |
+
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
105 |
+
if not self.keep_orig_class_label:
|
106 |
+
self.class_labels = [class_dict[s] for s in self.synsets]
|
107 |
+
else:
|
108 |
+
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
109 |
+
|
110 |
+
with open(self.human_dict, "r") as f:
|
111 |
+
human_dict = f.read().splitlines()
|
112 |
+
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
113 |
+
|
114 |
+
self.human_labels = [human_dict[s] for s in self.synsets]
|
115 |
+
|
116 |
+
labels = {
|
117 |
+
"relpath": np.array(self.relpaths),
|
118 |
+
"synsets": np.array(self.synsets),
|
119 |
+
"class_label": np.array(self.class_labels),
|
120 |
+
"human_label": np.array(self.human_labels),
|
121 |
+
}
|
122 |
+
|
123 |
+
if self.process_images:
|
124 |
+
self.size = retrieve(self.config, "size", default=256)
|
125 |
+
self.data = ImagePaths(self.abspaths,
|
126 |
+
labels=labels,
|
127 |
+
size=self.size,
|
128 |
+
random_crop=self.random_crop,
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
self.data = self.abspaths
|
132 |
+
|
133 |
+
|
134 |
+
class ImageNetTrain(ImageNetBase):
|
135 |
+
NAME = "ILSVRC2012_train"
|
136 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
137 |
+
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
138 |
+
FILES = [
|
139 |
+
"ILSVRC2012_img_train.tar",
|
140 |
+
]
|
141 |
+
SIZES = [
|
142 |
+
147897477120,
|
143 |
+
]
|
144 |
+
|
145 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
146 |
+
self.process_images = process_images
|
147 |
+
self.data_root = data_root
|
148 |
+
super().__init__(**kwargs)
|
149 |
+
|
150 |
+
def _prepare(self):
|
151 |
+
if self.data_root:
|
152 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
153 |
+
else:
|
154 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
155 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
156 |
+
|
157 |
+
self.datadir = os.path.join(self.root, "data")
|
158 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
159 |
+
self.expected_length = 1281167
|
160 |
+
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
161 |
+
default=True)
|
162 |
+
if not tdu.is_prepared(self.root):
|
163 |
+
# prep
|
164 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
165 |
+
|
166 |
+
datadir = self.datadir
|
167 |
+
if not os.path.exists(datadir):
|
168 |
+
path = os.path.join(self.root, self.FILES[0])
|
169 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
170 |
+
import academictorrents as at
|
171 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
172 |
+
assert atpath == path
|
173 |
+
|
174 |
+
print("Extracting {} to {}".format(path, datadir))
|
175 |
+
os.makedirs(datadir, exist_ok=True)
|
176 |
+
with tarfile.open(path, "r:") as tar:
|
177 |
+
tar.extractall(path=datadir)
|
178 |
+
|
179 |
+
print("Extracting sub-tars.")
|
180 |
+
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
181 |
+
for subpath in tqdm(subpaths):
|
182 |
+
subdir = subpath[:-len(".tar")]
|
183 |
+
os.makedirs(subdir, exist_ok=True)
|
184 |
+
with tarfile.open(subpath, "r:") as tar:
|
185 |
+
tar.extractall(path=subdir)
|
186 |
+
|
187 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
188 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
189 |
+
filelist = sorted(filelist)
|
190 |
+
filelist = "\n".join(filelist)+"\n"
|
191 |
+
with open(self.txt_filelist, "w") as f:
|
192 |
+
f.write(filelist)
|
193 |
+
|
194 |
+
tdu.mark_prepared(self.root)
|
195 |
+
|
196 |
+
|
197 |
+
class ImageNetValidation(ImageNetBase):
|
198 |
+
NAME = "ILSVRC2012_validation"
|
199 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
200 |
+
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
201 |
+
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
202 |
+
FILES = [
|
203 |
+
"ILSVRC2012_img_val.tar",
|
204 |
+
"validation_synset.txt",
|
205 |
+
]
|
206 |
+
SIZES = [
|
207 |
+
6744924160,
|
208 |
+
1950000,
|
209 |
+
]
|
210 |
+
|
211 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
212 |
+
self.data_root = data_root
|
213 |
+
self.process_images = process_images
|
214 |
+
super().__init__(**kwargs)
|
215 |
+
|
216 |
+
def _prepare(self):
|
217 |
+
if self.data_root:
|
218 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
219 |
+
else:
|
220 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
221 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
222 |
+
self.datadir = os.path.join(self.root, "data")
|
223 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
224 |
+
self.expected_length = 50000
|
225 |
+
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
226 |
+
default=False)
|
227 |
+
if not tdu.is_prepared(self.root):
|
228 |
+
# prep
|
229 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
230 |
+
|
231 |
+
datadir = self.datadir
|
232 |
+
if not os.path.exists(datadir):
|
233 |
+
path = os.path.join(self.root, self.FILES[0])
|
234 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
235 |
+
import academictorrents as at
|
236 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
237 |
+
assert atpath == path
|
238 |
+
|
239 |
+
print("Extracting {} to {}".format(path, datadir))
|
240 |
+
os.makedirs(datadir, exist_ok=True)
|
241 |
+
with tarfile.open(path, "r:") as tar:
|
242 |
+
tar.extractall(path=datadir)
|
243 |
+
|
244 |
+
vspath = os.path.join(self.root, self.FILES[1])
|
245 |
+
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
246 |
+
download(self.VS_URL, vspath)
|
247 |
+
|
248 |
+
with open(vspath, "r") as f:
|
249 |
+
synset_dict = f.read().splitlines()
|
250 |
+
synset_dict = dict(line.split() for line in synset_dict)
|
251 |
+
|
252 |
+
print("Reorganizing into synset folders")
|
253 |
+
synsets = np.unique(list(synset_dict.values()))
|
254 |
+
for s in synsets:
|
255 |
+
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
256 |
+
for k, v in synset_dict.items():
|
257 |
+
src = os.path.join(datadir, k)
|
258 |
+
dst = os.path.join(datadir, v)
|
259 |
+
shutil.move(src, dst)
|
260 |
+
|
261 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
262 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
263 |
+
filelist = sorted(filelist)
|
264 |
+
filelist = "\n".join(filelist)+"\n"
|
265 |
+
with open(self.txt_filelist, "w") as f:
|
266 |
+
f.write(filelist)
|
267 |
+
|
268 |
+
tdu.mark_prepared(self.root)
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
class ImageNetSR(Dataset):
|
273 |
+
def __init__(self, size=None,
|
274 |
+
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
275 |
+
random_crop=True):
|
276 |
+
"""
|
277 |
+
Imagenet Superresolution Dataloader
|
278 |
+
Performs following ops in order:
|
279 |
+
1. crops a crop of size s from image either as random or center crop
|
280 |
+
2. resizes crop to size with cv2.area_interpolation
|
281 |
+
3. degrades resized crop with degradation_fn
|
282 |
+
|
283 |
+
:param size: resizing to size after cropping
|
284 |
+
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
285 |
+
:param downscale_f: Low Resolution Downsample factor
|
286 |
+
:param min_crop_f: determines crop size s,
|
287 |
+
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
288 |
+
:param max_crop_f: ""
|
289 |
+
:param data_root:
|
290 |
+
:param random_crop:
|
291 |
+
"""
|
292 |
+
self.base = self.get_base()
|
293 |
+
assert size
|
294 |
+
assert (size / downscale_f).is_integer()
|
295 |
+
self.size = size
|
296 |
+
self.LR_size = int(size / downscale_f)
|
297 |
+
self.min_crop_f = min_crop_f
|
298 |
+
self.max_crop_f = max_crop_f
|
299 |
+
assert(max_crop_f <= 1.)
|
300 |
+
self.center_crop = not random_crop
|
301 |
+
|
302 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
303 |
+
|
304 |
+
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
305 |
+
|
306 |
+
if degradation == "bsrgan":
|
307 |
+
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
308 |
+
|
309 |
+
elif degradation == "bsrgan_light":
|
310 |
+
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
311 |
+
|
312 |
+
else:
|
313 |
+
interpolation_fn = {
|
314 |
+
"cv_nearest": cv2.INTER_NEAREST,
|
315 |
+
"cv_bilinear": cv2.INTER_LINEAR,
|
316 |
+
"cv_bicubic": cv2.INTER_CUBIC,
|
317 |
+
"cv_area": cv2.INTER_AREA,
|
318 |
+
"cv_lanczos": cv2.INTER_LANCZOS4,
|
319 |
+
"pil_nearest": PIL.Image.NEAREST,
|
320 |
+
"pil_bilinear": PIL.Image.BILINEAR,
|
321 |
+
"pil_bicubic": PIL.Image.BICUBIC,
|
322 |
+
"pil_box": PIL.Image.BOX,
|
323 |
+
"pil_hamming": PIL.Image.HAMMING,
|
324 |
+
"pil_lanczos": PIL.Image.LANCZOS,
|
325 |
+
}[degradation]
|
326 |
+
|
327 |
+
self.pil_interpolation = degradation.startswith("pil_")
|
328 |
+
|
329 |
+
if self.pil_interpolation:
|
330 |
+
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
331 |
+
|
332 |
+
else:
|
333 |
+
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
334 |
+
interpolation=interpolation_fn)
|
335 |
+
|
336 |
+
def __len__(self):
|
337 |
+
return len(self.base)
|
338 |
+
|
339 |
+
def __getitem__(self, i):
|
340 |
+
example = self.base[i]
|
341 |
+
image = Image.open(example["file_path_"])
|
342 |
+
|
343 |
+
if not image.mode == "RGB":
|
344 |
+
image = image.convert("RGB")
|
345 |
+
|
346 |
+
image = np.array(image).astype(np.uint8)
|
347 |
+
|
348 |
+
min_side_len = min(image.shape[:2])
|
349 |
+
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
350 |
+
crop_side_len = int(crop_side_len)
|
351 |
+
|
352 |
+
if self.center_crop:
|
353 |
+
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
354 |
+
|
355 |
+
else:
|
356 |
+
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
357 |
+
|
358 |
+
image = self.cropper(image=image)["image"]
|
359 |
+
image = self.image_rescaler(image=image)["image"]
|
360 |
+
|
361 |
+
if self.pil_interpolation:
|
362 |
+
image_pil = PIL.Image.fromarray(image)
|
363 |
+
LR_image = self.degradation_process(image_pil)
|
364 |
+
LR_image = np.array(LR_image).astype(np.uint8)
|
365 |
+
|
366 |
+
else:
|
367 |
+
LR_image = self.degradation_process(image=image)["image"]
|
368 |
+
|
369 |
+
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
370 |
+
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
371 |
+
|
372 |
+
return example
|
373 |
+
|
374 |
+
|
375 |
+
class ImageNetSRTrain(ImageNetSR):
|
376 |
+
def __init__(self, **kwargs):
|
377 |
+
super().__init__(**kwargs)
|
378 |
+
|
379 |
+
def get_base(self):
|
380 |
+
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
381 |
+
indices = pickle.load(f)
|
382 |
+
dset = ImageNetTrain(process_images=False,)
|
383 |
+
return Subset(dset, indices)
|
384 |
+
|
385 |
+
|
386 |
+
class ImageNetSRValidation(ImageNetSR):
|
387 |
+
def __init__(self, **kwargs):
|
388 |
+
super().__init__(**kwargs)
|
389 |
+
|
390 |
+
def get_base(self):
|
391 |
+
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
392 |
+
indices = pickle.load(f)
|
393 |
+
dset = ImageNetValidation(process_images=False,)
|
394 |
+
return Subset(dset, indices)
|
ldm/data/lsun.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
|
9 |
+
class LSUNBase(Dataset):
|
10 |
+
def __init__(self,
|
11 |
+
txt_file,
|
12 |
+
data_root,
|
13 |
+
size=None,
|
14 |
+
interpolation="bicubic",
|
15 |
+
flip_p=0.5
|
16 |
+
):
|
17 |
+
self.data_paths = txt_file
|
18 |
+
self.data_root = data_root
|
19 |
+
with open(self.data_paths, "r") as f:
|
20 |
+
self.image_paths = f.read().splitlines()
|
21 |
+
self._length = len(self.image_paths)
|
22 |
+
self.labels = {
|
23 |
+
"relative_file_path_": [l for l in self.image_paths],
|
24 |
+
"file_path_": [os.path.join(self.data_root, l)
|
25 |
+
for l in self.image_paths],
|
26 |
+
}
|
27 |
+
|
28 |
+
self.size = size
|
29 |
+
self.interpolation = {"linear": PIL.Image.LINEAR,
|
30 |
+
"bilinear": PIL.Image.BILINEAR,
|
31 |
+
"bicubic": PIL.Image.BICUBIC,
|
32 |
+
"lanczos": PIL.Image.LANCZOS,
|
33 |
+
}[interpolation]
|
34 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self._length
|
38 |
+
|
39 |
+
def __getitem__(self, i):
|
40 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
41 |
+
image = Image.open(example["file_path_"])
|
42 |
+
if not image.mode == "RGB":
|
43 |
+
image = image.convert("RGB")
|
44 |
+
|
45 |
+
# default to score-sde preprocessing
|
46 |
+
img = np.array(image).astype(np.uint8)
|
47 |
+
crop = min(img.shape[0], img.shape[1])
|
48 |
+
h, w, = img.shape[0], img.shape[1]
|
49 |
+
img = img[(h - crop) // 2:(h + crop) // 2,
|
50 |
+
(w - crop) // 2:(w + crop) // 2]
|
51 |
+
|
52 |
+
image = Image.fromarray(img)
|
53 |
+
if self.size is not None:
|
54 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
55 |
+
|
56 |
+
image = self.flip(image)
|
57 |
+
image = np.array(image).astype(np.uint8)
|
58 |
+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
59 |
+
return example
|
60 |
+
|
61 |
+
|
62 |
+
class LSUNChurchesTrain(LSUNBase):
|
63 |
+
def __init__(self, **kwargs):
|
64 |
+
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
65 |
+
|
66 |
+
|
67 |
+
class LSUNChurchesValidation(LSUNBase):
|
68 |
+
def __init__(self, flip_p=0., **kwargs):
|
69 |
+
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
70 |
+
flip_p=flip_p, **kwargs)
|
71 |
+
|
72 |
+
|
73 |
+
class LSUNBedroomsTrain(LSUNBase):
|
74 |
+
def __init__(self, **kwargs):
|
75 |
+
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
class LSUNBedroomsValidation(LSUNBase):
|
79 |
+
def __init__(self, flip_p=0.0, **kwargs):
|
80 |
+
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
81 |
+
flip_p=flip_p, **kwargs)
|
82 |
+
|
83 |
+
|
84 |
+
class LSUNCatsTrain(LSUNBase):
|
85 |
+
def __init__(self, **kwargs):
|
86 |
+
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
87 |
+
|
88 |
+
|
89 |
+
class LSUNCatsValidation(LSUNBase):
|
90 |
+
def __init__(self, flip_p=0., **kwargs):
|
91 |
+
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
92 |
+
flip_p=flip_p, **kwargs)
|
ldm/data/phylogeny.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import math
|
4 |
+
import pickle
|
5 |
+
import pprint
|
6 |
+
pp = pprint.PrettyPrinter(indent=4)
|
7 |
+
|
8 |
+
# For phylogeny parsing
|
9 |
+
# !pip install opentree
|
10 |
+
from opentree import OT
|
11 |
+
# !pip install ete3
|
12 |
+
from ete3 import Tree, PhyloTree
|
13 |
+
|
14 |
+
# Constants
|
15 |
+
Fix_Tree = True
|
16 |
+
format_ = 1 #8
|
17 |
+
|
18 |
+
class Phylogeny:
|
19 |
+
# Phylogeny class for Fish dataset
|
20 |
+
# If node_ids is None, it assumes that the tree already exists. Otherwise, you have to pass node_ids (i.e., list of species names).
|
21 |
+
def __init__(self, filePath, node_ids=None, verbose=False):
|
22 |
+
# filenames for phylo tree and cached mapping ottid-speciesname
|
23 |
+
cleaned_fine_tree_fileName = "cleaned_metadata.tre"
|
24 |
+
name_conversion_file = "name_conversion.pkl"
|
25 |
+
self.ott_ids = []
|
26 |
+
self.ott_id_dict = {}
|
27 |
+
self.node_ids = node_ids
|
28 |
+
self.treeFileNameAndPath = os.path.join(filePath, cleaned_fine_tree_fileName)
|
29 |
+
self.conversionFileNameAndPath = os.path.join(filePath, name_conversion_file)
|
30 |
+
self.total_distance = -1 # -1 means we never calculated it before.
|
31 |
+
|
32 |
+
self.distance_matrix = {}
|
33 |
+
self.species_groups_within_relative_distance = {}
|
34 |
+
|
35 |
+
self.get_ott_ids(node_ids, verbose=verbose)
|
36 |
+
self.get_tree(self.treeFileNameAndPath)
|
37 |
+
self.get_total_distance()
|
38 |
+
|
39 |
+
# Given two species names, get the phylo distance between them
|
40 |
+
def get_distance(self, species1, species2):
|
41 |
+
d= None
|
42 |
+
if self.distance_matrix[species1][species2] == -1:
|
43 |
+
if species1 == species2:
|
44 |
+
return 0
|
45 |
+
|
46 |
+
ott_id1 = 'ott' + str(self.ott_id_dict[species1])
|
47 |
+
ott_id2 = 'ott' + str(self.ott_id_dict[species2])
|
48 |
+
d = self.tree.get_distance(ott_id1, ott_id2)
|
49 |
+
|
50 |
+
self.distance_matrix[species1][species2] = d
|
51 |
+
else:
|
52 |
+
d = self.distance_matrix[species1][species2]
|
53 |
+
|
54 |
+
return d
|
55 |
+
|
56 |
+
# relative_distance = 0 => species node itself
|
57 |
+
# relative_distance = 1 => all species
|
58 |
+
def get_siblings_by_name(self, species, relative_distance, verbose=False):
|
59 |
+
self.get_species_groups(relative_distance, verbose)
|
60 |
+
for species_group in self.species_groups_within_relative_distance[relative_distance]:
|
61 |
+
if species in species_group:
|
62 |
+
return species_group
|
63 |
+
|
64 |
+
raise species+" was not found in " + self.species_groups_within_relative_distance[relative_distance]
|
65 |
+
|
66 |
+
def get_parent_by_name(self, species, relative_distance, verbose=False):
|
67 |
+
ott_id = 'ott' + str(self.ott_id_dict[species])
|
68 |
+
parent = self.get_parent_by_ottid(ott_id, relative_distance, verbose)
|
69 |
+
return parent
|
70 |
+
|
71 |
+
def get_distance_between_parents(self, species1, species2, relative_distance):
|
72 |
+
parent1 = self.get_parent_by_name(species1, relative_distance)
|
73 |
+
parent2 = self.get_parent_by_name(species2, relative_distance)
|
74 |
+
return self.tree.get_distance(parent1, parent2)
|
75 |
+
|
76 |
+
def get_species_groups(self, relative_distance, verbose=False):
|
77 |
+
if relative_distance not in self.species_groups_within_relative_distance.keys():
|
78 |
+
groups = {}
|
79 |
+
|
80 |
+
for species in self.getLabelList():
|
81 |
+
parent_node = self.get_parent_by_name(species, relative_distance, verbose)
|
82 |
+
parent = parent_node.name
|
83 |
+
if parent not in groups.keys():
|
84 |
+
groups[parent] = [species]
|
85 |
+
else:
|
86 |
+
groups[parent].append(species)
|
87 |
+
|
88 |
+
self.species_groups_within_relative_distance[relative_distance] = groups.values()
|
89 |
+
|
90 |
+
if verbose:
|
91 |
+
print("At relative_distance", relative_distance, ", the groups are:", groups.values())
|
92 |
+
|
93 |
+
return self.species_groups_within_relative_distance[relative_distance]
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
def getLabelList(self):
|
98 |
+
return list(self.node_ids)
|
99 |
+
|
100 |
+
|
101 |
+
# ------- privete functions
|
102 |
+
|
103 |
+
def get_total_distance(self):
|
104 |
+
if self.node_ids is None:
|
105 |
+
self.node_ids = self.ott_id_dict.keys()
|
106 |
+
|
107 |
+
self.init_distance_matrix()
|
108 |
+
|
109 |
+
# For one time, measure distance from all leaves down to root. They all should be equal.
|
110 |
+
# Save the value and reuse it.
|
111 |
+
|
112 |
+
if self.total_distance==-1:
|
113 |
+
for leaf in self.tree.iter_leaves():
|
114 |
+
total_distance = self.tree.get_distance(leaf) # gets distance to rootprint
|
115 |
+
assert math.isclose(self.total_distance, total_distance) or self.total_distance==-1
|
116 |
+
self.total_distance = total_distance
|
117 |
+
|
118 |
+
return self.total_distance
|
119 |
+
|
120 |
+
def init_distance_matrix(self):
|
121 |
+
for i in self.node_ids:
|
122 |
+
self.distance_matrix[i] = {}
|
123 |
+
for j in self.node_ids:
|
124 |
+
self.distance_matrix[i][j] = -1
|
125 |
+
|
126 |
+
def get_parent_by_ottid(self, ott_id, relative_distance, verbose=False):
|
127 |
+
abs_distance = relative_distance*self.total_distance
|
128 |
+
species_node = self.tree.search_nodes(name=ott_id)[0]
|
129 |
+
if verbose:
|
130 |
+
print('distance to ancestor: ', abs_distance, ". relaive distance: ", relative_distance)
|
131 |
+
|
132 |
+
# keep going up till distance exceeds abs_distance
|
133 |
+
distance = 0
|
134 |
+
parent = species_node
|
135 |
+
while distance < abs_distance:
|
136 |
+
if parent.up is None:
|
137 |
+
break
|
138 |
+
parent = parent.up
|
139 |
+
distance = self.tree.get_distance(parent, species_node)
|
140 |
+
|
141 |
+
return parent
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
# return ott_id_list
|
146 |
+
# node_ids: list of taxa
|
147 |
+
# returns: corresponding list of ott_ids
|
148 |
+
def get_ott_ids(self, node_ids, verbose=False):
|
149 |
+
if not os.path.exists(self.conversionFileNameAndPath):
|
150 |
+
if node_ids is None:
|
151 |
+
raise TypeError('No existing ottid-speciesnames found. node_ids should be a list of species names.')
|
152 |
+
if verbose:
|
153 |
+
print('Included taxonomy: ', node_ids, len(node_ids))
|
154 |
+
df2 = pd.DataFrame(columns=['in csv', 'in response', 'Same?'])
|
155 |
+
|
156 |
+
# Get the matches
|
157 |
+
resp = OT.tnrs_match(node_ids, do_approximate_matching=True)
|
158 |
+
matches = resp.response_dict['results']
|
159 |
+
unmatched_names = resp.response_dict['unmatched_names']
|
160 |
+
|
161 |
+
# Get the corresponding ott_ids
|
162 |
+
ott_ids = set()
|
163 |
+
ott_id_dict={}
|
164 |
+
assert len(unmatched_names)==0 # everything is matched!
|
165 |
+
for match_array in matches:
|
166 |
+
match_array_matches = match_array['matches']
|
167 |
+
assert len(match_array_matches)==1, match_array['name'] + " has too many matches" + str(list(map(lambda x: x['matched_name'], match_array_matches))) # we have a single unambiguous match!
|
168 |
+
first_match = match_array_matches[0]
|
169 |
+
ott_id = first_match['taxon']['ott_id']
|
170 |
+
ott_ids.add(ott_id)
|
171 |
+
if verbose:
|
172 |
+
#some original and matched names are not exactly the same. Not a bug
|
173 |
+
df2 = df2.append({'in csv':match_array['name'], 'in response': first_match['matched_name'], 'Same?': match_array['name'] == first_match['matched_name']}, ignore_index=True)
|
174 |
+
ott_id_dict[match_array['name']] = ott_id
|
175 |
+
ott_ids = list(ott_ids)
|
176 |
+
|
177 |
+
if verbose:
|
178 |
+
print(df2[df2['Same?']== False])
|
179 |
+
pp.pprint(ott_id_dict)
|
180 |
+
|
181 |
+
with open(self.conversionFileNameAndPath, 'wb') as f:
|
182 |
+
pickle.dump([ott_ids, ott_id_dict], f)
|
183 |
+
else:
|
184 |
+
with open(self.conversionFileNameAndPath, 'rb') as f:
|
185 |
+
ott_ids, ott_id_dict = pickle.load(f)
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
self.ott_ids = ott_ids
|
190 |
+
self.ott_id_dict = ott_id_dict
|
191 |
+
print(self.ott_id_dict)
|
192 |
+
|
193 |
+
def fix_tree(self, treeFileNameAndPath):
|
194 |
+
tree = PhyloTree(treeFileNameAndPath, format=format_)
|
195 |
+
|
196 |
+
# Special case for Fish dataset: Fix Esox Americanus.
|
197 |
+
D = tree.search_nodes(name="mrcaott47023ott496121")[0]
|
198 |
+
D.name = "ott496115"
|
199 |
+
tree.write(format=format_, outfile=treeFileNameAndPath)
|
200 |
+
|
201 |
+
def get_tree(self, treeFileNameAndPath):
|
202 |
+
if not os.path.exists(treeFileNameAndPath):
|
203 |
+
output = OT.synth_induced_tree(ott_ids=self.ott_ids, ignore_unknown_ids=False, label_format='id') # name_and_id ott_ids=list(ott_ids),
|
204 |
+
|
205 |
+
output.tree.write(path = treeFileNameAndPath, schema = "newick")
|
206 |
+
|
207 |
+
if Fix_Tree:
|
208 |
+
self.fix_tree(treeFileNameAndPath)
|
209 |
+
|
210 |
+
self.tree = PhyloTree(treeFileNameAndPath, format=format_)
|
211 |
+
|
212 |
+
class PhylogenyCUB:
|
213 |
+
# Phylogeny class for CUB dataset
|
214 |
+
def __init__(self, filePath, node_ids=None, verbose=False):
|
215 |
+
# cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-AllSpecies.phy"
|
216 |
+
# cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-AllSpecies-cub-names.phy"
|
217 |
+
cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-27Species-cub-names.phy"
|
218 |
+
self.node_ids = node_ids
|
219 |
+
self.treeFileNameAndPath = os.path.join(filePath, cleaned_fine_tree_fileName)
|
220 |
+
self.total_distance = -1 # -1 means we never calculated it before.
|
221 |
+
|
222 |
+
self.distance_matrix = {}
|
223 |
+
self.species_groups_within_relative_distance = {}
|
224 |
+
|
225 |
+
self.get_tree(self.treeFileNameAndPath)
|
226 |
+
self.get_total_distance()
|
227 |
+
|
228 |
+
# Given two species names, get the phylo distance between them
|
229 |
+
def get_distance(self, species1, species2):
|
230 |
+
d= None
|
231 |
+
if self.distance_matrix[species1][species2] == -1:
|
232 |
+
if species1 == species2:
|
233 |
+
return 0
|
234 |
+
d = self.tree.get_distance(species1, species2)
|
235 |
+
|
236 |
+
self.distance_matrix[species1][species2] = d
|
237 |
+
else:
|
238 |
+
d = self.distance_matrix[species1][species2]
|
239 |
+
|
240 |
+
return d
|
241 |
+
|
242 |
+
# relative_distance = 0 => species node itself
|
243 |
+
# relative_distance = 1 => all species
|
244 |
+
def get_siblings_by_name(self, species, relative_distance, verbose=False):
|
245 |
+
#NOTE: This implementation was causing inconsistencies since finding the parent.get_leaves() was not equivalent to get_species_groups
|
246 |
+
# ott_id = 'ott' + str(self.ott_id_dict[species])
|
247 |
+
# return self.get_siblings_by_ottid(ott_id, relative_distance, get_ottids, verbose)
|
248 |
+
|
249 |
+
self.get_species_groups(relative_distance, verbose)
|
250 |
+
for species_group in self.species_groups_within_relative_distance[relative_distance]:
|
251 |
+
if species in species_group:
|
252 |
+
return species_group
|
253 |
+
|
254 |
+
raise species+" was not found in " + self.species_groups_within_relative_distance[relative_distance]
|
255 |
+
|
256 |
+
def get_parent_by_name(self, species, relative_distance, verbose=False):
|
257 |
+
abs_distance = relative_distance*self.total_distance
|
258 |
+
species_node = self.tree.search_nodes(name=species)[0]
|
259 |
+
if verbose:
|
260 |
+
print('distance to ancestor: ', abs_distance, ". relaive distance: ", relative_distance)
|
261 |
+
|
262 |
+
# keep going up till distance exceeds abs_distance
|
263 |
+
distance = 0
|
264 |
+
parent = species_node
|
265 |
+
while distance < abs_distance:
|
266 |
+
if parent.up is None:
|
267 |
+
break
|
268 |
+
parent = parent.up
|
269 |
+
distance = self.tree.get_distance(parent, species_node)
|
270 |
+
|
271 |
+
return parent
|
272 |
+
|
273 |
+
def get_distance_between_parents(self, species1, species2, relative_distance):
|
274 |
+
parent1 = self.get_parent_by_name(species1, relative_distance)
|
275 |
+
parent2 = self.get_parent_by_name(species2, relative_distance)
|
276 |
+
return self.tree.get_distance(parent1, parent2)
|
277 |
+
|
278 |
+
def get_species_groups(self, relative_distance, verbose=False):
|
279 |
+
if relative_distance not in self.species_groups_within_relative_distance.keys():
|
280 |
+
groups = {}
|
281 |
+
|
282 |
+
for species in self.getLabelList():
|
283 |
+
parent_node = self.get_parent_by_name(species, relative_distance, verbose)
|
284 |
+
parent = parent_node.name
|
285 |
+
if parent not in groups.keys():
|
286 |
+
groups[parent] = [species]
|
287 |
+
else:
|
288 |
+
groups[parent].append(species)
|
289 |
+
|
290 |
+
self.species_groups_within_relative_distance[relative_distance] = groups.values()
|
291 |
+
|
292 |
+
if verbose:
|
293 |
+
print("At relative_distance", relative_distance, ", the groups are:", groups.values())
|
294 |
+
|
295 |
+
return self.species_groups_within_relative_distance[relative_distance]
|
296 |
+
|
297 |
+
|
298 |
+
def getLabelList(self):
|
299 |
+
return list(self.node_ids)
|
300 |
+
|
301 |
+
|
302 |
+
# ------- privete functions
|
303 |
+
|
304 |
+
def get_total_distance(self):
|
305 |
+
if self.node_ids is None:
|
306 |
+
self.node_ids = sorted([leaf.name for leaf in self.tree.iter_leaves()])
|
307 |
+
|
308 |
+
self.init_distance_matrix()
|
309 |
+
|
310 |
+
# maximum distance between root and lead node taken as total distance
|
311 |
+
leaf_to_root_distances = [self.tree.get_distance(leaf) for leaf in self.tree.iter_leaves()]
|
312 |
+
self.total_distance = max(leaf_to_root_distances)
|
313 |
+
|
314 |
+
return self.total_distance
|
315 |
+
|
316 |
+
def init_distance_matrix(self):
|
317 |
+
for i in self.node_ids:
|
318 |
+
self.distance_matrix[i] = {}
|
319 |
+
for j in self.node_ids:
|
320 |
+
self.distance_matrix[i][j] = -1
|
321 |
+
|
322 |
+
def get_tree(self, treeFileNameAndPath):
|
323 |
+
# if not os.path.exists(treeFileNameAndPath):
|
324 |
+
# output = OT.synth_induced_tree(ott_ids=self.ott_ids, ignore_unknown_ids=False, label_format='id') # name_and_id ott_ids=list(ott_ids),
|
325 |
+
|
326 |
+
# output.tree.write(path = treeFileNameAndPath, schema = "newick")
|
327 |
+
|
328 |
+
self.tree = PhyloTree(treeFileNameAndPath, format=format_)
|
329 |
+
|
330 |
+
# setting a dummy name to the internal nodes if it is unnamed
|
331 |
+
for i, node in enumerate(self.tree.traverse("postorder")):
|
332 |
+
if not len(node.name) > 0:
|
333 |
+
node.name = str(i)
|
ldm/data/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#based on https://github.com/CompVis/taming-transformers
|
2 |
+
|
3 |
+
import collections
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from ldm.data.helper_types import Annotation
|
7 |
+
from torch._six import string_classes
|
8 |
+
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
|
9 |
+
|
10 |
+
|
11 |
+
def custom_collate(batch):
|
12 |
+
r"""source: pytorch 1.9.0, only one modification to original code """
|
13 |
+
|
14 |
+
elem = batch[0]
|
15 |
+
elem_type = type(elem)
|
16 |
+
if isinstance(elem, torch.Tensor):
|
17 |
+
out = None
|
18 |
+
if torch.utils.data.get_worker_info() is not None:
|
19 |
+
# If we're in a background process, concatenate directly into a
|
20 |
+
# shared memory tensor to avoid an extra copy
|
21 |
+
numel = sum([x.numel() for x in batch])
|
22 |
+
storage = elem.storage()._new_shared(numel)
|
23 |
+
out = elem.new(storage)
|
24 |
+
return torch.stack(batch, 0, out=out)
|
25 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
26 |
+
and elem_type.__name__ != 'string_':
|
27 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
28 |
+
# array of string classes and object
|
29 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
30 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
31 |
+
|
32 |
+
return custom_collate([torch.as_tensor(b) for b in batch])
|
33 |
+
elif elem.shape == (): # scalars
|
34 |
+
return torch.as_tensor(batch)
|
35 |
+
elif isinstance(elem, float):
|
36 |
+
return torch.tensor(batch, dtype=torch.float64)
|
37 |
+
elif isinstance(elem, int):
|
38 |
+
return torch.tensor(batch)
|
39 |
+
elif isinstance(elem, string_classes):
|
40 |
+
return batch
|
41 |
+
elif isinstance(elem, collections.abc.Mapping):
|
42 |
+
return {key: custom_collate([d[key] for d in batch]) for key in elem}
|
43 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
44 |
+
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
|
45 |
+
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
|
46 |
+
return batch # added
|
47 |
+
elif isinstance(elem, collections.abc.Sequence):
|
48 |
+
# check to make sure that the elements in batch have consistent size
|
49 |
+
it = iter(batch)
|
50 |
+
elem_size = len(next(it))
|
51 |
+
if not all(len(elem) == elem_size for elem in it):
|
52 |
+
raise RuntimeError('each element in list of batch should be of equal size')
|
53 |
+
transposed = zip(*batch)
|
54 |
+
return [custom_collate(samples) for samples in transposed]
|
55 |
+
|
56 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|