mridulk commited on
Commit
642d5e2
·
1 Parent(s): 008150e

added few ldm files

Browse files
ldm/analysis_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ EPS=1e-10
5
+
6
+ def get_CosineDistance_matrix(features):
7
+ if features.dim() >2:
8
+ features = features.reshape(features.shape[0], -1)
9
+
10
+ features_norm = features / (EPS + features.norm(dim=1)[:, None])
11
+ ans = torch.mm(features_norm, features_norm.transpose(0,1))
12
+
13
+ # We want distance, not similarity.
14
+ ans = torch.add(-ans, 1.)
15
+
16
+ return ans
17
+
18
+ def aggregatefrom_specimen_to_species(sorted_class_names_according_to_class_indx, specimen_distance_matrix, z_size, channels):
19
+ unique_sorted_class_names_according_to_class_indx = sorted(set(sorted_class_names_according_to_class_indx))
20
+
21
+ # species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), 256, 16, 16)
22
+ species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), channels, z_size, z_size)
23
+ for indx_i, i in enumerate(unique_sorted_class_names_according_to_class_indx):
24
+ class_i_indices = [idx for idx, element in enumerate(sorted_class_names_according_to_class_indx) if element == i]
25
+ species_dist_matrix[indx_i] = torch.mean(specimen_distance_matrix[class_i_indices,:], dim=0, keepdim=True)
26
+
27
+ return species_dist_matrix
ldm/loading_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on https://github.com/CompVis/taming-transformers
2
+
3
+ import yaml
4
+ from omegaconf import OmegaConf
5
+ import torch
6
+ from ldm.util import instantiate_from_config
7
+
8
+ ######### loaders
9
+
10
+ def load_config(config_path, display=False):
11
+ config = OmegaConf.load(config_path)
12
+ if display:
13
+ print(yaml.dump(OmegaConf.to_container(config)))
14
+ return config
15
+
16
+ def load_model_from_config(config, ckpt):
17
+ print(f"Loading model from {ckpt}")
18
+ pl_sd = torch.load(ckpt)#, map_location="cpu")
19
+ sd = pl_sd["state_dict"]
20
+ model = instantiate_from_config(config.model)
21
+ m, u = model.load_state_dict(sd, strict=False)
22
+ model.cuda()
23
+ model.eval()
24
+ return model
25
+
26
+ def load_model(config_path, ckpt_path=None):
27
+ # def load_model(config_path, ckpt_path=None, cuda=False, model_type=VQModel):
28
+ # breakpoint()
29
+ # model = model_type(**config.model.params)
30
+ # if ckpt_path is not None:
31
+ # sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
32
+ # missing, unexpected = model.load_state_dict(sd, strict=True)
33
+ # if cuda:
34
+ # model = model.cuda()
35
+
36
+ config = OmegaConf.load(config_path)
37
+ model = load_model_from_config(config, ckpt_path)
38
+ return model
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0, gamma=0.99, step_size=1000):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.gamma = gamma
48
+ self.step_size = step_size
49
+ self.cycle_lengths = cycle_lengths
50
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
51
+ self.last_f = 0.
52
+ self.verbosity_interval = verbosity_interval
53
+
54
+ def find_in_interval(self, n):
55
+ interval = 0
56
+ for cl in self.cum_cycles[1:]:
57
+ if n <= cl:
58
+ return interval
59
+ interval += 1
60
+
61
+ def schedule(self, n, **kwargs):
62
+ cycle = self.find_in_interval(n)
63
+ n = n - self.cum_cycles[cycle]
64
+ if self.verbosity_interval > 0:
65
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
66
+ f"current cycle {cycle}")
67
+ if n < self.lr_warm_up_steps[cycle]:
68
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
69
+ self.last_f = f
70
+ return f
71
+ else:
72
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
73
+ t = min(t, 1.0)
74
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
75
+ 1 + np.cos(t * np.pi))
76
+ self.last_f = f
77
+ return f
78
+
79
+ def __call__(self, n, **kwargs):
80
+ return self.schedule(n, **kwargs)
81
+
82
+
83
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
84
+
85
+ def schedule(self, n, **kwargs):
86
+ cycle = self.find_in_interval(n)
87
+ n = n - self.cum_cycles[cycle]
88
+ if self.verbosity_interval > 0:
89
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}")
91
+
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
94
+ self.last_f = f
95
+ return f
96
+ else:
97
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
98
+ self.last_f = f
99
+ return f
100
+
101
+ class LambdaLinearScheduler_step(LambdaWarmUpCosineScheduler2):
102
+
103
+ def schedule(self, n, **kwargs):
104
+ cycle = self.find_in_interval(n)
105
+ n = n - self.cum_cycles[cycle]
106
+ if self.verbosity_interval > 0:
107
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
108
+ f"current cycle {cycle}")
109
+
110
+ if n < self.lr_warm_up_steps[cycle]:
111
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
112
+ self.last_f = f
113
+ return f
114
+ else:
115
+ f = self.gamma ** ((n-self.lr_warm_up_steps[cycle]) // self.step_size)
116
+ # f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
117
+ self.last_f = f
118
+ return f
119
+
120
+ # class LambdaCustomScheduler:
ldm/plotting_utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #based on https://github.com/CompVis/taming-transformers
3
+
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import os
7
+ from pathlib import Path
8
+ import torchvision
9
+ import torch
10
+ import numpy as np
11
+ from PIL import Image
12
+ import json
13
+ import csv
14
+ import pandas as pd
15
+
16
+ from sklearn.metrics import ConfusionMatrixDisplay
17
+
18
+
19
+ def dump_to_json(dict, ckpt_path, name='results', get_fig_path=True):
20
+
21
+ if get_fig_path:
22
+ root = get_fig_pth(ckpt_path)
23
+ else:
24
+ root = ckpt_path
25
+ if not os.path.exists(root):
26
+ os.mkdir(root)
27
+
28
+ with open(os.path.join(root, name+".json"), "w") as outfile:
29
+ json.dump(dict, outfile)
30
+
31
+
32
+ def save_to_cvs(ckpt_path, postfix, file_name, list_of_created_sequence):
33
+ if ckpt_path is not None:
34
+ root = get_fig_pth(ckpt_path, postfix=postfix)
35
+ else:
36
+ root = postfix
37
+
38
+ file = open(os.path.join(root, file_name), 'w')
39
+ with file:
40
+ write = csv.writer(file)
41
+ write.writerows(list_of_created_sequence)
42
+
43
+ def save_to_txt(arr, ckpt_path, name='results'):
44
+ root = get_fig_pth(ckpt_path)
45
+ with open(os.path.join(root, name+".txt"), "w") as outfile:
46
+ outfile.write(str(arr))
47
+
48
+
49
+
50
+ def save_image_grid(torch_images, ckpt_path=None, subfolder=None, postfix="", nrow=10):
51
+ if ckpt_path is not None:
52
+ root = get_fig_pth(ckpt_path, postfix=subfolder)
53
+ else:
54
+ root = subfolder
55
+
56
+ grid = torchvision.utils.make_grid(torch_images, nrow=nrow)
57
+ grid = torch.clamp(grid, -1., 1.)
58
+
59
+ grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
60
+ grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
61
+ grid = grid.cpu().numpy()
62
+ grid = (grid*255).astype(np.uint8)
63
+ filename = "code_changes_"+postfix+".png"
64
+ path = os.path.join(root, filename)
65
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
66
+ Image.fromarray(grid).save(path, bbox_inches='tight')
67
+
68
+
69
+ def unprocess_image(torch_image):
70
+ torch_image = torch.clamp(torch_image, -1., 1.)
71
+
72
+ torch_image = (torch_image+1.0)/2.0 # -1,1 -> 0,1; c,h,w
73
+ torch_image = torch_image.transpose(0,1).transpose(1,2).squeeze(-1)
74
+ torch_image = torch_image.cpu().numpy()
75
+ torch_image = (torch_image*255).astype(np.uint8)
76
+ return torch_image
77
+
78
+ def save_image(torch_image, image_name, ckpt_path=None, subfolder=None):
79
+ if ckpt_path is not None:
80
+ root = get_fig_pth(ckpt_path, postfix=subfolder)
81
+ else:
82
+ root = subfolder
83
+
84
+ torch_image = unprocess_image(torch_image)
85
+
86
+ filename = image_name+".png"
87
+ path = os.path.join(root, filename)
88
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
89
+ fig = plt.figure()
90
+ plt.imshow(torch_image[0].squeeze())
91
+ fig.savefig(path,bbox_inches='tight',dpi=300)
92
+
93
+
94
+
95
+ def get_fig_pth(ckpt_path, postfix=None):
96
+ figs_postfix = 'figs'
97
+ postfix = os.path.join(figs_postfix, postfix) if postfix is not None else figs_postfix
98
+ parent_path = Path(ckpt_path).parent.parent.absolute()
99
+ fig_path = Path(os.path.join(parent_path, postfix))
100
+ os.makedirs(fig_path, exist_ok=True)
101
+ return fig_path
102
+
103
+ def plot_heatmap(heatmap, ckpt_path=None, title='default', postfix=None):
104
+ if ckpt_path is not None:
105
+ path = get_fig_pth(ckpt_path, postfix=postfix)
106
+ else:
107
+ path = postfix
108
+
109
+ # show
110
+ fig = plt.figure()
111
+ ax = plt.imshow(heatmap, cmap='hot', interpolation='nearest')
112
+ plt.tick_params(left=False, bottom=False)
113
+ # cbar = ax.collections[0].colorbar
114
+ cbar = plt.colorbar(ax)
115
+ cbar.ax.tick_params(labelsize=15)
116
+ plt.axis('off')
117
+ plt.show()
118
+ fig.savefig(os.path.join(path, title+ " heat_map.png"),bbox_inches='tight',dpi=300)
119
+ pd.DataFrame(heatmap.numpy()).to_csv(os.path.join(path, title+ " heat_map.csv"))
120
+
121
+ def plot_heatmap_at_path(heatmap, save_path, ckpt_path=None, title='default', postfix=None):
122
+ if ckpt_path is not None:
123
+ path = get_fig_pth(ckpt_path, postfix=postfix)
124
+ else:
125
+ path = postfix
126
+
127
+ # show
128
+ fig = plt.figure()
129
+ ax = plt.imshow(heatmap, cmap='hot', interpolation='nearest')
130
+ plt.tick_params(left=False, bottom=False)
131
+ # cbar = ax.collections[0].colorbar
132
+ cbar = plt.colorbar(ax)
133
+ cbar.ax.tick_params(labelsize=15)
134
+ plt.axis('off')
135
+ plt.show()
136
+ fig.savefig(os.path.join(save_path, title+ "_heat_map.png"),bbox_inches='tight',dpi=300)
137
+ pd.DataFrame(heatmap.numpy()).to_csv(os.path.join(save_path, title+ "_heat_map.csv"))
138
+
139
+ def plot_confusionmatrix(preds, classes, classnames, ckpt_path, postfix=None, title="", get_fig_path=True):
140
+ fig, ax = plt.subplots(figsize=(30,30))
141
+ preds_max = np.argmax(preds.cpu().numpy(), axis=-1)
142
+ disp = ConfusionMatrixDisplay.from_predictions(classes.cpu().numpy(), preds_max, display_labels=classnames, normalize='true', xticks_rotation='vertical', ax=ax)
143
+ disp.plot()
144
+
145
+ if get_fig_path:
146
+ fig_path = get_fig_pth(ckpt_path, postfix=postfix)
147
+ else:
148
+ fig_path = ckpt_path
149
+ if not os.path.exists(fig_path):
150
+ os.mkdir(fig_path)
151
+
152
+ print(fig_path)
153
+ fig.savefig(os.path.join(fig_path, title+ " heat_map.png"))
154
+
155
+ def plot_confusionmatrix_colormap(preds, classes, classnames, ckpt_path, postfix=None, title="", get_fig_path=True):
156
+ fig, ax = plt.subplots(figsize=(30,30))
157
+ preds_max = np.argmax(preds.cpu().numpy(), axis=-1)
158
+ class_labels = list(range(len(classnames)))
159
+ disp = ConfusionMatrixDisplay.from_predictions(classes.cpu().numpy(), preds_max, display_labels=class_labels, normalize='true', xticks_rotation='vertical', ax=ax, cmap='coolwarm')
160
+ disp.plot()
161
+
162
+ if get_fig_path:
163
+ fig_path = get_fig_pth(ckpt_path, postfix=postfix)
164
+ else:
165
+ fig_path = ckpt_path
166
+ if not os.path.exists(fig_path):
167
+ os.mkdir(fig_path)
168
+
169
+ print(fig_path)
170
+ fig.savefig(os.path.join(fig_path, title+ " heat_map_coolwarm.png"))
171
+
172
+
173
+ class Histogram_plotter:
174
+ def __init__(self, codes_per_phylolevel, n_phylolevels, n_embed,
175
+ converter,
176
+ indx_to_label,
177
+ ckpt_path, directory):
178
+ self.codes_per_phylolevel = codes_per_phylolevel
179
+ self.n_phylolevels = n_phylolevels
180
+ self.n_embed = n_embed
181
+ self.converter = converter
182
+ self.ckpt_path = ckpt_path
183
+ self.directory = directory
184
+ self.indx_to_label = indx_to_label
185
+
186
+ def plot_histograms(self, histograms, species_indx, is_nonattribute=False, prefix="species"):
187
+ fig, axs = plt.subplots(self.codes_per_phylolevel, self.n_phylolevels, figsize = (5*self.n_phylolevels,30))
188
+ for i, ax in enumerate(axs.reshape(-1)):
189
+ ax.hist(histograms[i], density=True, range=(0, self.n_embed-1), bins=self.n_embed)
190
+
191
+ if not is_nonattribute:
192
+ code_location, level = self.converter.get_code_reshaped_index(i)
193
+ ax.set_title("code "+ str(code_location) + "/level " +str(level))
194
+ else:
195
+ ax.set_title("code "+ str(i))
196
+
197
+ plt.show()
198
+ sub_dir = 'attribute' if not is_nonattribute else 'non_attribute'
199
+ fig.savefig(os.path.join(get_fig_pth(self.ckpt_path, postfix=self.directory+'/'+sub_dir), "{}_{}_{}_hostogram.png".format(prefix, species_indx, self.indx_to_label[species_indx])),bbox_inches='tight',dpi=300)
200
+ plt.close(fig)
ldm/util.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import torch
4
+ import hashlib
5
+ import requests
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from collections import abc
9
+ from einops import rearrange
10
+ from functools import partial
11
+
12
+ import multiprocessing as mp
13
+ from threading import Thread
14
+ from queue import Queue
15
+
16
+ from inspect import isfunction
17
+ from PIL import Image, ImageDraw, ImageFont
18
+
19
+ URL_MAP = {
20
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
21
+ }
22
+
23
+ CKPT_MAP = {
24
+ "vgg_lpips": "vgg.pth"
25
+ }
26
+
27
+ MD5_MAP = {
28
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
29
+ }
30
+
31
+ def md5_hash(path):
32
+ with open(path, "rb") as f:
33
+ content = f.read()
34
+ return hashlib.md5(content).hexdigest()
35
+
36
+ def log_txt_as_img(wh, xc, size=10):
37
+ # wh a tuple of (width, height)
38
+ # xc a list of captions to plot
39
+ b = len(xc)
40
+ txts = list()
41
+ for bi in range(b):
42
+ txt = Image.new("RGB", wh, color="white")
43
+ draw = ImageDraw.Draw(txt)
44
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
45
+ nc = int(40 * (wh[0] / 256))
46
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
47
+
48
+ try:
49
+ draw.text((0, 0), lines, fill="black", font=font)
50
+ except UnicodeEncodeError:
51
+ print("Cant encode string for logging. Skipping.")
52
+
53
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
54
+ txts.append(txt)
55
+ txts = np.stack(txts)
56
+ txts = torch.tensor(txts)
57
+ return txts
58
+
59
+ def download(url, local_path, chunk_size=1024):
60
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
61
+ with requests.get(url, stream=True) as r:
62
+ total_size = int(r.headers.get("content-length", 0))
63
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
64
+ with open(local_path, "wb") as f:
65
+ for data in r.iter_content(chunk_size=chunk_size):
66
+ if data:
67
+ f.write(data)
68
+ pbar.update(chunk_size)
69
+
70
+ def get_ckpt_path(name, root, check=False):
71
+ assert name in URL_MAP
72
+ path = os.path.join(root, CKPT_MAP[name])
73
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
74
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
75
+ download(URL_MAP[name], path)
76
+ md5 = md5_hash(path)
77
+ assert md5 == MD5_MAP[name], md5
78
+ return path
79
+
80
+
81
+ def ismap(x):
82
+ if not isinstance(x, torch.Tensor):
83
+ return False
84
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
85
+
86
+
87
+ def isimage(x):
88
+ if not isinstance(x, torch.Tensor):
89
+ return False
90
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
91
+
92
+
93
+ def exists(x):
94
+ return x is not None
95
+
96
+
97
+ def default(val, d):
98
+ if exists(val):
99
+ return val
100
+ return d() if isfunction(d) else d
101
+
102
+
103
+ def mean_flat(tensor):
104
+ """
105
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
106
+ Take the mean over all non-batch dimensions.
107
+ """
108
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
109
+
110
+
111
+ def count_params(model, verbose=False):
112
+ total_params = sum(p.numel() for p in model.parameters())
113
+ if verbose:
114
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
115
+ return total_params
116
+
117
+
118
+ def instantiate_from_config(config):
119
+ if not "target" in config:
120
+ if config == '__is_first_stage__':
121
+ return None
122
+ elif config == "__is_unconditional__":
123
+ return None
124
+ raise KeyError("Expected key `target` to instantiate.")
125
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
126
+
127
+
128
+ def get_obj_from_str(string, reload=False):
129
+ module, cls = string.rsplit(".", 1)
130
+ if reload:
131
+ module_imp = importlib.import_module(module)
132
+ importlib.reload(module_imp)
133
+ return getattr(importlib.import_module(module, package=None), cls)
134
+
135
+
136
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
137
+ # create dummy dataset instance
138
+
139
+ # run prefetching
140
+ if idx_to_fn:
141
+ res = func(data, worker_id=idx)
142
+ else:
143
+ res = func(data)
144
+ Q.put([idx, res])
145
+ Q.put("Done")
146
+
147
+
148
+ def parallel_data_prefetch(
149
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
150
+ ):
151
+ # if target_data_type not in ["ndarray", "list"]:
152
+ # raise ValueError(
153
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
154
+ # )
155
+ if isinstance(data, np.ndarray) and target_data_type == "list":
156
+ raise ValueError("list expected but function got ndarray.")
157
+ elif isinstance(data, abc.Iterable):
158
+ if isinstance(data, dict):
159
+ print(
160
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
161
+ )
162
+ data = list(data.values())
163
+ if target_data_type == "ndarray":
164
+ data = np.asarray(data)
165
+ else:
166
+ data = list(data)
167
+ else:
168
+ raise TypeError(
169
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
170
+ )
171
+
172
+ if cpu_intensive:
173
+ Q = mp.Queue(1000)
174
+ proc = mp.Process
175
+ else:
176
+ Q = Queue(1000)
177
+ proc = Thread
178
+ # spawn processes
179
+ if target_data_type == "ndarray":
180
+ arguments = [
181
+ [func, Q, part, i, use_worker_id]
182
+ for i, part in enumerate(np.array_split(data, n_proc))
183
+ ]
184
+ else:
185
+ step = (
186
+ int(len(data) / n_proc + 1)
187
+ if len(data) % n_proc != 0
188
+ else int(len(data) / n_proc)
189
+ )
190
+ arguments = [
191
+ [func, Q, part, i, use_worker_id]
192
+ for i, part in enumerate(
193
+ [data[i: i + step] for i in range(0, len(data), step)]
194
+ )
195
+ ]
196
+ processes = []
197
+ for i in range(n_proc):
198
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
199
+ processes += [p]
200
+
201
+ # start processes
202
+ print(f"Start prefetching...")
203
+ import time
204
+
205
+ start = time.time()
206
+ gather_res = [[] for _ in range(n_proc)]
207
+ try:
208
+ for p in processes:
209
+ p.start()
210
+
211
+ k = 0
212
+ while k < n_proc:
213
+ # get result
214
+ res = Q.get()
215
+ if res == "Done":
216
+ k += 1
217
+ else:
218
+ gather_res[res[0]] = res[1]
219
+
220
+ except Exception as e:
221
+ print("Exception: ", e)
222
+ for p in processes:
223
+ p.terminate()
224
+
225
+ raise e
226
+ finally:
227
+ for p in processes:
228
+ p.join()
229
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
230
+
231
+ if target_data_type == 'ndarray':
232
+ if not isinstance(gather_res[0], np.ndarray):
233
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
234
+
235
+ # order outputs
236
+ return np.concatenate(gather_res, axis=0)
237
+ elif target_data_type == 'list':
238
+ out = []
239
+ for r in gather_res:
240
+ out.extend(r)
241
+ return out
242
+ else:
243
+ return gather_res