mueller-franzes commited on
Commit
f85e212
1 Parent(s): 96a28c6
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +25 -0
  2. README.md +62 -2
  3. medical_diffusion/data/augmentation/__init__.py +0 -0
  4. medical_diffusion/data/augmentation/augmentations_2d.py +27 -0
  5. medical_diffusion/data/augmentation/augmentations_3d.py +38 -0
  6. medical_diffusion/data/datamodules/__init__.py +1 -0
  7. medical_diffusion/data/datamodules/datamodule_simple.py +79 -0
  8. medical_diffusion/data/datasets/__init__.py +2 -0
  9. medical_diffusion/data/datasets/dataset_simple_2d.py +198 -0
  10. medical_diffusion/data/datasets/dataset_simple_3d.py +58 -0
  11. medical_diffusion/external/diffusers/attention.py +347 -0
  12. medical_diffusion/external/diffusers/embeddings.py +89 -0
  13. medical_diffusion/external/diffusers/resnet.py +479 -0
  14. medical_diffusion/external/diffusers/taming_discriminator.py +57 -0
  15. medical_diffusion/external/diffusers/unet.py +257 -0
  16. medical_diffusion/external/diffusers/unet_blocks.py +1557 -0
  17. medical_diffusion/external/diffusers/vae.py +857 -0
  18. medical_diffusion/external/stable_diffusion/attention.py +261 -0
  19. medical_diffusion/external/stable_diffusion/lr_schedulers.py +33 -0
  20. medical_diffusion/external/stable_diffusion/unet_openai.py +962 -0
  21. medical_diffusion/external/stable_diffusion/util.py +284 -0
  22. medical_diffusion/external/stable_diffusion/util_attention.py +56 -0
  23. medical_diffusion/external/unet_lucidrains.py +332 -0
  24. medical_diffusion/loss/gan_losses.py +22 -0
  25. medical_diffusion/loss/perceivers.py +27 -0
  26. medical_diffusion/metrics/__init__.py +0 -0
  27. medical_diffusion/metrics/torchmetrics_pr_recall.py +170 -0
  28. medical_diffusion/models/__init__.py +1 -0
  29. medical_diffusion/models/embedders/__init__.py +2 -0
  30. medical_diffusion/models/embedders/cond_embedders.py +27 -0
  31. medical_diffusion/models/embedders/latent_embedders.py +1065 -0
  32. medical_diffusion/models/embedders/time_embedder.py +75 -0
  33. medical_diffusion/models/estimators/__init__.py +1 -0
  34. medical_diffusion/models/estimators/unet.py +186 -0
  35. medical_diffusion/models/estimators/unet2.py +279 -0
  36. medical_diffusion/models/model_base.py +114 -0
  37. medical_diffusion/models/noise_schedulers/__init__.py +2 -0
  38. medical_diffusion/models/noise_schedulers/gaussian_scheduler.py +154 -0
  39. medical_diffusion/models/noise_schedulers/scheduler_base.py +49 -0
  40. medical_diffusion/models/pipelines/__init__.py +1 -0
  41. medical_diffusion/models/pipelines/diffusion_pipeline.py +348 -0
  42. medical_diffusion/models/utils/__init__.py +2 -0
  43. medical_diffusion/models/utils/attention_blocks.py +335 -0
  44. medical_diffusion/models/utils/conv_blocks.py +528 -0
  45. medical_diffusion/utils/math_utils.py +6 -0
  46. medical_diffusion/utils/train_utils.py +88 -0
  47. requirements.txt +17 -0
  48. scripts/evaluate_images.py +129 -0
  49. scripts/evaluate_latent_embedder.py +98 -0
  50. scripts/helpers/dump_discrimnator.py +26 -0
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**/*
2
+ !/**/
3
+ /venv/
4
+ !*.ipynb
5
+ !*.gitignore
6
+ !*.md
7
+ !*.bat
8
+ !*.py
9
+ !*.yml
10
+ !*.ui
11
+ !*.yaml
12
+
13
+ !requirements.txt
14
+ !version.txt
15
+
16
+ /docs/build
17
+ !/docs/Makefile
18
+
19
+ /build/
20
+
21
+
22
+ /results
23
+ /scripts/local_trash
24
+
25
+ !/media/**
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Medfusion App
3
- emoji: 🏢
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: streamlit
@@ -10,4 +10,64 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Medfusion App
3
+ emoji: 🔬
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: streamlit
 
10
  license: mit
11
  ---
12
 
13
+ Medfusion - Medical Denoising Diffusion Probabilistic Model
14
+ =============
15
+
16
+ Paper
17
+ =======
18
+ Please see: [**Diffusion Probabilistic Models beat GANs on Medical 2D Images**]()
19
+
20
+ ![](media/Medfusion.png)
21
+ *Figure: Medfusion*
22
+
23
+ ![](media/animation_eye.gif) ![](media/animation_histo.gif) ![](media/animation_chest.gif)\
24
+ *Figure: Eye fundus, chest X-ray and colon histology images generated with Medfusion (Warning color quality limited by .gif)*
25
+
26
+ Demo
27
+ =============
28
+ [Link]() to streamlit app.
29
+
30
+ Install
31
+ =============
32
+
33
+ Create virtual environment and install packages: \
34
+ `python -m venv venv` \
35
+ `source venv/bin/activate`\
36
+ `pip install -e .`
37
+
38
+
39
+ Get Started
40
+ =============
41
+
42
+ 1 Prepare Data
43
+ -------------
44
+
45
+ * Go to [medical_diffusion/data/datasets/dataset_simple_2d.py](medical_diffusion/data/datasets/dataset_simple_2d.py) and create a new `SimpleDataset2D` or write your own Dataset.
46
+
47
+
48
+ 2 Train Autoencoder
49
+ ----------------
50
+ * Go to [scripts/train_latent_embedder_2d.py](scripts/train_latent_embedder_2d.py) and import your Dataset.
51
+ * Load your dataset with eg. `SimpleDataModule`
52
+ * Customize `VAE` to your needs
53
+ * (Optional): Train a `VAEGAN` instead or load a pre-trained `VAE` and set `start_gan_train_step=-1` to start training of GAN immediately.
54
+
55
+ 2.1 Evaluate Autoencoder
56
+ ----------------
57
+ * Use [scripts/evaluate_latent_embedder.py](scripts/evaluate_latent_embedder.py) to evaluate the performance of the Autoencoder.
58
+
59
+ 3 Train Diffusion
60
+ ----------------
61
+ * Go to [scripts/train_diffusion.py](scripts/train_diffusion.py) and import/load your Dataset as before.
62
+ * Load your pre-trained VAE or VAEGAN with `latent_embedder_checkpoint=...`
63
+ * Use `cond_embedder = LabelEmbedder` for conditional training, otherwise `cond_embedder = None`
64
+
65
+ 3.1 Evaluate Diffusion
66
+ ----------------
67
+ * Go to [scripts/sample.py](scripts/sample.py) to sample a test image.
68
+ * Go to [scripts/helpers/sample_dataset.py](scripts/helpers/sample_dataset.py) to sample a more reprensative sample size.
69
+ * Use [scripts/evaluate_images.py](scripts/evaluate_images.py) to evaluate performance of sample (FID, Precision, Recall)
70
+
71
+ Acknowledgment
72
+ =============
73
+ * Code builds upon https://github.com/lucidrains/denoising-diffusion-pytorch
medical_diffusion/data/augmentation/__init__.py ADDED
File without changes
medical_diffusion/data/augmentation/augmentations_2d.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+
5
+ class ToTensor16bit(object):
6
+ """PyTorch can not handle uint16 only int16. First transform to int32. Note, this function also adds a channel-dim"""
7
+ def __call__(self, image):
8
+ # return torch.as_tensor(np.array(image, dtype=np.int32)[None])
9
+ # return torch.from_numpy(np.array(image, np.int32, copy=True)[None])
10
+ image = np.array(image, np.int32, copy=True) # [H,W,C] or [H,W]
11
+ image = np.expand_dims(image, axis=-1) if image.ndim ==2 else image
12
+ return torch.from_numpy(np.moveaxis(image, -1, 0)) #[C, H, W]
13
+
14
+ class Normalize(object):
15
+ """Rescale the image to [0,1] and ensure float32 dtype """
16
+
17
+ def __call__(self, image):
18
+ image = image.type(torch.FloatTensor)
19
+ return (image-image.min())/(image.max()-image.min())
20
+
21
+
22
+ class RandomBackground(object):
23
+ """Fill Background (intensity ==0) with random values"""
24
+
25
+ def __call__(self, image):
26
+ image[image==0] = torch.rand(*image[image==0].shape) #(image.max()-image.min())
27
+ return image
medical_diffusion/data/augmentation/augmentations_3d.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchio as tio
2
+ from typing import Union, Optional, Sequence
3
+ from torchio.typing import TypeTripletInt
4
+ from torchio import Subject, Image
5
+ from torchio.utils import to_tuple
6
+
7
+ class CropOrPad_None(tio.CropOrPad):
8
+ def __init__(
9
+ self,
10
+ target_shape: Union[int, TypeTripletInt, None] = None,
11
+ padding_mode: Union[str, float] = 0,
12
+ mask_name: Optional[str] = None,
13
+ labels: Optional[Sequence[int]] = None,
14
+ **kwargs
15
+ ):
16
+
17
+ # WARNING: Ugly workaround to allow None values
18
+ if target_shape is not None:
19
+ self.original_target_shape = to_tuple(target_shape, length=3)
20
+ target_shape = [1 if t_s is None else t_s for t_s in target_shape]
21
+ super().__init__(target_shape, padding_mode, mask_name, labels, **kwargs)
22
+
23
+ def apply_transform(self, subject: Subject):
24
+ # WARNING: This makes the transformation subject dependent - reverse transformation must be adapted
25
+ if self.target_shape is not None:
26
+ self.target_shape = [s_s if t_s is None else t_s for t_s, s_s in zip(self.original_target_shape, subject.spatial_shape)]
27
+ return super().apply_transform(subject=subject)
28
+
29
+
30
+ class SubjectToTensor(object):
31
+ """Transforms TorchIO Subjects into a Python dict and changes axes order from TorchIO to Torch"""
32
+ def __call__(self, subject: Subject):
33
+ return {key: val.data.swapaxes(1,-1) if isinstance(val, Image) else val for key,val in subject.items()}
34
+
35
+ class ImageToTensor(object):
36
+ """Transforms TorchIO Image into a Numpy/Torch Tensor and changes axes order from TorchIO [B, C, W, H, D] to Torch [B, C, D, H, W]"""
37
+ def __call__(self, image: Image):
38
+ return image.data.swapaxes(1,-1)
medical_diffusion/data/datamodules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .datamodule_simple import SimpleDataModule
medical_diffusion/data/datamodules/datamodule_simple.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pytorch_lightning as pl
3
+ import torch
4
+ from torch.utils.data.dataloader import DataLoader
5
+ import torch.multiprocessing as mp
6
+ from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler
7
+
8
+
9
+
10
+ class SimpleDataModule(pl.LightningDataModule):
11
+
12
+ def __init__(self,
13
+ ds_train: object,
14
+ ds_val:object =None,
15
+ ds_test:object =None,
16
+ batch_size: int = 1,
17
+ num_workers: int = mp.cpu_count(),
18
+ seed: int = 0,
19
+ pin_memory: bool = False,
20
+ weights: list = None
21
+ ):
22
+ super().__init__()
23
+ self.hyperparameters = {**locals()}
24
+ self.hyperparameters.pop('__class__')
25
+ self.hyperparameters.pop('self')
26
+
27
+ self.ds_train = ds_train
28
+ self.ds_val = ds_val
29
+ self.ds_test = ds_test
30
+
31
+ self.batch_size = batch_size
32
+ self.num_workers = num_workers
33
+ self.seed = seed
34
+ self.pin_memory = pin_memory
35
+ self.weights = weights
36
+
37
+
38
+
39
+ def train_dataloader(self):
40
+ generator = torch.Generator()
41
+ generator.manual_seed(self.seed)
42
+
43
+ if self.weights is not None:
44
+ sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator)
45
+ else:
46
+ sampler = RandomSampler(self.ds_train, replacement=False, generator=generator)
47
+ return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers,
48
+ sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory)
49
+
50
+
51
+ def val_dataloader(self):
52
+ generator = torch.Generator()
53
+ generator.manual_seed(self.seed)
54
+ if self.ds_val is not None:
55
+ return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
56
+ generator=generator, drop_last=False, pin_memory=self.pin_memory)
57
+ else:
58
+ raise AssertionError("A validation set was not initialized.")
59
+
60
+
61
+ def test_dataloader(self):
62
+ generator = torch.Generator()
63
+ generator.manual_seed(self.seed)
64
+ if self.ds_test is not None:
65
+ return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
66
+ generator = generator, drop_last=False, pin_memory=self.pin_memory)
67
+ else:
68
+ raise AssertionError("A test test set was not initialized.")
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
medical_diffusion/data/datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .dataset_simple_2d import *
2
+ from .dataset_simple_3d import *
medical_diffusion/data/datasets/dataset_simple_2d.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.utils.data as data
3
+ import torch
4
+ from torch import nn
5
+ from pathlib import Path
6
+ from torchvision import transforms as T
7
+ import pandas as pd
8
+
9
+ from PIL import Image
10
+
11
+ from medical_diffusion.data.augmentation.augmentations_2d import Normalize, ToTensor16bit
12
+
13
+ class SimpleDataset2D(data.Dataset):
14
+ def __init__(
15
+ self,
16
+ path_root,
17
+ item_pointers =[],
18
+ crawler_ext = 'tif', # other options are ['jpg', 'jpeg', 'png', 'tiff'],
19
+ transform = None,
20
+ image_resize = None,
21
+ augment_horizontal_flip = False,
22
+ augment_vertical_flip = False,
23
+ image_crop = None,
24
+ ):
25
+ super().__init__()
26
+ self.path_root = Path(path_root)
27
+ self.crawler_ext = crawler_ext
28
+ if len(item_pointers):
29
+ self.item_pointers = item_pointers
30
+ else:
31
+ self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext)
32
+
33
+ if transform is None:
34
+ self.transform = T.Compose([
35
+ T.Resize(image_resize) if image_resize is not None else nn.Identity(),
36
+ T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
37
+ T.RandomVerticalFlip() if augment_vertical_flip else nn.Identity(),
38
+ T.CenterCrop(image_crop) if image_crop is not None else nn.Identity(),
39
+ T.ToTensor(),
40
+ # T.Lambda(lambda x: torch.cat([x]*3) if x.shape[0]==1 else x),
41
+ # ToTensor16bit(),
42
+ # Normalize(), # [0, 1.0]
43
+ # T.ConvertImageDtype(torch.float),
44
+ T.Normalize(mean=0.5, std=0.5) # WARNING: mean and std are not the target values but rather the values to subtract and divide by: [0, 1] -> [0-0.5, 1-0.5]/0.5 -> [-1, 1]
45
+ ])
46
+ else:
47
+ self.transform = transform
48
+
49
+ def __len__(self):
50
+ return len(self.item_pointers)
51
+
52
+ def __getitem__(self, index):
53
+ rel_path_item = self.item_pointers[index]
54
+ path_item = self.path_root/rel_path_item
55
+ # img = Image.open(path_item)
56
+ img = self.load_item(path_item)
57
+ return {'uid':rel_path_item.stem, 'source': self.transform(img)}
58
+
59
+ def load_item(self, path_item):
60
+ return Image.open(path_item).convert('RGB')
61
+ # return cv2.imread(str(path_item), cv2.IMREAD_UNCHANGED) # NOTE: Only CV2 supports 16bit RGB images
62
+
63
+ @classmethod
64
+ def run_item_crawler(cls, path_root, extension, **kwargs):
65
+ return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]
66
+
67
+ def get_weights(self):
68
+ """Return list of class-weights for WeightedSampling"""
69
+ return None
70
+
71
+
72
+ class AIROGSDataset(SimpleDataset2D):
73
+ def __init__(self, *args, **kwargs):
74
+ super().__init__(*args, **kwargs)
75
+ self.labels = pd.read_csv(self.path_root.parent/'train_labels.csv', index_col='challenge_id')
76
+
77
+ def __len__(self):
78
+ return len(self.labels)
79
+
80
+ def __getitem__(self, index):
81
+ uid = self.labels.index[index]
82
+ path_item = self.path_root/f'{uid}.jpg'
83
+ img = self.load_item(path_item)
84
+ str_2_int = {'NRG':0, 'RG':1} # RG = 3270, NRG = 98172
85
+ target = str_2_int[self.labels.loc[uid, 'class']]
86
+ # return {'uid':uid, 'source': self.transform(img), 'target':target}
87
+ return {'source': self.transform(img), 'target':target}
88
+
89
+ def get_weights(self):
90
+ n_samples = len(self)
91
+ weight_per_class = 1/self.labels['class'].value_counts(normalize=True) # {'NRG': 1.03, 'RG': 31.02}
92
+ weights = [0] * n_samples
93
+ for index in range(n_samples):
94
+ target = self.labels.iloc[index]['class']
95
+ weights[index] = weight_per_class[target]
96
+ return weights
97
+
98
+ @classmethod
99
+ def run_item_crawler(cls, path_root, extension, **kwargs):
100
+ """Overwrite to speed up as paths are determined by .csv file anyway"""
101
+ return []
102
+
103
+ class MSIvsMSS_Dataset(SimpleDataset2D):
104
+ # https://doi.org/10.5281/zenodo.2530835
105
+ def __getitem__(self, index):
106
+ rel_path_item = self.item_pointers[index]
107
+ path_item = self.path_root/rel_path_item
108
+ img = self.load_item(path_item)
109
+ uid = rel_path_item.stem
110
+ str_2_int = {'MSIMUT':0, 'MSS':1}
111
+ target = str_2_int[path_item.parent.name] #
112
+ return {'uid':uid, 'source': self.transform(img), 'target':target}
113
+
114
+
115
+ class MSIvsMSS_2_Dataset(SimpleDataset2D):
116
+ # https://doi.org/10.5281/zenodo.3832231
117
+ def __getitem__(self, index):
118
+ rel_path_item = self.item_pointers[index]
119
+ path_item = self.path_root/rel_path_item
120
+ img = self.load_item(path_item)
121
+ uid = rel_path_item.stem
122
+ str_2_int = {'MSIH':0, 'nonMSIH':1} # patients with MSI-H = MSIH; patients with MSI-L and MSS = NonMSIH)
123
+ target = str_2_int[path_item.parent.name]
124
+ # return {'uid':uid, 'source': self.transform(img), 'target':target}
125
+ return {'source': self.transform(img), 'target':target}
126
+
127
+
128
+ class CheXpert_Dataset(SimpleDataset2D):
129
+ def __init__(self, *args, **kwargs):
130
+ super().__init__(*args, **kwargs)
131
+ mode = self.path_root.name
132
+ labels = pd.read_csv(self.path_root.parent/f'{mode}.csv', index_col='Path')
133
+ self.labels = labels.loc[labels['Frontal/Lateral'] == 'Frontal'].copy()
134
+ self.labels.index = self.labels.index.str[20:]
135
+ self.labels.loc[self.labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Affects 1 case, must be "female" to match stats in publication
136
+ self.labels.fillna(2, inplace=True) # TODO: Find better solution,
137
+ str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1}}
138
+ self.labels.replace(str_2_int, inplace=True)
139
+
140
+ def __len__(self):
141
+ return len(self.labels)
142
+
143
+ def __getitem__(self, index):
144
+ rel_path_item = self.labels.index[index]
145
+ path_item = self.path_root/rel_path_item
146
+ img = self.load_item(path_item)
147
+ uid = str(rel_path_item)
148
+ target = torch.tensor(self.labels.loc[uid, 'Cardiomegaly']+1, dtype=torch.long) # Note Labels are -1=uncertain, 0=negative, 1=positive, NA=not reported -> Map to [0, 2], NA=3
149
+ return {'uid':uid, 'source': self.transform(img), 'target':target}
150
+
151
+
152
+ @classmethod
153
+ def run_item_crawler(cls, path_root, extension, **kwargs):
154
+ """Overwrite to speed up as paths are determined by .csv file anyway"""
155
+ return []
156
+
157
+ class CheXpert_2_Dataset(SimpleDataset2D):
158
+ def __init__(self, *args, **kwargs):
159
+ super().__init__(*args, **kwargs)
160
+ labels = pd.read_csv(self.path_root/'labels/cheXPert_label.csv', index_col=['Path', 'Image Index']) # Note: 1 and -1 (uncertain) cases count as positives (1), 0 and NA count as negatives (0)
161
+ labels = labels.loc[labels['fold']=='train'].copy()
162
+ labels = labels.drop(labels='fold', axis=1)
163
+
164
+ labels2 = pd.read_csv(self.path_root/'labels/train.csv', index_col='Path')
165
+ labels2 = labels2.loc[labels2['Frontal/Lateral'] == 'Frontal'].copy()
166
+ labels2 = labels2[['Cardiomegaly',]].copy()
167
+ labels2[ (labels2 <0) | labels2.isna()] = 2 # 0 = Negative, 1 = Positive, 2 = Uncertain
168
+ labels = labels.join(labels2['Cardiomegaly'], on=["Path",], rsuffix='_true')
169
+ # labels = labels[labels['Cardiomegaly_true']!=2]
170
+
171
+ self.labels = labels
172
+
173
+ def __len__(self):
174
+ return len(self.labels)
175
+
176
+ def __getitem__(self, index):
177
+ path_index, image_index = self.labels.index[index]
178
+ path_item = self.path_root/'data'/f'{image_index:06}.png'
179
+ img = self.load_item(path_item)
180
+ uid = image_index
181
+ target = int(self.labels.loc[(path_index, image_index), 'Cardiomegaly'])
182
+ # return {'uid':uid, 'source': self.transform(img), 'target':target}
183
+ return {'source': self.transform(img), 'target':target}
184
+
185
+ @classmethod
186
+ def run_item_crawler(cls, path_root, extension, **kwargs):
187
+ """Overwrite to speed up as paths are determined by .csv file anyway"""
188
+ return []
189
+
190
+ def get_weights(self):
191
+ n_samples = len(self)
192
+ weight_per_class = 1/self.labels['Cardiomegaly'].value_counts(normalize=True)
193
+ # weight_per_class = {2.0: 1.2, 1.0: 8.2, 0.0: 24.3}
194
+ weights = [0] * n_samples
195
+ for index in range(n_samples):
196
+ target = self.labels.loc[self.labels.index[index], 'Cardiomegaly']
197
+ weights[index] = weight_per_class[target]
198
+ return weights
medical_diffusion/data/datasets/dataset_simple_3d.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.utils.data as data
3
+ from pathlib import Path
4
+ from torchvision import transforms as T
5
+
6
+
7
+ import torchio as tio
8
+
9
+ from medical_diffusion.data.augmentation.augmentations_3d import ImageToTensor
10
+
11
+
12
+ class SimpleDataset3D(data.Dataset):
13
+ def __init__(
14
+ self,
15
+ path_root,
16
+ item_pointers =[],
17
+ crawler_ext = ['nii'], # other options are ['nii.gz'],
18
+ transform = None,
19
+ image_resize = None,
20
+ flip = False,
21
+ image_crop = None,
22
+ use_znorm=True, # Use z-Norm for MRI as scale is arbitrary, otherwise scale intensity to [-1, 1]
23
+ ):
24
+ super().__init__()
25
+ self.path_root = path_root
26
+ self.crawler_ext = crawler_ext
27
+
28
+ if transform is None:
29
+ self.transform = T.Compose([
30
+ tio.Resize(image_resize) if image_resize is not None else tio.Lambda(lambda x: x),
31
+ tio.RandomFlip((0,1,2)) if flip else tio.Lambda(lambda x: x),
32
+ tio.CropOrPad(image_crop) if image_crop is not None else tio.Lambda(lambda x: x),
33
+ tio.ZNormalization() if use_znorm else tio.RescaleIntensity((-1,1)),
34
+ ImageToTensor() # [C, W, H, D] -> [C, D, H, W]
35
+ ])
36
+ else:
37
+ self.transform = transform
38
+
39
+ if len(item_pointers):
40
+ self.item_pointers = item_pointers
41
+ else:
42
+ self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext)
43
+
44
+ def __len__(self):
45
+ return len(self.item_pointers)
46
+
47
+ def __getitem__(self, index):
48
+ rel_path_item = self.item_pointers[index]
49
+ path_item = self.path_root/rel_path_item
50
+ img = self.load_item(path_item)
51
+ return {'uid':rel_path_item.stem, 'source': self.transform(img)}
52
+
53
+ def load_item(self, path_item):
54
+ return tio.ScalarImage(path_item) # Consider to use this or tio.ScalarLabel over SimpleITK (sitk.ReadImage(str(path_item)))
55
+
56
+ @classmethod
57
+ def run_item_crawler(cls, path_root, extension, **kwargs):
58
+ return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]
medical_diffusion/external/diffusers/attention.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class AttentionBlock(nn.Module):
10
+ """
11
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
12
+ to the N-d case.
13
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
14
+ Uses three q, k, v linear layers to compute attention.
15
+
16
+ Parameters:
17
+ channels (:obj:`int`): The number of channels in the input and output.
18
+ num_head_channels (:obj:`int`, *optional*):
19
+ The number of channels in each head. If None, then `num_heads` = 1.
20
+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
21
+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
22
+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ channels: int,
28
+ num_head_channels: Optional[int] = None,
29
+ num_groups: int = 32,
30
+ rescale_output_factor: float = 1.0,
31
+ eps: float = 1e-5,
32
+ ):
33
+ super().__init__()
34
+ self.channels = channels
35
+
36
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
37
+ self.num_head_size = num_head_channels
38
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
39
+
40
+ # define q,k,v as linear layers
41
+ self.query = nn.Linear(channels, channels)
42
+ self.key = nn.Linear(channels, channels)
43
+ self.value = nn.Linear(channels, channels)
44
+
45
+ self.rescale_output_factor = rescale_output_factor
46
+ self.proj_attn = nn.Linear(channels, channels, 1)
47
+
48
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
49
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
50
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
51
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
52
+ return new_projection
53
+
54
+ def forward(self, hidden_states):
55
+ residual = hidden_states
56
+ batch, channel, height, width = hidden_states.shape
57
+
58
+ # norm
59
+ hidden_states = self.group_norm(hidden_states)
60
+
61
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
62
+
63
+ # proj to q, k, v
64
+ query_proj = self.query(hidden_states)
65
+ key_proj = self.key(hidden_states)
66
+ value_proj = self.value(hidden_states)
67
+
68
+ # transpose
69
+ query_states = self.transpose_for_scores(query_proj)
70
+ key_states = self.transpose_for_scores(key_proj)
71
+ value_states = self.transpose_for_scores(value_proj)
72
+
73
+ # get scores
74
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
75
+
76
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
77
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
78
+
79
+ # compute attention output
80
+ hidden_states = torch.matmul(attention_probs, value_states)
81
+
82
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
83
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
84
+ hidden_states = hidden_states.view(new_hidden_states_shape)
85
+
86
+ # compute next hidden_states
87
+ hidden_states = self.proj_attn(hidden_states)
88
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
89
+
90
+ # res connect and rescale
91
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
92
+ return hidden_states
93
+
94
+
95
+ class SpatialTransformer(nn.Module):
96
+ """
97
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
98
+ standard transformer action. Finally, reshape to image.
99
+
100
+ Parameters:
101
+ in_channels (:obj:`int`): The number of channels in the input and output.
102
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
103
+ d_head (:obj:`int`): The number of channels in each head.
104
+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
105
+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
106
+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ in_channels: int,
112
+ n_heads: int,
113
+ d_head: int,
114
+ depth: int = 1,
115
+ dropout: float = 0.0,
116
+ num_groups: int = 32,
117
+ context_dim: Optional[int] = None,
118
+ ):
119
+ super().__init__()
120
+ self.n_heads = n_heads
121
+ self.d_head = d_head
122
+ self.in_channels = in_channels
123
+ inner_dim = n_heads * d_head
124
+ self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
125
+
126
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
127
+
128
+ self.transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
131
+ for d in range(depth)
132
+ ]
133
+ )
134
+
135
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
136
+
137
+ def _set_attention_slice(self, slice_size):
138
+ for block in self.transformer_blocks:
139
+ block._set_attention_slice(slice_size)
140
+
141
+ def forward(self, hidden_states, context=None):
142
+ # note: if no context is given, cross-attention defaults to self-attention
143
+ batch, channel, height, weight = hidden_states.shape
144
+ residual = hidden_states
145
+ hidden_states = self.norm(hidden_states)
146
+ hidden_states = self.proj_in(hidden_states)
147
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
148
+ for block in self.transformer_blocks:
149
+ hidden_states = block(hidden_states, context=context)
150
+ hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
151
+ hidden_states = self.proj_out(hidden_states)
152
+ return hidden_states + residual
153
+
154
+
155
+ class BasicTransformerBlock(nn.Module):
156
+ r"""
157
+ A basic Transformer block.
158
+
159
+ Parameters:
160
+ dim (:obj:`int`): The number of channels in the input and output.
161
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
162
+ d_head (:obj:`int`): The number of channels in each head.
163
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
164
+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
165
+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
166
+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ dim: int,
172
+ n_heads: int,
173
+ d_head: int,
174
+ dropout=0.0,
175
+ context_dim: Optional[int] = None,
176
+ gated_ff: bool = True,
177
+ checkpoint: bool = True,
178
+ ):
179
+ super().__init__()
180
+ self.attn1 = CrossAttention(
181
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
182
+ ) # is a self-attention
183
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
184
+ self.attn2 = CrossAttention(
185
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
186
+ ) # is self-attn if context is none
187
+ self.norm1 = nn.LayerNorm(dim)
188
+ self.norm2 = nn.LayerNorm(dim)
189
+ self.norm3 = nn.LayerNorm(dim)
190
+ self.checkpoint = checkpoint
191
+
192
+ def _set_attention_slice(self, slice_size):
193
+ self.attn1._slice_size = slice_size
194
+ self.attn2._slice_size = slice_size
195
+
196
+ def forward(self, hidden_states, context=None):
197
+ hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
198
+ hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
199
+ hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
200
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
201
+ return hidden_states
202
+
203
+
204
+ class CrossAttention(nn.Module):
205
+ r"""
206
+ A cross attention layer.
207
+
208
+ Parameters:
209
+ query_dim (:obj:`int`): The number of channels in the query.
210
+ context_dim (:obj:`int`, *optional*):
211
+ The number of channels in the context. If not given, defaults to `query_dim`.
212
+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
213
+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
214
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
215
+ """
216
+
217
+ def __init__(
218
+ self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
219
+ ):
220
+ super().__init__()
221
+ inner_dim = dim_head * heads
222
+ context_dim = context_dim if context_dim is not None else query_dim
223
+
224
+ self.scale = dim_head**-0.5
225
+ self.heads = heads
226
+ # for slice_size > 0 the attention score computation
227
+ # is split across the batch axis to save memory
228
+ # You can set slice_size with `set_attention_slice`
229
+ self._slice_size = None
230
+
231
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
232
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
233
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
234
+
235
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
236
+
237
+ def reshape_heads_to_batch_dim(self, tensor):
238
+ batch_size, seq_len, dim = tensor.shape
239
+ head_size = self.heads
240
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
241
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
242
+ return tensor
243
+
244
+ def reshape_batch_dim_to_heads(self, tensor):
245
+ batch_size, seq_len, dim = tensor.shape
246
+ head_size = self.heads
247
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
248
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
249
+ return tensor
250
+
251
+ def forward(self, hidden_states, context=None, mask=None):
252
+ batch_size, sequence_length, _ = hidden_states.shape
253
+
254
+ query = self.to_q(hidden_states)
255
+ context = context if context is not None else hidden_states
256
+ key = self.to_k(context)
257
+ value = self.to_v(context)
258
+
259
+ dim = query.shape[-1]
260
+
261
+ query = self.reshape_heads_to_batch_dim(query)
262
+ key = self.reshape_heads_to_batch_dim(key)
263
+ value = self.reshape_heads_to_batch_dim(value)
264
+
265
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
266
+
267
+ # attention, what we cannot get enough of
268
+
269
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
270
+ hidden_states = self._attention(query, key, value)
271
+ else:
272
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
273
+
274
+ return self.to_out(hidden_states)
275
+
276
+ def _attention(self, query, key, value):
277
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
278
+ attention_probs = attention_scores.softmax(dim=-1)
279
+ # compute attention output
280
+ hidden_states = torch.matmul(attention_probs, value)
281
+ # reshape hidden_states
282
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
283
+ return hidden_states
284
+
285
+ def _sliced_attention(self, query, key, value, sequence_length, dim):
286
+ batch_size_attention = query.shape[0]
287
+ hidden_states = torch.zeros(
288
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
289
+ )
290
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
291
+ for i in range(hidden_states.shape[0] // slice_size):
292
+ start_idx = i * slice_size
293
+ end_idx = (i + 1) * slice_size
294
+ attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
295
+ attn_slice = attn_slice.softmax(dim=-1)
296
+ attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
297
+
298
+ hidden_states[start_idx:end_idx] = attn_slice
299
+
300
+ # reshape hidden_states
301
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
302
+ return hidden_states
303
+
304
+
305
+ class FeedForward(nn.Module):
306
+ r"""
307
+ A feed-forward layer.
308
+
309
+ Parameters:
310
+ dim (:obj:`int`): The number of channels in the input.
311
+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
312
+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
313
+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
314
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
315
+ """
316
+
317
+ def __init__(
318
+ self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
319
+ ):
320
+ super().__init__()
321
+ inner_dim = int(dim * mult)
322
+ dim_out = dim_out if dim_out is not None else dim
323
+ project_in = GEGLU(dim, inner_dim)
324
+
325
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
326
+
327
+ def forward(self, hidden_states):
328
+ return self.net(hidden_states)
329
+
330
+
331
+ # feedforward
332
+ class GEGLU(nn.Module):
333
+ r"""
334
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
335
+
336
+ Parameters:
337
+ dim_in (:obj:`int`): The number of channels in the input.
338
+ dim_out (:obj:`int`): The number of channels in the output.
339
+ """
340
+
341
+ def __init__(self, dim_in: int, dim_out: int):
342
+ super().__init__()
343
+ self.proj = nn.Linear(dim_in, dim_out * 2)
344
+
345
+ def forward(self, hidden_states):
346
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
347
+ return hidden_states * F.gelu(gate)
medical_diffusion/external/diffusers/embeddings.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pydoc import describe
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ def get_timestep_embedding(
10
+ timesteps: torch.Tensor,
11
+ embedding_dim: int,
12
+ flip_sin_to_cos: bool = False,
13
+ downscale_freq_shift: float = 1,
14
+ scale: float = 1,
15
+ max_period: int = 10000,
16
+ ):
17
+ """
18
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
19
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
20
+ These may be fractional.
21
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
22
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
23
+ """
24
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
25
+
26
+ half_dim = embedding_dim // 2
27
+ exponent = -math.log(max_period) * torch.arange(
28
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
29
+ )
30
+ exponent = exponent / (half_dim - downscale_freq_shift)
31
+
32
+ emb = torch.exp(exponent)
33
+ emb = timesteps[:, None].float() * emb[None, :]
34
+
35
+ # scale embeddings
36
+ emb = scale * emb
37
+
38
+ # concat sine and cosine embeddings
39
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
40
+
41
+ # flip sine and cosine embeddings
42
+ if flip_sin_to_cos:
43
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
44
+
45
+ # zero pad
46
+ if embedding_dim % 2 == 1:
47
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
48
+ return emb
49
+
50
+ class Timesteps(nn.Module):
51
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
52
+ super().__init__()
53
+ self.num_channels = num_channels
54
+ self.flip_sin_to_cos = flip_sin_to_cos
55
+ self.downscale_freq_shift = downscale_freq_shift
56
+
57
+ def forward(self, timesteps):
58
+ t_emb = get_timestep_embedding(
59
+ timesteps,
60
+ self.num_channels,
61
+ flip_sin_to_cos=self.flip_sin_to_cos,
62
+ downscale_freq_shift=self.downscale_freq_shift,
63
+ )
64
+ return t_emb
65
+
66
+ class TimeEmbbeding(nn.Module):
67
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
68
+ super().__init__()
69
+
70
+ self.temb = Timesteps(channel, flip_sin_to_cos=True, downscale_freq_shift=0)
71
+
72
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
73
+ self.act = None
74
+ if act_fn == "silu":
75
+ self.act = nn.SiLU()
76
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
77
+
78
+ def forward(self, sample):
79
+ sample = self.temb(sample)
80
+ sample = self.linear_1(sample)
81
+
82
+ if self.act is not None:
83
+ sample = self.act(sample)
84
+
85
+ sample = self.linear_2(sample)
86
+ return sample
87
+
88
+
89
+
medical_diffusion/external/diffusers/resnet.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Upsample2D(nn.Module):
9
+ """
10
+ An upsampling layer with an optional convolution.
11
+
12
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
13
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14
+ upsampling occurs in the inner-two dimensions.
15
+ """
16
+
17
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.out_channels = out_channels or channels
21
+ self.use_conv = use_conv
22
+ self.use_conv_transpose = use_conv_transpose
23
+ self.name = name
24
+
25
+ conv = None
26
+ if use_conv_transpose:
27
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
28
+ elif use_conv:
29
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
30
+
31
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
32
+ if name == "conv":
33
+ self.conv = conv
34
+ else:
35
+ self.Conv2d_0 = conv
36
+
37
+ def forward(self, x):
38
+ assert x.shape[1] == self.channels
39
+ if self.use_conv_transpose:
40
+ return self.conv(x)
41
+
42
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
43
+
44
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
45
+ if self.use_conv:
46
+ if self.name == "conv":
47
+ x = self.conv(x)
48
+ else:
49
+ x = self.Conv2d_0(x)
50
+
51
+ return x
52
+
53
+
54
+ class Downsample2D(nn.Module):
55
+ """
56
+ A downsampling layer with an optional convolution.
57
+
58
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
59
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
60
+ downsampling occurs in the inner-two dimensions.
61
+ """
62
+
63
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
64
+ super().__init__()
65
+ self.channels = channels
66
+ self.out_channels = out_channels or channels
67
+ self.use_conv = use_conv
68
+ self.padding = padding
69
+ stride = 2
70
+ self.name = name
71
+
72
+ if use_conv:
73
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
74
+ else:
75
+ assert self.channels == self.out_channels
76
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
77
+
78
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
79
+ if name == "conv":
80
+ self.Conv2d_0 = conv
81
+ self.conv = conv
82
+ elif name == "Conv2d_0":
83
+ self.conv = conv
84
+ else:
85
+ self.conv = conv
86
+
87
+ def forward(self, x):
88
+ assert x.shape[1] == self.channels
89
+ if self.use_conv and self.padding == 0:
90
+ pad = (0, 1, 0, 1)
91
+ x = F.pad(x, pad, mode="constant", value=0)
92
+
93
+ assert x.shape[1] == self.channels
94
+ x = self.conv(x)
95
+
96
+ return x
97
+
98
+
99
+ class FirUpsample2D(nn.Module):
100
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
101
+ super().__init__()
102
+ out_channels = out_channels if out_channels else channels
103
+ if use_conv:
104
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
105
+ self.use_conv = use_conv
106
+ self.fir_kernel = fir_kernel
107
+ self.out_channels = out_channels
108
+
109
+ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
110
+ """Fused `upsample_2d()` followed by `Conv2d()`.
111
+
112
+ Args:
113
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
114
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
115
+ order.
116
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
117
+ C]`.
118
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
119
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
120
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
121
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
122
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
123
+
124
+ Returns:
125
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
126
+ `x`.
127
+ """
128
+
129
+ assert isinstance(factor, int) and factor >= 1
130
+
131
+ # Setup filter kernel.
132
+ if kernel is None:
133
+ kernel = [1] * factor
134
+
135
+ # setup kernel
136
+ kernel = torch.tensor(kernel, dtype=torch.float32)
137
+ if kernel.ndim == 1:
138
+ kernel = torch.outer(kernel, kernel)
139
+ kernel /= torch.sum(kernel)
140
+
141
+ kernel = kernel * (gain * (factor**2))
142
+
143
+ if self.use_conv:
144
+ convH = weight.shape[2]
145
+ convW = weight.shape[3]
146
+ inC = weight.shape[1]
147
+
148
+ p = (kernel.shape[0] - factor) - (convW - 1)
149
+
150
+ stride = (factor, factor)
151
+ # Determine data dimensions.
152
+ output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
153
+ output_padding = (
154
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
155
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
156
+ )
157
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
158
+ inC = weight.shape[1]
159
+ num_groups = x.shape[1] // inC
160
+
161
+ # Transpose weights.
162
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
163
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
164
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
165
+
166
+ x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
167
+
168
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
169
+ else:
170
+ p = kernel.shape[0] - factor
171
+ x = upfirdn2d_native(
172
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
173
+ )
174
+
175
+ return x
176
+
177
+ def forward(self, x):
178
+ if self.use_conv:
179
+ height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
180
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
181
+ else:
182
+ height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
183
+
184
+ return height
185
+
186
+
187
+ class FirDownsample2D(nn.Module):
188
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
189
+ super().__init__()
190
+ out_channels = out_channels if out_channels else channels
191
+ if use_conv:
192
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
193
+ self.fir_kernel = fir_kernel
194
+ self.use_conv = use_conv
195
+ self.out_channels = out_channels
196
+
197
+ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
198
+ """Fused `Conv2d()` followed by `downsample_2d()`.
199
+
200
+ Args:
201
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
202
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
203
+ order.
204
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
205
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
206
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
207
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
208
+ Scaling factor for signal magnitude (default: 1.0).
209
+
210
+ Returns:
211
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
212
+ datatype as `x`.
213
+ """
214
+
215
+ assert isinstance(factor, int) and factor >= 1
216
+ if kernel is None:
217
+ kernel = [1] * factor
218
+
219
+ # setup kernel
220
+ kernel = torch.tensor(kernel, dtype=torch.float32)
221
+ if kernel.ndim == 1:
222
+ kernel = torch.outer(kernel, kernel)
223
+ kernel /= torch.sum(kernel)
224
+
225
+ kernel = kernel * gain
226
+
227
+ if self.use_conv:
228
+ _, _, convH, convW = weight.shape
229
+ p = (kernel.shape[0] - factor) + (convW - 1)
230
+ s = [factor, factor]
231
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
232
+ x = F.conv2d(x, weight, stride=s, padding=0)
233
+ else:
234
+ p = kernel.shape[0] - factor
235
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
236
+
237
+ return x
238
+
239
+ def forward(self, x):
240
+ if self.use_conv:
241
+ x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
242
+ x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
243
+ else:
244
+ x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
245
+
246
+ return x
247
+
248
+
249
+ class ResnetBlock2D(nn.Module):
250
+ def __init__(
251
+ self,
252
+ *,
253
+ in_channels,
254
+ out_channels=None,
255
+ conv_shortcut=False,
256
+ dropout=0.0,
257
+ temb_channels=512,
258
+ groups=32,
259
+ groups_out=None,
260
+ pre_norm=True,
261
+ eps=1e-6,
262
+ non_linearity="swish",
263
+ time_embedding_norm="default",
264
+ kernel=None,
265
+ output_scale_factor=1.0,
266
+ use_in_shortcut=None,
267
+ up=False,
268
+ down=False,
269
+ ):
270
+ super().__init__()
271
+ self.pre_norm = pre_norm
272
+ self.pre_norm = True
273
+ self.in_channels = in_channels
274
+ out_channels = in_channels if out_channels is None else out_channels
275
+ self.out_channels = out_channels
276
+ self.use_conv_shortcut = conv_shortcut
277
+ self.time_embedding_norm = time_embedding_norm
278
+ self.up = up
279
+ self.down = down
280
+ self.output_scale_factor = output_scale_factor
281
+
282
+ if groups_out is None:
283
+ groups_out = groups
284
+
285
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
286
+
287
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
288
+
289
+ if temb_channels is not None:
290
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
291
+ else:
292
+ self.time_emb_proj = None
293
+
294
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
295
+ self.dropout = torch.nn.Dropout(dropout)
296
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
297
+
298
+ if non_linearity == "swish":
299
+ self.nonlinearity = lambda x: F.silu(x)
300
+ elif non_linearity == "mish":
301
+ self.nonlinearity = Mish()
302
+ elif non_linearity == "silu":
303
+ self.nonlinearity = nn.SiLU()
304
+
305
+ self.upsample = self.downsample = None
306
+ if self.up:
307
+ if kernel == "fir":
308
+ fir_kernel = (1, 3, 3, 1)
309
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
310
+ elif kernel == "sde_vp":
311
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
312
+ else:
313
+ self.upsample = Upsample2D(in_channels, use_conv=False)
314
+ elif self.down:
315
+ if kernel == "fir":
316
+ fir_kernel = (1, 3, 3, 1)
317
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
318
+ elif kernel == "sde_vp":
319
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
320
+ else:
321
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
322
+
323
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
324
+
325
+ self.conv_shortcut = None
326
+ if self.use_in_shortcut:
327
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
328
+
329
+ def forward(self, x, temb):
330
+ hidden_states = x
331
+
332
+ # make sure hidden states is in float32
333
+ # when running in half-precision
334
+ hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
335
+ hidden_states = self.nonlinearity(hidden_states)
336
+
337
+ if self.upsample is not None:
338
+ x = self.upsample(x)
339
+ hidden_states = self.upsample(hidden_states)
340
+ elif self.downsample is not None:
341
+ x = self.downsample(x)
342
+ hidden_states = self.downsample(hidden_states)
343
+
344
+ hidden_states = self.conv1(hidden_states)
345
+
346
+ if temb is not None:
347
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
348
+ hidden_states = hidden_states + temb
349
+
350
+ # make sure hidden states is in float32
351
+ # when running in half-precision
352
+ hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
353
+ hidden_states = self.nonlinearity(hidden_states)
354
+
355
+ hidden_states = self.dropout(hidden_states)
356
+ hidden_states = self.conv2(hidden_states)
357
+
358
+ if self.conv_shortcut is not None:
359
+ x = self.conv_shortcut(x)
360
+
361
+ out = (x + hidden_states) / self.output_scale_factor
362
+
363
+ return out
364
+
365
+
366
+ class Mish(torch.nn.Module):
367
+ def forward(self, x):
368
+ return x * torch.tanh(torch.nn.functional.softplus(x))
369
+
370
+
371
+ def upsample_2d(x, kernel=None, factor=2, gain=1):
372
+ r"""Upsample2D a batch of 2D images with the given filter.
373
+
374
+ Args:
375
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
376
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
377
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
378
+ multiple of the upsampling factor.
379
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
380
+ C]`.
381
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
382
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
383
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
384
+
385
+ Returns:
386
+ Tensor of the shape `[N, C, H * factor, W * factor]`
387
+ """
388
+ assert isinstance(factor, int) and factor >= 1
389
+ if kernel is None:
390
+ kernel = [1] * factor
391
+
392
+ kernel = torch.tensor(kernel, dtype=torch.float32)
393
+ if kernel.ndim == 1:
394
+ kernel = torch.outer(kernel, kernel)
395
+ kernel /= torch.sum(kernel)
396
+
397
+ kernel = kernel * (gain * (factor**2))
398
+ p = kernel.shape[0] - factor
399
+ return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
400
+
401
+
402
+ def downsample_2d(x, kernel=None, factor=2, gain=1):
403
+ r"""Downsample2D a batch of 2D images with the given filter.
404
+
405
+ Args:
406
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
407
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
408
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
409
+ shape is a multiple of the downsampling factor.
410
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
411
+ C]`.
412
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
413
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
414
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
415
+
416
+ Returns:
417
+ Tensor of the shape `[N, C, H // factor, W // factor]`
418
+ """
419
+
420
+ assert isinstance(factor, int) and factor >= 1
421
+ if kernel is None:
422
+ kernel = [1] * factor
423
+
424
+ kernel = torch.tensor(kernel, dtype=torch.float32)
425
+ if kernel.ndim == 1:
426
+ kernel = torch.outer(kernel, kernel)
427
+ kernel /= torch.sum(kernel)
428
+
429
+ kernel = kernel * gain
430
+ p = kernel.shape[0] - factor
431
+ return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
432
+
433
+
434
+ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
435
+ up_x = up_y = up
436
+ down_x = down_y = down
437
+ pad_x0 = pad_y0 = pad[0]
438
+ pad_x1 = pad_y1 = pad[1]
439
+
440
+ _, channel, in_h, in_w = input.shape
441
+ input = input.reshape(-1, in_h, in_w, 1)
442
+
443
+ _, in_h, in_w, minor = input.shape
444
+ kernel_h, kernel_w = kernel.shape
445
+
446
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
447
+
448
+ # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
449
+ if input.device.type == "mps":
450
+ out = out.to("cpu")
451
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
452
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
453
+
454
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
455
+ out = out.to(input.device) # Move back to mps if necessary
456
+ out = out[
457
+ :,
458
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
459
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
460
+ :,
461
+ ]
462
+
463
+ out = out.permute(0, 3, 1, 2)
464
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
465
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
466
+ out = F.conv2d(out, w)
467
+ out = out.reshape(
468
+ -1,
469
+ minor,
470
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
471
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
472
+ )
473
+ out = out.permute(0, 2, 3, 1)
474
+ out = out[:, ::down_y, ::down_x, :]
475
+
476
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
477
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
478
+
479
+ return out.view(-1, channel, out_h, out_w)
medical_diffusion/external/diffusers/taming_discriminator.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+
7
+ class NLayerDiscriminator(nn.Module):
8
+ """Defines a PatchGAN discriminator as in Pix2Pix
9
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
10
+ """
11
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
12
+ """Construct a PatchGAN discriminator
13
+ Parameters:
14
+ input_nc (int) -- the number of channels in input images
15
+ ndf (int) -- the number of filters in the last conv layer
16
+ n_layers (int) -- the number of conv layers in the discriminator
17
+ norm_layer -- normalization layer
18
+ """
19
+ super(NLayerDiscriminator, self).__init__()
20
+ if not use_actnorm:
21
+ norm_layer = nn.BatchNorm2d
22
+ else:
23
+ raise NotImplementedError
24
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
25
+ use_bias = norm_layer.func != nn.BatchNorm2d
26
+ else:
27
+ use_bias = norm_layer != nn.BatchNorm2d
28
+
29
+ kw = 4
30
+ padw = 1
31
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
32
+ nf_mult = 1
33
+ nf_mult_prev = 1
34
+ for n in range(1, n_layers): # gradually increase the number of filters
35
+ nf_mult_prev = nf_mult
36
+ nf_mult = min(2 ** n, 8)
37
+ sequence += [
38
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
39
+ norm_layer(ndf * nf_mult),
40
+ nn.LeakyReLU(0.2, True)
41
+ ]
42
+
43
+ nf_mult_prev = nf_mult
44
+ nf_mult = min(2 ** n_layers, 8)
45
+ sequence += [
46
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
47
+ norm_layer(ndf * nf_mult),
48
+ nn.LeakyReLU(0.2, True)
49
+ ]
50
+
51
+ sequence += [
52
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
53
+ self.main = nn.Sequential(*sequence)
54
+
55
+ def forward(self, input):
56
+ """Standard forward."""
57
+ return self.main(input)
medical_diffusion/external/diffusers/unet.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+
9
+
10
+ from .embeddings import TimeEmbbeding
11
+
12
+ from .unet_blocks import (
13
+ CrossAttnDownBlock2D,
14
+ CrossAttnUpBlock2D,
15
+ DownBlock2D,
16
+ UNetMidBlock2DCrossAttn,
17
+ UpBlock2D,
18
+ get_down_block,
19
+ get_up_block,
20
+ )
21
+
22
+ class TimestepEmbedding(nn.Module):
23
+ def __init__(self, channel, time_embed_dim, act_fn="silu"):
24
+ super().__init__()
25
+
26
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
27
+ self.act = None
28
+ if act_fn == "silu":
29
+ self.act = nn.SiLU()
30
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
31
+
32
+ def forward(self, sample):
33
+ sample = self.linear_1(sample)
34
+
35
+ if self.act is not None:
36
+ sample = self.act(sample)
37
+
38
+ sample = self.linear_2(sample)
39
+ return sample
40
+
41
+
42
+ class UNet2DConditionModel(nn.Module):
43
+ r"""
44
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
45
+ and returns sample shaped output.
46
+
47
+
48
+ Parameters:
49
+ sample_size (`int`, *optional*): The size of the input sample.
50
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
51
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
52
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
53
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
54
+ Whether to flip the sin to cos in the time embedding.
55
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
56
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
57
+ The tuple of downsample blocks to use.
58
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
59
+ The tuple of upsample blocks to use.
60
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
61
+ The tuple of output channels for each block.
62
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
63
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
64
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
65
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
66
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
67
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
68
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
69
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
70
+ """
71
+
72
+ _supports_gradient_checkpointing = True
73
+
74
+
75
+ def __init__(
76
+ self,
77
+ sample_size: Optional[int] = None,
78
+ in_channels: int = 4,
79
+ out_channels: int = 4,
80
+ center_input_sample: bool = False,
81
+ flip_sin_to_cos: bool = True,
82
+ freq_shift: int = 0,
83
+ down_block_types: Tuple[str] = (
84
+ "CrossAttnDownBlock2D",
85
+ "CrossAttnDownBlock2D",
86
+ "CrossAttnDownBlock2D",
87
+ "DownBlock2D",
88
+ ),
89
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
90
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
91
+ layers_per_block: int = 2,
92
+ downsample_padding: int = 1,
93
+ mid_block_scale_factor: float = 1,
94
+ act_fn: str = "silu",
95
+ norm_num_groups: int = 32,
96
+ norm_eps: float = 1e-5,
97
+ cross_attention_dim: int = 768,
98
+ attention_head_dim: int = 8,
99
+ ):
100
+ super().__init__()
101
+
102
+ self.sample_size = sample_size
103
+ time_embed_dim = block_out_channels[0] * 4
104
+
105
+ self.emb = nn.Embedding(2, cross_attention_dim)
106
+
107
+ # input
108
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
109
+
110
+ # time
111
+ self.time_embedding = TimeEmbbeding(block_out_channels[0], time_embed_dim)
112
+
113
+ self.down_blocks = nn.ModuleList([])
114
+ self.mid_block = None
115
+ self.up_blocks = nn.ModuleList([])
116
+
117
+ # down
118
+ output_channel = block_out_channels[0]
119
+ for i, down_block_type in enumerate(down_block_types):
120
+ input_channel = output_channel
121
+ output_channel = block_out_channels[i]
122
+ is_final_block = i == len(block_out_channels) - 1
123
+
124
+ down_block = get_down_block(
125
+ down_block_type,
126
+ num_layers=layers_per_block,
127
+ in_channels=input_channel,
128
+ out_channels=output_channel,
129
+ temb_channels=time_embed_dim,
130
+ add_downsample=not is_final_block,
131
+ resnet_eps=norm_eps,
132
+ resnet_act_fn=act_fn,
133
+ resnet_groups=norm_num_groups,
134
+ cross_attention_dim=cross_attention_dim,
135
+ attn_num_head_channels=attention_head_dim,
136
+ downsample_padding=downsample_padding,
137
+ )
138
+ self.down_blocks.append(down_block)
139
+
140
+ # mid
141
+ self.mid_block = UNetMidBlock2DCrossAttn(
142
+ in_channels=block_out_channels[-1],
143
+ temb_channels=time_embed_dim,
144
+ resnet_eps=norm_eps,
145
+ resnet_act_fn=act_fn,
146
+ output_scale_factor=mid_block_scale_factor,
147
+ resnet_time_scale_shift="default",
148
+ cross_attention_dim=cross_attention_dim,
149
+ attn_num_head_channels=attention_head_dim,
150
+ resnet_groups=norm_num_groups,
151
+ )
152
+
153
+ # up
154
+ reversed_block_out_channels = list(reversed(block_out_channels))
155
+ output_channel = reversed_block_out_channels[0]
156
+ for i, up_block_type in enumerate(up_block_types):
157
+ prev_output_channel = output_channel
158
+ output_channel = reversed_block_out_channels[i]
159
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
160
+
161
+ is_final_block = i == len(block_out_channels) - 1
162
+
163
+ up_block = get_up_block(
164
+ up_block_type,
165
+ num_layers=layers_per_block + 1,
166
+ in_channels=input_channel,
167
+ out_channels=output_channel,
168
+ prev_output_channel=prev_output_channel,
169
+ temb_channels=time_embed_dim,
170
+ add_upsample=not is_final_block,
171
+ resnet_eps=norm_eps,
172
+ resnet_act_fn=act_fn,
173
+ resnet_groups=norm_num_groups,
174
+ cross_attention_dim=cross_attention_dim,
175
+ attn_num_head_channels=attention_head_dim,
176
+ )
177
+ self.up_blocks.append(up_block)
178
+ prev_output_channel = output_channel
179
+
180
+ # out
181
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
182
+ self.conv_act = nn.SiLU()
183
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
184
+
185
+
186
+
187
+ def forward(
188
+ self,
189
+ sample: torch.FloatTensor,
190
+ t: torch.Tensor,
191
+ encoder_hidden_states: torch.Tensor = None,
192
+ self_cond: torch.Tensor = None
193
+ ):
194
+ encoder_hidden_states = self.emb(encoder_hidden_states)
195
+ # encoder_hidden_states = None # ------------------------ WARNING Disabled ---------------------
196
+ """r
197
+ Args:
198
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
199
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
200
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
201
+
202
+ Returns:
203
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
204
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
205
+ returning a tuple, the first element is the sample tensor.
206
+ """
207
+ # 0. center input if necessary
208
+ # if self.config.center_input_sample:
209
+ # sample = 2 * sample - 1.0
210
+
211
+ # 1. time
212
+ t_emb = self.time_embedding(t)
213
+
214
+ # 2. pre-process
215
+ sample = self.conv_in(sample)
216
+
217
+ # 3. down
218
+ down_block_res_samples = (sample,)
219
+ for downsample_block in self.down_blocks:
220
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
221
+ sample, res_samples = downsample_block(
222
+ hidden_states=sample,
223
+ temb=t_emb,
224
+ encoder_hidden_states=encoder_hidden_states,
225
+ )
226
+ else:
227
+ sample, res_samples = downsample_block(hidden_states=sample, temb=t_emb)
228
+
229
+ down_block_res_samples += res_samples
230
+
231
+ # 4. mid
232
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states=encoder_hidden_states)
233
+
234
+ # 5. up
235
+ for upsample_block in self.up_blocks:
236
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
237
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
238
+
239
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
240
+ sample = upsample_block(
241
+ hidden_states=sample,
242
+ temb=t_emb,
243
+ res_hidden_states_tuple=res_samples,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ )
246
+ else:
247
+ sample = upsample_block(hidden_states=sample, temb=t_emb, res_hidden_states_tuple=res_samples)
248
+
249
+ # 6. post-process
250
+ # make sure hidden states is in float32
251
+ # when running in half-precision
252
+ sample = self.conv_norm_out(sample.float()).type(sample.dtype)
253
+ sample = self.conv_act(sample)
254
+ sample = self.conv_out(sample)
255
+
256
+
257
+ return sample, []
medical_diffusion/external/diffusers/unet_blocks.py ADDED
@@ -0,0 +1,1557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+
14
+ import numpy as np
15
+
16
+ # limitations under the License.
17
+ import torch
18
+ from torch import nn
19
+
20
+ from .attention import AttentionBlock, SpatialTransformer
21
+ from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
22
+
23
+
24
+ def get_down_block(
25
+ down_block_type,
26
+ num_layers,
27
+ in_channels,
28
+ out_channels,
29
+ temb_channels,
30
+ add_downsample,
31
+ resnet_eps,
32
+ resnet_act_fn,
33
+ attn_num_head_channels,
34
+ resnet_groups=None,
35
+ cross_attention_dim=None,
36
+ downsample_padding=None,
37
+ ):
38
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
39
+ if down_block_type == "DownBlock2D":
40
+ return DownBlock2D(
41
+ num_layers=num_layers,
42
+ in_channels=in_channels,
43
+ out_channels=out_channels,
44
+ temb_channels=temb_channels,
45
+ add_downsample=add_downsample,
46
+ resnet_eps=resnet_eps,
47
+ resnet_act_fn=resnet_act_fn,
48
+ resnet_groups=resnet_groups,
49
+ downsample_padding=downsample_padding,
50
+ )
51
+ elif down_block_type == "AttnDownBlock2D":
52
+ return AttnDownBlock2D(
53
+ num_layers=num_layers,
54
+ in_channels=in_channels,
55
+ out_channels=out_channels,
56
+ temb_channels=temb_channels,
57
+ add_downsample=add_downsample,
58
+ resnet_eps=resnet_eps,
59
+ resnet_act_fn=resnet_act_fn,
60
+ resnet_groups=resnet_groups,
61
+ downsample_padding=downsample_padding,
62
+ attn_num_head_channels=attn_num_head_channels,
63
+ )
64
+ elif down_block_type == "CrossAttnDownBlock2D":
65
+ if cross_attention_dim is None:
66
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
67
+ return CrossAttnDownBlock2D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ )
80
+ elif down_block_type == "SkipDownBlock2D":
81
+ return SkipDownBlock2D(
82
+ num_layers=num_layers,
83
+ in_channels=in_channels,
84
+ out_channels=out_channels,
85
+ temb_channels=temb_channels,
86
+ add_downsample=add_downsample,
87
+ resnet_eps=resnet_eps,
88
+ resnet_act_fn=resnet_act_fn,
89
+ downsample_padding=downsample_padding,
90
+ )
91
+ elif down_block_type == "AttnSkipDownBlock2D":
92
+ return AttnSkipDownBlock2D(
93
+ num_layers=num_layers,
94
+ in_channels=in_channels,
95
+ out_channels=out_channels,
96
+ temb_channels=temb_channels,
97
+ add_downsample=add_downsample,
98
+ resnet_eps=resnet_eps,
99
+ resnet_act_fn=resnet_act_fn,
100
+ downsample_padding=downsample_padding,
101
+ attn_num_head_channels=attn_num_head_channels,
102
+ )
103
+ elif down_block_type == "DownEncoderBlock2D":
104
+ return DownEncoderBlock2D(
105
+ num_layers=num_layers,
106
+ in_channels=in_channels,
107
+ out_channels=out_channels,
108
+ add_downsample=add_downsample,
109
+ resnet_eps=resnet_eps,
110
+ resnet_act_fn=resnet_act_fn,
111
+ resnet_groups=resnet_groups,
112
+ downsample_padding=downsample_padding,
113
+ )
114
+
115
+
116
+ def get_up_block(
117
+ up_block_type,
118
+ num_layers,
119
+ in_channels,
120
+ out_channels,
121
+ prev_output_channel,
122
+ temb_channels,
123
+ add_upsample,
124
+ resnet_eps,
125
+ resnet_act_fn,
126
+ attn_num_head_channels,
127
+ resnet_groups=None,
128
+ cross_attention_dim=None,
129
+ ):
130
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
131
+ if up_block_type == "UpBlock2D":
132
+ return UpBlock2D(
133
+ num_layers=num_layers,
134
+ in_channels=in_channels,
135
+ out_channels=out_channels,
136
+ prev_output_channel=prev_output_channel,
137
+ temb_channels=temb_channels,
138
+ add_upsample=add_upsample,
139
+ resnet_eps=resnet_eps,
140
+ resnet_act_fn=resnet_act_fn,
141
+ resnet_groups=resnet_groups,
142
+ )
143
+ elif up_block_type == "CrossAttnUpBlock2D":
144
+ if cross_attention_dim is None:
145
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
146
+ return CrossAttnUpBlock2D(
147
+ num_layers=num_layers,
148
+ in_channels=in_channels,
149
+ out_channels=out_channels,
150
+ prev_output_channel=prev_output_channel,
151
+ temb_channels=temb_channels,
152
+ add_upsample=add_upsample,
153
+ resnet_eps=resnet_eps,
154
+ resnet_act_fn=resnet_act_fn,
155
+ resnet_groups=resnet_groups,
156
+ cross_attention_dim=cross_attention_dim,
157
+ attn_num_head_channels=attn_num_head_channels,
158
+ )
159
+ elif up_block_type == "AttnUpBlock2D":
160
+ return AttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ in_channels=in_channels,
163
+ out_channels=out_channels,
164
+ prev_output_channel=prev_output_channel,
165
+ temb_channels=temb_channels,
166
+ add_upsample=add_upsample,
167
+ resnet_eps=resnet_eps,
168
+ resnet_act_fn=resnet_act_fn,
169
+ resnet_groups=resnet_groups,
170
+ attn_num_head_channels=attn_num_head_channels,
171
+ )
172
+ elif up_block_type == "SkipUpBlock2D":
173
+ return SkipUpBlock2D(
174
+ num_layers=num_layers,
175
+ in_channels=in_channels,
176
+ out_channels=out_channels,
177
+ prev_output_channel=prev_output_channel,
178
+ temb_channels=temb_channels,
179
+ add_upsample=add_upsample,
180
+ resnet_eps=resnet_eps,
181
+ resnet_act_fn=resnet_act_fn,
182
+ )
183
+ elif up_block_type == "AttnSkipUpBlock2D":
184
+ return AttnSkipUpBlock2D(
185
+ num_layers=num_layers,
186
+ in_channels=in_channels,
187
+ out_channels=out_channels,
188
+ prev_output_channel=prev_output_channel,
189
+ temb_channels=temb_channels,
190
+ add_upsample=add_upsample,
191
+ resnet_eps=resnet_eps,
192
+ resnet_act_fn=resnet_act_fn,
193
+ attn_num_head_channels=attn_num_head_channels,
194
+ )
195
+ elif up_block_type == "UpDecoderBlock2D":
196
+ return UpDecoderBlock2D(
197
+ num_layers=num_layers,
198
+ in_channels=in_channels,
199
+ out_channels=out_channels,
200
+ add_upsample=add_upsample,
201
+ resnet_eps=resnet_eps,
202
+ resnet_act_fn=resnet_act_fn,
203
+ resnet_groups=resnet_groups,
204
+ )
205
+ raise ValueError(f"{up_block_type} does not exist.")
206
+
207
+
208
+ class UNetMidBlock2D(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_channels: int,
212
+ temb_channels: int,
213
+ dropout: float = 0.0,
214
+ num_layers: int = 1,
215
+ resnet_eps: float = 1e-6,
216
+ resnet_time_scale_shift: str = "default",
217
+ resnet_act_fn: str = "swish",
218
+ resnet_groups: int = 32,
219
+ resnet_pre_norm: bool = True,
220
+ attn_num_head_channels=1,
221
+ attention_type="default",
222
+ output_scale_factor=1.0,
223
+ **kwargs,
224
+ ):
225
+ super().__init__()
226
+
227
+ self.attention_type = attention_type
228
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
229
+
230
+ # there is always at least one resnet
231
+ resnets = [
232
+ ResnetBlock2D(
233
+ in_channels=in_channels,
234
+ out_channels=in_channels,
235
+ temb_channels=temb_channels,
236
+ eps=resnet_eps,
237
+ groups=resnet_groups,
238
+ dropout=dropout,
239
+ time_embedding_norm=resnet_time_scale_shift,
240
+ non_linearity=resnet_act_fn,
241
+ output_scale_factor=output_scale_factor,
242
+ pre_norm=resnet_pre_norm,
243
+ )
244
+ ]
245
+ attentions = []
246
+
247
+ for _ in range(num_layers):
248
+ attentions.append(
249
+ AttentionBlock(
250
+ in_channels,
251
+ num_head_channels=attn_num_head_channels,
252
+ rescale_output_factor=output_scale_factor,
253
+ eps=resnet_eps,
254
+ num_groups=resnet_groups,
255
+ )
256
+ )
257
+ resnets.append(
258
+ ResnetBlock2D(
259
+ in_channels=in_channels,
260
+ out_channels=in_channels,
261
+ temb_channels=temb_channels,
262
+ eps=resnet_eps,
263
+ groups=resnet_groups,
264
+ dropout=dropout,
265
+ time_embedding_norm=resnet_time_scale_shift,
266
+ non_linearity=resnet_act_fn,
267
+ output_scale_factor=output_scale_factor,
268
+ pre_norm=resnet_pre_norm,
269
+ )
270
+ )
271
+
272
+ self.attentions = nn.ModuleList(attentions)
273
+ self.resnets = nn.ModuleList(resnets)
274
+
275
+ def forward(self, hidden_states, temb=None, encoder_states=None):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
278
+ if self.attention_type == "default":
279
+ hidden_states = attn(hidden_states)
280
+ else:
281
+ hidden_states = attn(hidden_states, encoder_states)
282
+ hidden_states = resnet(hidden_states, temb)
283
+
284
+ return hidden_states
285
+
286
+
287
+ class UNetMidBlock2DCrossAttn(nn.Module):
288
+ def __init__(
289
+ self,
290
+ in_channels: int,
291
+ temb_channels: int,
292
+ dropout: float = 0.0,
293
+ num_layers: int = 1,
294
+ resnet_eps: float = 1e-6,
295
+ resnet_time_scale_shift: str = "default",
296
+ resnet_act_fn: str = "swish",
297
+ resnet_groups: int = 32,
298
+ resnet_pre_norm: bool = True,
299
+ attn_num_head_channels=1,
300
+ attention_type="default",
301
+ output_scale_factor=1.0,
302
+ cross_attention_dim=1280,
303
+ **kwargs,
304
+ ):
305
+ super().__init__()
306
+
307
+ self.attention_type = attention_type
308
+ self.attn_num_head_channels = attn_num_head_channels
309
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
310
+
311
+ # there is always at least one resnet
312
+ resnets = [
313
+ ResnetBlock2D(
314
+ in_channels=in_channels,
315
+ out_channels=in_channels,
316
+ temb_channels=temb_channels,
317
+ eps=resnet_eps,
318
+ groups=resnet_groups,
319
+ dropout=dropout,
320
+ time_embedding_norm=resnet_time_scale_shift,
321
+ non_linearity=resnet_act_fn,
322
+ output_scale_factor=output_scale_factor,
323
+ pre_norm=resnet_pre_norm,
324
+ )
325
+ ]
326
+ attentions = []
327
+
328
+ for _ in range(num_layers):
329
+ attentions.append(
330
+ SpatialTransformer(
331
+ in_channels,
332
+ attn_num_head_channels,
333
+ in_channels // attn_num_head_channels,
334
+ depth=1,
335
+ context_dim=cross_attention_dim,
336
+ num_groups=resnet_groups,
337
+ )
338
+ )
339
+ resnets.append(
340
+ ResnetBlock2D(
341
+ in_channels=in_channels,
342
+ out_channels=in_channels,
343
+ temb_channels=temb_channels,
344
+ eps=resnet_eps,
345
+ groups=resnet_groups,
346
+ dropout=dropout,
347
+ time_embedding_norm=resnet_time_scale_shift,
348
+ non_linearity=resnet_act_fn,
349
+ output_scale_factor=output_scale_factor,
350
+ pre_norm=resnet_pre_norm,
351
+ )
352
+ )
353
+
354
+ self.attentions = nn.ModuleList(attentions)
355
+ self.resnets = nn.ModuleList(resnets)
356
+
357
+ def set_attention_slice(self, slice_size):
358
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
359
+ raise ValueError(
360
+ f"Make sure slice_size {slice_size} is a divisor of "
361
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
362
+ )
363
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
364
+ raise ValueError(
365
+ f"Chunk_size {slice_size} has to be smaller or equal to "
366
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
367
+ )
368
+
369
+ for attn in self.attentions:
370
+ attn._set_attention_slice(slice_size)
371
+
372
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
373
+ hidden_states = self.resnets[0](hidden_states, temb)
374
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
375
+ hidden_states = attn(hidden_states, encoder_hidden_states)
376
+ hidden_states = resnet(hidden_states, temb)
377
+
378
+ return hidden_states
379
+
380
+
381
+ class AttnDownBlock2D(nn.Module):
382
+ def __init__(
383
+ self,
384
+ in_channels: int,
385
+ out_channels: int,
386
+ temb_channels: int,
387
+ dropout: float = 0.0,
388
+ num_layers: int = 1,
389
+ resnet_eps: float = 1e-6,
390
+ resnet_time_scale_shift: str = "default",
391
+ resnet_act_fn: str = "swish",
392
+ resnet_groups: int = 32,
393
+ resnet_pre_norm: bool = True,
394
+ attn_num_head_channels=1,
395
+ attention_type="default",
396
+ output_scale_factor=1.0,
397
+ downsample_padding=1,
398
+ add_downsample=True,
399
+ ):
400
+ super().__init__()
401
+ resnets = []
402
+ attentions = []
403
+
404
+ self.attention_type = attention_type
405
+
406
+ for i in range(num_layers):
407
+ in_channels = in_channels if i == 0 else out_channels
408
+ resnets.append(
409
+ ResnetBlock2D(
410
+ in_channels=in_channels,
411
+ out_channels=out_channels,
412
+ temb_channels=temb_channels,
413
+ eps=resnet_eps,
414
+ groups=resnet_groups,
415
+ dropout=dropout,
416
+ time_embedding_norm=resnet_time_scale_shift,
417
+ non_linearity=resnet_act_fn,
418
+ output_scale_factor=output_scale_factor,
419
+ pre_norm=resnet_pre_norm,
420
+ )
421
+ )
422
+ attentions.append(
423
+ AttentionBlock(
424
+ out_channels,
425
+ num_head_channels=attn_num_head_channels,
426
+ rescale_output_factor=output_scale_factor,
427
+ eps=resnet_eps,
428
+ num_groups=resnet_groups,
429
+ )
430
+ )
431
+
432
+ self.attentions = nn.ModuleList(attentions)
433
+ self.resnets = nn.ModuleList(resnets)
434
+
435
+ if add_downsample:
436
+ self.downsamplers = nn.ModuleList(
437
+ [
438
+ Downsample2D(
439
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
440
+ )
441
+ ]
442
+ )
443
+ else:
444
+ self.downsamplers = None
445
+
446
+ def forward(self, hidden_states, temb=None):
447
+ output_states = ()
448
+
449
+ for resnet, attn in zip(self.resnets, self.attentions):
450
+ hidden_states = resnet(hidden_states, temb)
451
+ hidden_states = attn(hidden_states)
452
+ output_states += (hidden_states,)
453
+
454
+ if self.downsamplers is not None:
455
+ for downsampler in self.downsamplers:
456
+ hidden_states = downsampler(hidden_states)
457
+
458
+ output_states += (hidden_states,)
459
+
460
+ return hidden_states, output_states
461
+
462
+
463
+ class CrossAttnDownBlock2D(nn.Module):
464
+ def __init__(
465
+ self,
466
+ in_channels: int,
467
+ out_channels: int,
468
+ temb_channels: int,
469
+ dropout: float = 0.0,
470
+ num_layers: int = 1,
471
+ resnet_eps: float = 1e-6,
472
+ resnet_time_scale_shift: str = "default",
473
+ resnet_act_fn: str = "swish",
474
+ resnet_groups: int = 32,
475
+ resnet_pre_norm: bool = True,
476
+ attn_num_head_channels=1,
477
+ cross_attention_dim=1280,
478
+ attention_type="default",
479
+ output_scale_factor=1.0,
480
+ downsample_padding=1,
481
+ add_downsample=True,
482
+ ):
483
+ super().__init__()
484
+ resnets = []
485
+ attentions = []
486
+
487
+ self.attention_type = attention_type
488
+ self.attn_num_head_channels = attn_num_head_channels
489
+
490
+ for i in range(num_layers):
491
+ in_channels = in_channels if i == 0 else out_channels
492
+ resnets.append(
493
+ ResnetBlock2D(
494
+ in_channels=in_channels,
495
+ out_channels=out_channels,
496
+ temb_channels=temb_channels,
497
+ eps=resnet_eps,
498
+ groups=resnet_groups,
499
+ dropout=dropout,
500
+ time_embedding_norm=resnet_time_scale_shift,
501
+ non_linearity=resnet_act_fn,
502
+ output_scale_factor=output_scale_factor,
503
+ pre_norm=resnet_pre_norm,
504
+ )
505
+ )
506
+ attentions.append(
507
+ SpatialTransformer(
508
+ out_channels,
509
+ attn_num_head_channels,
510
+ out_channels // attn_num_head_channels,
511
+ depth=1,
512
+ context_dim=cross_attention_dim,
513
+ num_groups=resnet_groups,
514
+ )
515
+ )
516
+ self.attentions = nn.ModuleList(attentions)
517
+ self.resnets = nn.ModuleList(resnets)
518
+
519
+ if add_downsample:
520
+ self.downsamplers = nn.ModuleList(
521
+ [
522
+ Downsample2D(
523
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
524
+ )
525
+ ]
526
+ )
527
+ else:
528
+ self.downsamplers = None
529
+
530
+ self.gradient_checkpointing = False
531
+
532
+ def set_attention_slice(self, slice_size):
533
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
534
+ raise ValueError(
535
+ f"Make sure slice_size {slice_size} is a divisor of "
536
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
537
+ )
538
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
539
+ raise ValueError(
540
+ f"Chunk_size {slice_size} has to be smaller or equal to "
541
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
542
+ )
543
+
544
+ for attn in self.attentions:
545
+ attn._set_attention_slice(slice_size)
546
+
547
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
548
+ output_states = ()
549
+
550
+ for resnet, attn in zip(self.resnets, self.attentions):
551
+ if self.training and self.gradient_checkpointing:
552
+
553
+ def create_custom_forward(module):
554
+ def custom_forward(*inputs):
555
+ return module(*inputs)
556
+
557
+ return custom_forward
558
+
559
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
560
+ hidden_states = torch.utils.checkpoint.checkpoint(
561
+ create_custom_forward(attn), hidden_states, encoder_hidden_states
562
+ )
563
+ else:
564
+ hidden_states = resnet(hidden_states, temb)
565
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
566
+
567
+ output_states += (hidden_states,)
568
+
569
+ if self.downsamplers is not None:
570
+ for downsampler in self.downsamplers:
571
+ hidden_states = downsampler(hidden_states)
572
+
573
+ output_states += (hidden_states,)
574
+
575
+ return hidden_states, output_states
576
+
577
+
578
+ class DownBlock2D(nn.Module):
579
+ def __init__(
580
+ self,
581
+ in_channels: int,
582
+ out_channels: int,
583
+ temb_channels: int,
584
+ dropout: float = 0.0,
585
+ num_layers: int = 1,
586
+ resnet_eps: float = 1e-6,
587
+ resnet_time_scale_shift: str = "default",
588
+ resnet_act_fn: str = "swish",
589
+ resnet_groups: int = 32,
590
+ resnet_pre_norm: bool = True,
591
+ output_scale_factor=1.0,
592
+ add_downsample=True,
593
+ downsample_padding=1,
594
+ ):
595
+ super().__init__()
596
+ resnets = []
597
+
598
+ for i in range(num_layers):
599
+ in_channels = in_channels if i == 0 else out_channels
600
+ resnets.append(
601
+ ResnetBlock2D(
602
+ in_channels=in_channels,
603
+ out_channels=out_channels,
604
+ temb_channels=temb_channels,
605
+ eps=resnet_eps,
606
+ groups=resnet_groups,
607
+ dropout=dropout,
608
+ time_embedding_norm=resnet_time_scale_shift,
609
+ non_linearity=resnet_act_fn,
610
+ output_scale_factor=output_scale_factor,
611
+ pre_norm=resnet_pre_norm,
612
+ )
613
+ )
614
+
615
+ self.resnets = nn.ModuleList(resnets)
616
+
617
+ if add_downsample:
618
+ self.downsamplers = nn.ModuleList(
619
+ [
620
+ Downsample2D(
621
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
622
+ )
623
+ ]
624
+ )
625
+ else:
626
+ self.downsamplers = None
627
+
628
+ self.gradient_checkpointing = False
629
+
630
+ def forward(self, hidden_states, temb=None):
631
+ output_states = ()
632
+
633
+ for resnet in self.resnets:
634
+ if self.training and self.gradient_checkpointing:
635
+
636
+ def create_custom_forward(module):
637
+ def custom_forward(*inputs):
638
+ return module(*inputs)
639
+
640
+ return custom_forward
641
+
642
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
643
+ else:
644
+ hidden_states = resnet(hidden_states, temb)
645
+
646
+ output_states += (hidden_states,)
647
+
648
+ if self.downsamplers is not None:
649
+ for downsampler in self.downsamplers:
650
+ hidden_states = downsampler(hidden_states)
651
+
652
+ output_states += (hidden_states,)
653
+
654
+ return hidden_states, output_states
655
+
656
+
657
+ class DownEncoderBlock2D(nn.Module):
658
+ def __init__(
659
+ self,
660
+ in_channels: int,
661
+ out_channels: int,
662
+ dropout: float = 0.0,
663
+ num_layers: int = 1,
664
+ resnet_eps: float = 1e-6,
665
+ resnet_time_scale_shift: str = "default",
666
+ resnet_act_fn: str = "swish",
667
+ resnet_groups: int = 32,
668
+ resnet_pre_norm: bool = True,
669
+ output_scale_factor=1.0,
670
+ add_downsample=True,
671
+ downsample_padding=1,
672
+ ):
673
+ super().__init__()
674
+ resnets = []
675
+
676
+ for i in range(num_layers):
677
+ in_channels = in_channels if i == 0 else out_channels
678
+ resnets.append(
679
+ ResnetBlock2D(
680
+ in_channels=in_channels,
681
+ out_channels=out_channels,
682
+ temb_channels=None,
683
+ eps=resnet_eps,
684
+ groups=resnet_groups,
685
+ dropout=dropout,
686
+ time_embedding_norm=resnet_time_scale_shift,
687
+ non_linearity=resnet_act_fn,
688
+ output_scale_factor=output_scale_factor,
689
+ pre_norm=resnet_pre_norm,
690
+ )
691
+ )
692
+
693
+ self.resnets = nn.ModuleList(resnets)
694
+
695
+ if add_downsample:
696
+ self.downsamplers = nn.ModuleList(
697
+ [
698
+ Downsample2D(
699
+ out_channels if len(resnets)>0 else in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
700
+ )
701
+ ]
702
+ )
703
+ else:
704
+ self.downsamplers = None
705
+
706
+ def forward(self, hidden_states):
707
+ for resnet in self.resnets:
708
+ hidden_states = resnet(hidden_states, temb=None)
709
+
710
+ if self.downsamplers is not None:
711
+ for downsampler in self.downsamplers:
712
+ hidden_states = downsampler(hidden_states)
713
+
714
+ return hidden_states
715
+
716
+
717
+ class AttnDownEncoderBlock2D(nn.Module):
718
+ def __init__(
719
+ self,
720
+ in_channels: int,
721
+ out_channels: int,
722
+ dropout: float = 0.0,
723
+ num_layers: int = 1,
724
+ resnet_eps: float = 1e-6,
725
+ resnet_time_scale_shift: str = "default",
726
+ resnet_act_fn: str = "swish",
727
+ resnet_groups: int = 32,
728
+ resnet_pre_norm: bool = True,
729
+ attn_num_head_channels=1,
730
+ output_scale_factor=1.0,
731
+ add_downsample=True,
732
+ downsample_padding=1,
733
+ ):
734
+ super().__init__()
735
+ resnets = []
736
+ attentions = []
737
+
738
+ for i in range(num_layers):
739
+ in_channels = in_channels if i == 0 else out_channels
740
+ resnets.append(
741
+ ResnetBlock2D(
742
+ in_channels=in_channels,
743
+ out_channels=out_channels,
744
+ temb_channels=None,
745
+ eps=resnet_eps,
746
+ groups=resnet_groups,
747
+ dropout=dropout,
748
+ time_embedding_norm=resnet_time_scale_shift,
749
+ non_linearity=resnet_act_fn,
750
+ output_scale_factor=output_scale_factor,
751
+ pre_norm=resnet_pre_norm,
752
+ )
753
+ )
754
+ attentions.append(
755
+ AttentionBlock(
756
+ out_channels,
757
+ num_head_channels=attn_num_head_channels,
758
+ rescale_output_factor=output_scale_factor,
759
+ eps=resnet_eps,
760
+ num_groups=resnet_groups,
761
+ )
762
+ )
763
+
764
+ self.attentions = nn.ModuleList(attentions)
765
+ self.resnets = nn.ModuleList(resnets)
766
+
767
+ if add_downsample:
768
+ self.downsamplers = nn.ModuleList(
769
+ [
770
+ Downsample2D(
771
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
772
+ )
773
+ ]
774
+ )
775
+ else:
776
+ self.downsamplers = None
777
+
778
+ def forward(self, hidden_states):
779
+ for resnet, attn in zip(self.resnets, self.attentions):
780
+ hidden_states = resnet(hidden_states, temb=None)
781
+ hidden_states = attn(hidden_states)
782
+
783
+ if self.downsamplers is not None:
784
+ for downsampler in self.downsamplers:
785
+ hidden_states = downsampler(hidden_states)
786
+
787
+ return hidden_states
788
+
789
+
790
+ class AttnSkipDownBlock2D(nn.Module):
791
+ def __init__(
792
+ self,
793
+ in_channels: int,
794
+ out_channels: int,
795
+ temb_channels: int,
796
+ dropout: float = 0.0,
797
+ num_layers: int = 1,
798
+ resnet_eps: float = 1e-6,
799
+ resnet_time_scale_shift: str = "default",
800
+ resnet_act_fn: str = "swish",
801
+ resnet_pre_norm: bool = True,
802
+ attn_num_head_channels=1,
803
+ attention_type="default",
804
+ output_scale_factor=np.sqrt(2.0),
805
+ downsample_padding=1,
806
+ add_downsample=True,
807
+ ):
808
+ super().__init__()
809
+ self.attentions = nn.ModuleList([])
810
+ self.resnets = nn.ModuleList([])
811
+
812
+ self.attention_type = attention_type
813
+
814
+ for i in range(num_layers):
815
+ in_channels = in_channels if i == 0 else out_channels
816
+ self.resnets.append(
817
+ ResnetBlock2D(
818
+ in_channels=in_channels,
819
+ out_channels=out_channels,
820
+ temb_channels=temb_channels,
821
+ eps=resnet_eps,
822
+ groups=min(in_channels // 4, 32),
823
+ groups_out=min(out_channels // 4, 32),
824
+ dropout=dropout,
825
+ time_embedding_norm=resnet_time_scale_shift,
826
+ non_linearity=resnet_act_fn,
827
+ output_scale_factor=output_scale_factor,
828
+ pre_norm=resnet_pre_norm,
829
+ )
830
+ )
831
+ self.attentions.append(
832
+ AttentionBlock(
833
+ out_channels,
834
+ num_head_channels=attn_num_head_channels,
835
+ rescale_output_factor=output_scale_factor,
836
+ eps=resnet_eps,
837
+ )
838
+ )
839
+
840
+ if add_downsample:
841
+ self.resnet_down = ResnetBlock2D(
842
+ in_channels=out_channels,
843
+ out_channels=out_channels,
844
+ temb_channels=temb_channels,
845
+ eps=resnet_eps,
846
+ groups=min(out_channels // 4, 32),
847
+ dropout=dropout,
848
+ time_embedding_norm=resnet_time_scale_shift,
849
+ non_linearity=resnet_act_fn,
850
+ output_scale_factor=output_scale_factor,
851
+ pre_norm=resnet_pre_norm,
852
+ use_in_shortcut=True,
853
+ down=True,
854
+ kernel="fir",
855
+ )
856
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
857
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
858
+ else:
859
+ self.resnet_down = None
860
+ self.downsamplers = None
861
+ self.skip_conv = None
862
+
863
+ def forward(self, hidden_states, temb=None, skip_sample=None):
864
+ output_states = ()
865
+
866
+ for resnet, attn in zip(self.resnets, self.attentions):
867
+ hidden_states = resnet(hidden_states, temb)
868
+ hidden_states = attn(hidden_states)
869
+ output_states += (hidden_states,)
870
+
871
+ if self.downsamplers is not None:
872
+ hidden_states = self.resnet_down(hidden_states, temb)
873
+ for downsampler in self.downsamplers:
874
+ skip_sample = downsampler(skip_sample)
875
+
876
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
877
+
878
+ output_states += (hidden_states,)
879
+
880
+ return hidden_states, output_states, skip_sample
881
+
882
+
883
+ class SkipDownBlock2D(nn.Module):
884
+ def __init__(
885
+ self,
886
+ in_channels: int,
887
+ out_channels: int,
888
+ temb_channels: int,
889
+ dropout: float = 0.0,
890
+ num_layers: int = 1,
891
+ resnet_eps: float = 1e-6,
892
+ resnet_time_scale_shift: str = "default",
893
+ resnet_act_fn: str = "swish",
894
+ resnet_pre_norm: bool = True,
895
+ output_scale_factor=np.sqrt(2.0),
896
+ add_downsample=True,
897
+ downsample_padding=1,
898
+ ):
899
+ super().__init__()
900
+ self.resnets = nn.ModuleList([])
901
+
902
+ for i in range(num_layers):
903
+ in_channels = in_channels if i == 0 else out_channels
904
+ self.resnets.append(
905
+ ResnetBlock2D(
906
+ in_channels=in_channels,
907
+ out_channels=out_channels,
908
+ temb_channels=temb_channels,
909
+ eps=resnet_eps,
910
+ groups=min(in_channels // 4, 32),
911
+ groups_out=min(out_channels // 4, 32),
912
+ dropout=dropout,
913
+ time_embedding_norm=resnet_time_scale_shift,
914
+ non_linearity=resnet_act_fn,
915
+ output_scale_factor=output_scale_factor,
916
+ pre_norm=resnet_pre_norm,
917
+ )
918
+ )
919
+
920
+ if add_downsample:
921
+ self.resnet_down = ResnetBlock2D(
922
+ in_channels=out_channels,
923
+ out_channels=out_channels,
924
+ temb_channels=temb_channels,
925
+ eps=resnet_eps,
926
+ groups=min(out_channels // 4, 32),
927
+ dropout=dropout,
928
+ time_embedding_norm=resnet_time_scale_shift,
929
+ non_linearity=resnet_act_fn,
930
+ output_scale_factor=output_scale_factor,
931
+ pre_norm=resnet_pre_norm,
932
+ use_in_shortcut=True,
933
+ down=True,
934
+ kernel="fir",
935
+ )
936
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
937
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
938
+ else:
939
+ self.resnet_down = None
940
+ self.downsamplers = None
941
+ self.skip_conv = None
942
+
943
+ def forward(self, hidden_states, temb=None, skip_sample=None):
944
+ output_states = ()
945
+
946
+ for resnet in self.resnets:
947
+ hidden_states = resnet(hidden_states, temb)
948
+ output_states += (hidden_states,)
949
+
950
+ if self.downsamplers is not None:
951
+ hidden_states = self.resnet_down(hidden_states, temb)
952
+ for downsampler in self.downsamplers:
953
+ skip_sample = downsampler(skip_sample)
954
+
955
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
956
+
957
+ output_states += (hidden_states,)
958
+
959
+ return hidden_states, output_states, skip_sample
960
+
961
+
962
+ class AttnUpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ dropout: float = 0.0,
970
+ num_layers: int = 1,
971
+ resnet_eps: float = 1e-6,
972
+ resnet_time_scale_shift: str = "default",
973
+ resnet_act_fn: str = "swish",
974
+ resnet_groups: int = 32,
975
+ resnet_pre_norm: bool = True,
976
+ attention_type="default",
977
+ attn_num_head_channels=1,
978
+ output_scale_factor=1.0,
979
+ add_upsample=True,
980
+ ):
981
+ super().__init__()
982
+ resnets = []
983
+ attentions = []
984
+
985
+ self.attention_type = attention_type
986
+
987
+ for i in range(num_layers):
988
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
989
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
990
+
991
+ resnets.append(
992
+ ResnetBlock2D(
993
+ in_channels=resnet_in_channels + res_skip_channels,
994
+ out_channels=out_channels,
995
+ temb_channels=temb_channels,
996
+ eps=resnet_eps,
997
+ groups=resnet_groups,
998
+ dropout=dropout,
999
+ time_embedding_norm=resnet_time_scale_shift,
1000
+ non_linearity=resnet_act_fn,
1001
+ output_scale_factor=output_scale_factor,
1002
+ pre_norm=resnet_pre_norm,
1003
+ )
1004
+ )
1005
+ attentions.append(
1006
+ AttentionBlock(
1007
+ out_channels,
1008
+ num_head_channels=attn_num_head_channels,
1009
+ rescale_output_factor=output_scale_factor,
1010
+ eps=resnet_eps,
1011
+ num_groups=resnet_groups,
1012
+ )
1013
+ )
1014
+
1015
+ self.attentions = nn.ModuleList(attentions)
1016
+ self.resnets = nn.ModuleList(resnets)
1017
+
1018
+ if add_upsample:
1019
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1020
+ else:
1021
+ self.upsamplers = None
1022
+
1023
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1024
+ for resnet, attn in zip(self.resnets, self.attentions):
1025
+ # pop res hidden states
1026
+ res_hidden_states = res_hidden_states_tuple[-1]
1027
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1028
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1029
+
1030
+ hidden_states = resnet(hidden_states, temb)
1031
+ hidden_states = attn(hidden_states)
1032
+
1033
+ if self.upsamplers is not None:
1034
+ for upsampler in self.upsamplers:
1035
+ hidden_states = upsampler(hidden_states)
1036
+
1037
+ return hidden_states
1038
+
1039
+
1040
+ class CrossAttnUpBlock2D(nn.Module):
1041
+ def __init__(
1042
+ self,
1043
+ in_channels: int,
1044
+ out_channels: int,
1045
+ prev_output_channel: int,
1046
+ temb_channels: int,
1047
+ dropout: float = 0.0,
1048
+ num_layers: int = 1,
1049
+ resnet_eps: float = 1e-6,
1050
+ resnet_time_scale_shift: str = "default",
1051
+ resnet_act_fn: str = "swish",
1052
+ resnet_groups: int = 32,
1053
+ resnet_pre_norm: bool = True,
1054
+ attn_num_head_channels=1,
1055
+ cross_attention_dim=1280,
1056
+ attention_type="default",
1057
+ output_scale_factor=1.0,
1058
+ downsample_padding=1,
1059
+ add_upsample=True,
1060
+ ):
1061
+ super().__init__()
1062
+ resnets = []
1063
+ attentions = []
1064
+
1065
+ self.attention_type = attention_type
1066
+ self.attn_num_head_channels = attn_num_head_channels
1067
+
1068
+ for i in range(num_layers):
1069
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1070
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1071
+
1072
+ resnets.append(
1073
+ ResnetBlock2D(
1074
+ in_channels=resnet_in_channels + res_skip_channels,
1075
+ out_channels=out_channels,
1076
+ temb_channels=temb_channels,
1077
+ eps=resnet_eps,
1078
+ groups=resnet_groups,
1079
+ dropout=dropout,
1080
+ time_embedding_norm=resnet_time_scale_shift,
1081
+ non_linearity=resnet_act_fn,
1082
+ output_scale_factor=output_scale_factor,
1083
+ pre_norm=resnet_pre_norm,
1084
+ )
1085
+ )
1086
+ attentions.append(
1087
+ SpatialTransformer(
1088
+ out_channels,
1089
+ attn_num_head_channels,
1090
+ out_channels // attn_num_head_channels,
1091
+ depth=1,
1092
+ context_dim=cross_attention_dim,
1093
+ num_groups=resnet_groups,
1094
+ )
1095
+ )
1096
+ self.attentions = nn.ModuleList(attentions)
1097
+ self.resnets = nn.ModuleList(resnets)
1098
+
1099
+ if add_upsample:
1100
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1101
+ else:
1102
+ self.upsamplers = None
1103
+
1104
+ self.gradient_checkpointing = False
1105
+
1106
+ def set_attention_slice(self, slice_size):
1107
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1108
+ raise ValueError(
1109
+ f"Make sure slice_size {slice_size} is a divisor of "
1110
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1111
+ )
1112
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
1113
+ raise ValueError(
1114
+ f"Chunk_size {slice_size} has to be smaller or equal to "
1115
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1116
+ )
1117
+
1118
+ for attn in self.attentions:
1119
+ attn._set_attention_slice(slice_size)
1120
+
1121
+ self.gradient_checkpointing = False
1122
+
1123
+ def forward(
1124
+ self,
1125
+ hidden_states,
1126
+ res_hidden_states_tuple,
1127
+ temb=None,
1128
+ encoder_hidden_states=None,
1129
+ ):
1130
+ for resnet, attn in zip(self.resnets, self.attentions):
1131
+ # pop res hidden states
1132
+ res_hidden_states = res_hidden_states_tuple[-1]
1133
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1134
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1135
+
1136
+ if self.training and self.gradient_checkpointing:
1137
+
1138
+ def create_custom_forward(module):
1139
+ def custom_forward(*inputs):
1140
+ return module(*inputs)
1141
+
1142
+ return custom_forward
1143
+
1144
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1145
+ hidden_states = torch.utils.checkpoint.checkpoint(
1146
+ create_custom_forward(attn), hidden_states, encoder_hidden_states
1147
+ )
1148
+ else:
1149
+ hidden_states = resnet(hidden_states, temb)
1150
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
1151
+
1152
+ if self.upsamplers is not None:
1153
+ for upsampler in self.upsamplers:
1154
+ hidden_states = upsampler(hidden_states)
1155
+
1156
+ return hidden_states
1157
+
1158
+
1159
+ class UpBlock2D(nn.Module):
1160
+ def __init__(
1161
+ self,
1162
+ in_channels: int,
1163
+ prev_output_channel: int,
1164
+ out_channels: int,
1165
+ temb_channels: int,
1166
+ dropout: float = 0.0,
1167
+ num_layers: int = 1,
1168
+ resnet_eps: float = 1e-6,
1169
+ resnet_time_scale_shift: str = "default",
1170
+ resnet_act_fn: str = "swish",
1171
+ resnet_groups: int = 32,
1172
+ resnet_pre_norm: bool = True,
1173
+ output_scale_factor=1.0,
1174
+ add_upsample=True,
1175
+ ):
1176
+ super().__init__()
1177
+ resnets = []
1178
+
1179
+ for i in range(num_layers):
1180
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1181
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1182
+
1183
+ resnets.append(
1184
+ ResnetBlock2D(
1185
+ in_channels=resnet_in_channels + res_skip_channels,
1186
+ out_channels=out_channels,
1187
+ temb_channels=temb_channels,
1188
+ eps=resnet_eps,
1189
+ groups=resnet_groups,
1190
+ dropout=dropout,
1191
+ time_embedding_norm=resnet_time_scale_shift,
1192
+ non_linearity=resnet_act_fn,
1193
+ output_scale_factor=output_scale_factor,
1194
+ pre_norm=resnet_pre_norm,
1195
+ )
1196
+ )
1197
+
1198
+ self.resnets = nn.ModuleList(resnets)
1199
+
1200
+ if add_upsample:
1201
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1202
+ else:
1203
+ self.upsamplers = None
1204
+
1205
+ self.gradient_checkpointing = False
1206
+
1207
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1208
+ for resnet in self.resnets:
1209
+ # pop res hidden states
1210
+ res_hidden_states = res_hidden_states_tuple[-1]
1211
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1212
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1213
+
1214
+ if self.training and self.gradient_checkpointing:
1215
+
1216
+ def create_custom_forward(module):
1217
+ def custom_forward(*inputs):
1218
+ return module(*inputs)
1219
+
1220
+ return custom_forward
1221
+
1222
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1223
+ else:
1224
+ hidden_states = resnet(hidden_states, temb)
1225
+
1226
+ if self.upsamplers is not None:
1227
+ for upsampler in self.upsamplers:
1228
+ hidden_states = upsampler(hidden_states)
1229
+
1230
+ return hidden_states
1231
+
1232
+
1233
+ class UpDecoderBlock2D(nn.Module):
1234
+ def __init__(
1235
+ self,
1236
+ in_channels: int,
1237
+ out_channels: int,
1238
+ dropout: float = 0.0,
1239
+ num_layers: int = 1,
1240
+ resnet_eps: float = 1e-6,
1241
+ resnet_time_scale_shift: str = "default",
1242
+ resnet_act_fn: str = "swish",
1243
+ resnet_groups: int = 32,
1244
+ resnet_pre_norm: bool = True,
1245
+ output_scale_factor=1.0,
1246
+ add_upsample=True,
1247
+ ):
1248
+ super().__init__()
1249
+ resnets = []
1250
+
1251
+ for i in range(num_layers):
1252
+ input_channels = in_channels if i == 0 else out_channels
1253
+
1254
+ resnets.append(
1255
+ ResnetBlock2D(
1256
+ in_channels=input_channels,
1257
+ out_channels=out_channels,
1258
+ temb_channels=None,
1259
+ eps=resnet_eps,
1260
+ groups=resnet_groups,
1261
+ dropout=dropout,
1262
+ time_embedding_norm=resnet_time_scale_shift,
1263
+ non_linearity=resnet_act_fn,
1264
+ output_scale_factor=output_scale_factor,
1265
+ pre_norm=resnet_pre_norm,
1266
+ )
1267
+ )
1268
+
1269
+ self.resnets = nn.ModuleList(resnets)
1270
+
1271
+ if add_upsample:
1272
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1273
+ else:
1274
+ self.upsamplers = None
1275
+
1276
+ def forward(self, hidden_states):
1277
+ for resnet in self.resnets:
1278
+ hidden_states = resnet(hidden_states, temb=None)
1279
+
1280
+ if self.upsamplers is not None:
1281
+ for upsampler in self.upsamplers:
1282
+ hidden_states = upsampler(hidden_states)
1283
+
1284
+ return hidden_states
1285
+
1286
+
1287
+ class AttnUpDecoderBlock2D(nn.Module):
1288
+ def __init__(
1289
+ self,
1290
+ in_channels: int,
1291
+ out_channels: int,
1292
+ dropout: float = 0.0,
1293
+ num_layers: int = 1,
1294
+ resnet_eps: float = 1e-6,
1295
+ resnet_time_scale_shift: str = "default",
1296
+ resnet_act_fn: str = "swish",
1297
+ resnet_groups: int = 32,
1298
+ resnet_pre_norm: bool = True,
1299
+ attn_num_head_channels=1,
1300
+ output_scale_factor=1.0,
1301
+ add_upsample=True,
1302
+ ):
1303
+ super().__init__()
1304
+ resnets = []
1305
+ attentions = []
1306
+
1307
+ for i in range(num_layers):
1308
+ input_channels = in_channels if i == 0 else out_channels
1309
+
1310
+ resnets.append(
1311
+ ResnetBlock2D(
1312
+ in_channels=input_channels,
1313
+ out_channels=out_channels,
1314
+ temb_channels=None,
1315
+ eps=resnet_eps,
1316
+ groups=resnet_groups,
1317
+ dropout=dropout,
1318
+ time_embedding_norm=resnet_time_scale_shift,
1319
+ non_linearity=resnet_act_fn,
1320
+ output_scale_factor=output_scale_factor,
1321
+ pre_norm=resnet_pre_norm,
1322
+ )
1323
+ )
1324
+ attentions.append(
1325
+ AttentionBlock(
1326
+ out_channels,
1327
+ num_head_channels=attn_num_head_channels,
1328
+ rescale_output_factor=output_scale_factor,
1329
+ eps=resnet_eps,
1330
+ num_groups=resnet_groups,
1331
+ )
1332
+ )
1333
+
1334
+ self.attentions = nn.ModuleList(attentions)
1335
+ self.resnets = nn.ModuleList(resnets)
1336
+
1337
+ if add_upsample:
1338
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1339
+ else:
1340
+ self.upsamplers = None
1341
+
1342
+ def forward(self, hidden_states):
1343
+ for resnet, attn in zip(self.resnets, self.attentions):
1344
+ hidden_states = resnet(hidden_states, temb=None)
1345
+ hidden_states = attn(hidden_states)
1346
+
1347
+ if self.upsamplers is not None:
1348
+ for upsampler in self.upsamplers:
1349
+ hidden_states = upsampler(hidden_states)
1350
+
1351
+ return hidden_states
1352
+
1353
+
1354
+ class AttnSkipUpBlock2D(nn.Module):
1355
+ def __init__(
1356
+ self,
1357
+ in_channels: int,
1358
+ prev_output_channel: int,
1359
+ out_channels: int,
1360
+ temb_channels: int,
1361
+ dropout: float = 0.0,
1362
+ num_layers: int = 1,
1363
+ resnet_eps: float = 1e-6,
1364
+ resnet_time_scale_shift: str = "default",
1365
+ resnet_act_fn: str = "swish",
1366
+ resnet_pre_norm: bool = True,
1367
+ attn_num_head_channels=1,
1368
+ attention_type="default",
1369
+ output_scale_factor=np.sqrt(2.0),
1370
+ upsample_padding=1,
1371
+ add_upsample=True,
1372
+ ):
1373
+ super().__init__()
1374
+ self.attentions = nn.ModuleList([])
1375
+ self.resnets = nn.ModuleList([])
1376
+
1377
+ self.attention_type = attention_type
1378
+
1379
+ for i in range(num_layers):
1380
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1381
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1382
+
1383
+ self.resnets.append(
1384
+ ResnetBlock2D(
1385
+ in_channels=resnet_in_channels + res_skip_channels,
1386
+ out_channels=out_channels,
1387
+ temb_channels=temb_channels,
1388
+ eps=resnet_eps,
1389
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1390
+ groups_out=min(out_channels // 4, 32),
1391
+ dropout=dropout,
1392
+ time_embedding_norm=resnet_time_scale_shift,
1393
+ non_linearity=resnet_act_fn,
1394
+ output_scale_factor=output_scale_factor,
1395
+ pre_norm=resnet_pre_norm,
1396
+ )
1397
+ )
1398
+
1399
+ self.attentions.append(
1400
+ AttentionBlock(
1401
+ out_channels,
1402
+ num_head_channels=attn_num_head_channels,
1403
+ rescale_output_factor=output_scale_factor,
1404
+ eps=resnet_eps,
1405
+ )
1406
+ )
1407
+
1408
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1409
+ if add_upsample:
1410
+ self.resnet_up = ResnetBlock2D(
1411
+ in_channels=out_channels,
1412
+ out_channels=out_channels,
1413
+ temb_channels=temb_channels,
1414
+ eps=resnet_eps,
1415
+ groups=min(out_channels // 4, 32),
1416
+ groups_out=min(out_channels // 4, 32),
1417
+ dropout=dropout,
1418
+ time_embedding_norm=resnet_time_scale_shift,
1419
+ non_linearity=resnet_act_fn,
1420
+ output_scale_factor=output_scale_factor,
1421
+ pre_norm=resnet_pre_norm,
1422
+ use_in_shortcut=True,
1423
+ up=True,
1424
+ kernel="fir",
1425
+ )
1426
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1427
+ self.skip_norm = torch.nn.GroupNorm(
1428
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1429
+ )
1430
+ self.act = nn.SiLU()
1431
+ else:
1432
+ self.resnet_up = None
1433
+ self.skip_conv = None
1434
+ self.skip_norm = None
1435
+ self.act = None
1436
+
1437
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1438
+ for resnet in self.resnets:
1439
+ # pop res hidden states
1440
+ res_hidden_states = res_hidden_states_tuple[-1]
1441
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1442
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1443
+
1444
+ hidden_states = resnet(hidden_states, temb)
1445
+
1446
+ hidden_states = self.attentions[0](hidden_states)
1447
+
1448
+ if skip_sample is not None:
1449
+ skip_sample = self.upsampler(skip_sample)
1450
+ else:
1451
+ skip_sample = 0
1452
+
1453
+ if self.resnet_up is not None:
1454
+ skip_sample_states = self.skip_norm(hidden_states)
1455
+ skip_sample_states = self.act(skip_sample_states)
1456
+ skip_sample_states = self.skip_conv(skip_sample_states)
1457
+
1458
+ skip_sample = skip_sample + skip_sample_states
1459
+
1460
+ hidden_states = self.resnet_up(hidden_states, temb)
1461
+
1462
+ return hidden_states, skip_sample
1463
+
1464
+
1465
+ class SkipUpBlock2D(nn.Module):
1466
+ def __init__(
1467
+ self,
1468
+ in_channels: int,
1469
+ prev_output_channel: int,
1470
+ out_channels: int,
1471
+ temb_channels: int,
1472
+ dropout: float = 0.0,
1473
+ num_layers: int = 1,
1474
+ resnet_eps: float = 1e-6,
1475
+ resnet_time_scale_shift: str = "default",
1476
+ resnet_act_fn: str = "swish",
1477
+ resnet_pre_norm: bool = True,
1478
+ output_scale_factor=np.sqrt(2.0),
1479
+ add_upsample=True,
1480
+ upsample_padding=1,
1481
+ ):
1482
+ super().__init__()
1483
+ self.resnets = nn.ModuleList([])
1484
+
1485
+ for i in range(num_layers):
1486
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1487
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1488
+
1489
+ self.resnets.append(
1490
+ ResnetBlock2D(
1491
+ in_channels=resnet_in_channels + res_skip_channels,
1492
+ out_channels=out_channels,
1493
+ temb_channels=temb_channels,
1494
+ eps=resnet_eps,
1495
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1496
+ groups_out=min(out_channels // 4, 32),
1497
+ dropout=dropout,
1498
+ time_embedding_norm=resnet_time_scale_shift,
1499
+ non_linearity=resnet_act_fn,
1500
+ output_scale_factor=output_scale_factor,
1501
+ pre_norm=resnet_pre_norm,
1502
+ )
1503
+ )
1504
+
1505
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1506
+ if add_upsample:
1507
+ self.resnet_up = ResnetBlock2D(
1508
+ in_channels=out_channels,
1509
+ out_channels=out_channels,
1510
+ temb_channels=temb_channels,
1511
+ eps=resnet_eps,
1512
+ groups=min(out_channels // 4, 32),
1513
+ groups_out=min(out_channels // 4, 32),
1514
+ dropout=dropout,
1515
+ time_embedding_norm=resnet_time_scale_shift,
1516
+ non_linearity=resnet_act_fn,
1517
+ output_scale_factor=output_scale_factor,
1518
+ pre_norm=resnet_pre_norm,
1519
+ use_in_shortcut=True,
1520
+ up=True,
1521
+ kernel="fir",
1522
+ )
1523
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1524
+ self.skip_norm = torch.nn.GroupNorm(
1525
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1526
+ )
1527
+ self.act = nn.SiLU()
1528
+ else:
1529
+ self.resnet_up = None
1530
+ self.skip_conv = None
1531
+ self.skip_norm = None
1532
+ self.act = None
1533
+
1534
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1535
+ for resnet in self.resnets:
1536
+ # pop res hidden states
1537
+ res_hidden_states = res_hidden_states_tuple[-1]
1538
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1539
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1540
+
1541
+ hidden_states = resnet(hidden_states, temb)
1542
+
1543
+ if skip_sample is not None:
1544
+ skip_sample = self.upsampler(skip_sample)
1545
+ else:
1546
+ skip_sample = 0
1547
+
1548
+ if self.resnet_up is not None:
1549
+ skip_sample_states = self.skip_norm(hidden_states)
1550
+ skip_sample_states = self.act(skip_sample_states)
1551
+ skip_sample_states = self.skip_conv(skip_sample_states)
1552
+
1553
+ skip_sample = skip_sample + skip_sample_states
1554
+
1555
+ hidden_states = self.resnet_up(hidden_states, temb)
1556
+
1557
+ return hidden_states, skip_sample
medical_diffusion/external/diffusers/vae.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from typing import Optional, Tuple, Union
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from itertools import chain
11
+
12
+ from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
13
+ from .taming_discriminator import NLayerDiscriminator
14
+ from medical_diffusion.models import BasicModel
15
+ from torchvision.utils import save_image
16
+
17
+ from torch.distributions.normal import Normal
18
+ from torch.distributions import kl_divergence
19
+
20
+ class Encoder(nn.Module):
21
+ def __init__(
22
+ self,
23
+ in_channels=3,
24
+ out_channels=3,
25
+ down_block_types=("DownEncoderBlock2D",),
26
+ block_out_channels=(64),
27
+ layers_per_block=2,
28
+ norm_num_groups=32,
29
+ act_fn="silu",
30
+ double_z=True,
31
+ ):
32
+ super().__init__()
33
+ self.layers_per_block = layers_per_block
34
+
35
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
36
+
37
+ self.mid_block = None
38
+ self.down_blocks = nn.ModuleList([])
39
+
40
+ # down
41
+ output_channel = block_out_channels[0]
42
+ for i, down_block_type in enumerate(down_block_types):
43
+ input_channel = output_channel
44
+ output_channel = block_out_channels[i+1]
45
+ is_final_block = False #i == len(block_out_channels) - 1
46
+
47
+ down_block = get_down_block(
48
+ down_block_type,
49
+ num_layers=self.layers_per_block,
50
+ in_channels=input_channel,
51
+ out_channels=output_channel,
52
+ add_downsample=not is_final_block,
53
+ resnet_eps=1e-6,
54
+ downsample_padding=0,
55
+ resnet_act_fn=act_fn,
56
+ resnet_groups=norm_num_groups,
57
+ attn_num_head_channels=None,
58
+ temb_channels=None,
59
+ )
60
+ self.down_blocks.append(down_block)
61
+
62
+ # mid
63
+ self.mid_block = UNetMidBlock2D(
64
+ in_channels=block_out_channels[-1],
65
+ resnet_eps=1e-6,
66
+ resnet_act_fn=act_fn,
67
+ output_scale_factor=1,
68
+ resnet_time_scale_shift="default",
69
+ attn_num_head_channels=None,
70
+ resnet_groups=norm_num_groups,
71
+ temb_channels=None,
72
+ )
73
+
74
+ # out
75
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
76
+ self.conv_act = nn.SiLU()
77
+
78
+ conv_out_channels = 2 * out_channels if double_z else out_channels
79
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
80
+
81
+ def forward(self, x):
82
+ sample = x
83
+ sample = self.conv_in(sample)
84
+
85
+ # down
86
+ for down_block in self.down_blocks:
87
+ sample = down_block(sample)
88
+
89
+ # middle
90
+ sample = self.mid_block(sample)
91
+
92
+ # post-process
93
+ sample = self.conv_norm_out(sample)
94
+ sample = self.conv_act(sample)
95
+ sample = self.conv_out(sample)
96
+
97
+ return sample
98
+
99
+
100
+ class Decoder(nn.Module):
101
+ def __init__(
102
+ self,
103
+ in_channels=3,
104
+ out_channels=3,
105
+ up_block_types=("UpDecoderBlock2D",),
106
+ block_out_channels=(64,),
107
+ layers_per_block=2,
108
+ norm_num_groups=32,
109
+ act_fn="silu",
110
+ ):
111
+ super().__init__()
112
+ self.layers_per_block = layers_per_block
113
+
114
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
115
+
116
+ self.mid_block = None
117
+ self.up_blocks = nn.ModuleList([])
118
+
119
+ # mid
120
+ self.mid_block = UNetMidBlock2D(
121
+ in_channels=block_out_channels[-1],
122
+ resnet_eps=1e-6,
123
+ resnet_act_fn=act_fn,
124
+ output_scale_factor=1,
125
+ resnet_time_scale_shift="default",
126
+ attn_num_head_channels=None,
127
+ resnet_groups=norm_num_groups,
128
+ temb_channels=None,
129
+ )
130
+
131
+ # up
132
+ reversed_block_out_channels = list(reversed(block_out_channels))
133
+ output_channel = reversed_block_out_channels[0]
134
+ for i, up_block_type in enumerate(up_block_types):
135
+ prev_output_channel = output_channel
136
+ output_channel = reversed_block_out_channels[i+1]
137
+
138
+ is_final_block = False # i == len(block_out_channels) - 1
139
+
140
+ up_block = get_up_block(
141
+ up_block_type,
142
+ num_layers=self.layers_per_block + 1,
143
+ in_channels=prev_output_channel,
144
+ out_channels=output_channel,
145
+ prev_output_channel=None,
146
+ add_upsample=not is_final_block,
147
+ resnet_eps=1e-6,
148
+ resnet_act_fn=act_fn,
149
+ resnet_groups=norm_num_groups,
150
+ attn_num_head_channels=None,
151
+ temb_channels=None,
152
+ )
153
+ self.up_blocks.append(up_block)
154
+ prev_output_channel = output_channel
155
+
156
+ # out
157
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
158
+ self.conv_act = nn.SiLU()
159
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
160
+
161
+ def forward(self, z):
162
+ sample = z
163
+ sample = self.conv_in(sample)
164
+
165
+ # middle
166
+ sample = self.mid_block(sample)
167
+
168
+ # up
169
+ for up_block in self.up_blocks:
170
+ sample = up_block(sample)
171
+
172
+ # post-process
173
+ sample = self.conv_norm_out(sample)
174
+ sample = self.conv_act(sample)
175
+ sample = self.conv_out(sample)
176
+
177
+ return sample
178
+
179
+
180
+ class VectorQuantizer(nn.Module):
181
+ """
182
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
183
+ multiplications and allows for post-hoc remapping of indices.
184
+ """
185
+
186
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
187
+ # backwards compatibility we use the buggy version by default, but you can
188
+ # specify legacy=False to fix it.
189
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=False):
190
+ super().__init__()
191
+ self.n_e = n_e
192
+ self.e_dim = e_dim
193
+ self.beta = beta
194
+ self.legacy = legacy
195
+
196
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
197
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
198
+
199
+ self.remap = remap
200
+ if self.remap is not None:
201
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
202
+ self.re_embed = self.used.shape[0]
203
+ self.unknown_index = unknown_index # "random" or "extra" or integer
204
+ if self.unknown_index == "extra":
205
+ self.unknown_index = self.re_embed
206
+ self.re_embed = self.re_embed + 1
207
+ print(
208
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
209
+ f"Using {self.unknown_index} for unknown indices."
210
+ )
211
+ else:
212
+ self.re_embed = n_e
213
+
214
+ self.sane_index_shape = sane_index_shape
215
+
216
+ def remap_to_used(self, inds):
217
+ ishape = inds.shape
218
+ assert len(ishape) > 1
219
+ inds = inds.reshape(ishape[0], -1)
220
+ used = self.used.to(inds)
221
+ match = (inds[:, :, None] == used[None, None, ...]).long()
222
+ new = match.argmax(-1)
223
+ unknown = match.sum(2) < 1
224
+ if self.unknown_index == "random":
225
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
226
+ else:
227
+ new[unknown] = self.unknown_index
228
+ return new.reshape(ishape)
229
+
230
+ def unmap_to_all(self, inds):
231
+ ishape = inds.shape
232
+ assert len(ishape) > 1
233
+ inds = inds.reshape(ishape[0], -1)
234
+ used = self.used.to(inds)
235
+ if self.re_embed > self.used.shape[0]: # extra token
236
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
237
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
238
+ return back.reshape(ishape)
239
+
240
+ def forward(self, z):
241
+ # reshape z -> (batch, height, width, channel) and flatten
242
+ z = z.permute(0, 2, 3, 1).contiguous()
243
+ z_flattened = z.view(-1, self.e_dim)
244
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
245
+
246
+ d = (
247
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
248
+ + torch.sum(self.embedding.weight**2, dim=1)
249
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
250
+ )
251
+
252
+ min_encoding_indices = torch.argmin(d, dim=1)
253
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
254
+ perplexity = None
255
+ min_encodings = None
256
+
257
+ # compute loss for embedding
258
+ if not self.legacy:
259
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
260
+ else:
261
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
262
+
263
+ # preserve gradients
264
+ z_q = z + (z_q - z).detach()
265
+
266
+ # reshape back to match original input shape
267
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
268
+
269
+ if self.remap is not None:
270
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
271
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
272
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
273
+
274
+ if self.sane_index_shape:
275
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
276
+
277
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
278
+
279
+ def get_codebook_entry(self, indices, shape):
280
+ # shape specifying (batch, height, width, channel)
281
+ if self.remap is not None:
282
+ indices = indices.reshape(shape[0], -1) # add batch axis
283
+ indices = self.unmap_to_all(indices)
284
+ indices = indices.reshape(-1) # flatten again
285
+
286
+ # get quantized latent vectors
287
+ z_q = self.embedding(indices)
288
+
289
+ if shape is not None:
290
+ z_q = z_q.view(shape)
291
+ # reshape back to match original input shape
292
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
293
+
294
+ return z_q
295
+
296
+
297
+ class DiagonalGaussianDistribution(object):
298
+ def __init__(self, parameters, deterministic=False):
299
+ self.batch_size = parameters.shape[0]
300
+ self.parameters = parameters
301
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
302
+ # self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
303
+ self.deterministic = deterministic
304
+ self.std = torch.exp(0.5 * self.logvar)
305
+ self.var = torch.exp(self.logvar)
306
+ if self.deterministic:
307
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
308
+
309
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
310
+ device = self.parameters.device
311
+ sample_device = "cpu" if device.type == "mps" else device
312
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
313
+ x = self.mean + self.std * sample
314
+ return x
315
+
316
+ def kl(self, other=None):
317
+ if self.deterministic:
318
+ return torch.Tensor([0.0])
319
+ else:
320
+ if other is None:
321
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar)/self.batch_size
322
+ else:
323
+ return 0.5 * torch.sum(
324
+ torch.pow(self.mean - other.mean, 2) / other.var
325
+ + self.var / other.var
326
+ - 1.0
327
+ - self.logvar
328
+ + other.logvar,
329
+ )/self.batch_size
330
+
331
+ # q_z_x = Normal(self.mean, self.logvar.mul(.5).exp())
332
+ # p_z = Normal(torch.zeros_like(self.mean), torch.ones_like(self.logvar))
333
+ # kl_div = kl_divergence(q_z_x, p_z).sum(1).mean()
334
+ # return kl_div
335
+
336
+ def nll(self, sample, dims=[1, 2, 3]):
337
+ if self.deterministic:
338
+ return torch.Tensor([0.0])
339
+ logtwopi = np.log(2.0 * np.pi)
340
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
341
+
342
+ def mode(self):
343
+ return self.mean
344
+
345
+
346
+ class VQModel(nn.Module):
347
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
348
+ Kavukcuoglu.
349
+
350
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
351
+ implements for all the model (such as downloading or saving, etc.)
352
+
353
+ Parameters:
354
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
355
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
356
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
357
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
358
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
359
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
360
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
361
+ obj:`(64,)`): Tuple of block output channels.
362
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
363
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
364
+ sample_size (`int`, *optional*, defaults to `32`): TODO
365
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
366
+ """
367
+
368
+
369
+ def __init__(
370
+ self,
371
+ in_channels: int = 3,
372
+ out_channels: int = 3,
373
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
374
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
375
+ block_out_channels: Tuple[int] = (32, 64, 128, 256),
376
+ layers_per_block: int = 1,
377
+ act_fn: str = "silu",
378
+ latent_channels: int = 3,
379
+ sample_size: int = 32,
380
+ num_vq_embeddings: int = 256,
381
+ norm_num_groups: int = 32,
382
+ ):
383
+ super().__init__()
384
+
385
+ # pass init params to Encoder
386
+ self.encoder = Encoder(
387
+ in_channels=in_channels,
388
+ out_channels=latent_channels,
389
+ down_block_types=down_block_types,
390
+ block_out_channels=block_out_channels,
391
+ layers_per_block=layers_per_block,
392
+ act_fn=act_fn,
393
+ norm_num_groups=norm_num_groups,
394
+ double_z=False,
395
+ )
396
+
397
+ self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
398
+ self.quantize = VectorQuantizer(
399
+ num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
400
+ )
401
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
402
+
403
+ # pass init params to Decoder
404
+ self.decoder = Decoder(
405
+ in_channels=latent_channels,
406
+ out_channels=out_channels,
407
+ up_block_types=up_block_types,
408
+ block_out_channels=block_out_channels,
409
+ layers_per_block=layers_per_block,
410
+ act_fn=act_fn,
411
+ norm_num_groups=norm_num_groups,
412
+ )
413
+
414
+ # def encode(self, x: torch.FloatTensor):
415
+ # z = self.encoder(x)
416
+ # z = self.quant_conv(z)
417
+ # return z
418
+
419
+ def encode(self, x, return_loss=True, force_quantize= True):
420
+ z = self.encoder(x)
421
+ z = self.quant_conv(z)
422
+
423
+ if force_quantize:
424
+ z_q, emb_loss, _ = self.quantize(z)
425
+ else:
426
+ z_q, emb_loss = z, None
427
+
428
+ if return_loss:
429
+ return z_q, emb_loss
430
+ else:
431
+ return z_q
432
+
433
+ def decode(self, z_q) -> torch.FloatTensor:
434
+ z_q = self.post_quant_conv(z_q)
435
+ x = self.decoder(z_q)
436
+ return x
437
+
438
+ # def decode(self, z: torch.FloatTensor, return_loss=True, force_quantize: bool = True) -> torch.FloatTensor:
439
+ # if force_quantize:
440
+ # z_q, emb_loss, _ = self.quantize(z)
441
+ # else:
442
+ # z_q, emb_loss = z, None
443
+
444
+ # z_q = self.post_quant_conv(z_q)
445
+ # x = self.decoder(z_q)
446
+
447
+ # if return_loss:
448
+ # return x, emb_loss
449
+ # else:
450
+ # return x
451
+
452
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
453
+ r"""
454
+ Args:
455
+ sample (`torch.FloatTensor`): Input sample.
456
+ """
457
+ # h = self.encode(sample)
458
+ h, emb_loss = self.encode(sample)
459
+ dec = self.decode(h)
460
+ # dec, emb_loss = self.decode(h)
461
+
462
+ return dec, emb_loss
463
+
464
+
465
+ class AutoencoderKL(nn.Module):
466
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
467
+ and Max Welling.
468
+
469
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
470
+ implements for all the model (such as downloading or saving, etc.)
471
+
472
+ Parameters:
473
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
474
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
475
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
476
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
477
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
478
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
479
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
480
+ obj:`(64,)`): Tuple of block output channels.
481
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
482
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
483
+ sample_size (`int`, *optional*, defaults to `32`): TODO
484
+ """
485
+
486
+
487
+ def __init__(
488
+ self,
489
+ in_channels: int = 3,
490
+ out_channels: int = 3,
491
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D","DownEncoderBlock2D",),
492
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D",),
493
+ block_out_channels: Tuple[int] = (32, 64, 128, 128),
494
+ layers_per_block: int = 1,
495
+ act_fn: str = "silu",
496
+ latent_channels: int = 3,
497
+ norm_num_groups: int = 32,
498
+ sample_size: int = 32,
499
+ ):
500
+ super().__init__()
501
+
502
+ # pass init params to Encoder
503
+ self.encoder = Encoder(
504
+ in_channels=in_channels,
505
+ out_channels=latent_channels,
506
+ down_block_types=down_block_types,
507
+ block_out_channels=block_out_channels,
508
+ layers_per_block=layers_per_block,
509
+ act_fn=act_fn,
510
+ norm_num_groups=norm_num_groups,
511
+ double_z=True,
512
+ )
513
+
514
+ # pass init params to Decoder
515
+ self.decoder = Decoder(
516
+ in_channels=latent_channels,
517
+ out_channels=out_channels,
518
+ up_block_types=up_block_types,
519
+ block_out_channels=block_out_channels,
520
+ layers_per_block=layers_per_block,
521
+ norm_num_groups=norm_num_groups,
522
+ act_fn=act_fn,
523
+ )
524
+
525
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
526
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
527
+
528
+ def encode(self, x: torch.FloatTensor):
529
+ h = self.encoder(x)
530
+ moments = self.quant_conv(h)
531
+ posterior = DiagonalGaussianDistribution(moments)
532
+ return posterior
533
+
534
+ def decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
535
+ z = self.post_quant_conv(z)
536
+ dec = self.decoder(z)
537
+ return dec
538
+
539
+ def forward(
540
+ self,
541
+ sample: torch.FloatTensor,
542
+ sample_posterior: bool = True,
543
+ generator: Optional[torch.Generator] = None,
544
+ ) -> torch.FloatTensor:
545
+ r"""
546
+ Args:
547
+ sample (`torch.FloatTensor`): Input sample.
548
+ sample_posterior (`bool`, *optional*, defaults to `False`):
549
+ Whether to sample from the posterior.
550
+ """
551
+ x = sample
552
+ posterior = self.encode(x)
553
+ if sample_posterior:
554
+ z = posterior.sample(generator=generator)
555
+ else:
556
+ z = posterior.mode()
557
+ kl_loss = posterior.kl()
558
+ dec = self.decode(z)
559
+ return dec, kl_loss
560
+
561
+
562
+
563
+ class VQVAEWrapper(BasicModel):
564
+ def __init__(
565
+ self,
566
+ in_ch: int = 3,
567
+ out_ch: int = 3,
568
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D",),
569
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D",),
570
+ block_out_channels: Tuple[int] = (32, 64, 128, 256, ),
571
+ layers_per_block: int = 1,
572
+ act_fn: str = "silu",
573
+ latent_channels: int = 3,
574
+ sample_size: int = 32,
575
+ num_vq_embeddings: int = 64,
576
+ norm_num_groups: int = 32,
577
+
578
+ optimizer=torch.optim.AdamW,
579
+ optimizer_kwargs={},
580
+ lr_scheduler=None,
581
+ lr_scheduler_kwargs={},
582
+ loss=torch.nn.MSELoss,
583
+ loss_kwargs={}
584
+ ):
585
+ super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs, loss, loss_kwargs)
586
+ self.model = VQModel(in_ch, out_ch, down_block_types, up_block_types, block_out_channels,
587
+ layers_per_block, act_fn, latent_channels, sample_size, num_vq_embeddings, norm_num_groups)
588
+
589
+ def forward(self, sample):
590
+ return self.model(sample)
591
+
592
+ def encode(self, x):
593
+ z = self.model.encode(x, return_loss=False)
594
+ return z
595
+
596
+ def decode(self, z):
597
+ x = self.model.decode(z)
598
+ return x
599
+
600
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
601
+ # ------------------------- Get Source/Target ---------------------------
602
+ x = batch['source']
603
+ target = x
604
+
605
+ # ------------------------- Run Model ---------------------------
606
+ pred, vq_loss = self(x)
607
+
608
+ # ------------------------- Compute Loss ---------------------------
609
+ loss = self.loss_fct(pred, target)
610
+ loss += vq_loss
611
+
612
+ # --------------------- Compute Metrics -------------------------------
613
+ results = {'loss':loss}
614
+ with torch.no_grad():
615
+ results['L2'] = torch.nn.functional.mse_loss(pred, target)
616
+ results['L1'] = torch.nn.functional.l1_loss(pred, target)
617
+
618
+ # ----------------- Log Scalars ----------------------
619
+ for metric_name, metric_val in results.items():
620
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
621
+
622
+ # ----------------- Save Image ------------------------------
623
+ if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0:
624
+ def norm(x):
625
+ return (x-x.min())/(x.max()-x.min())
626
+
627
+ images = [x, pred]
628
+ log_step = self.global_step // self.trainer.log_every_n_steps
629
+ path_out = Path(self.logger.log_dir)/'images'
630
+ path_out.mkdir(parents=True, exist_ok=True)
631
+ images = torch.cat([norm(img) for img in images])
632
+ save_image(images, path_out/f'sample_{log_step}.png')
633
+
634
+ return loss
635
+
636
+ def hinge_d_loss(logits_real, logits_fake):
637
+ loss_real = torch.mean(F.relu(1. - logits_real))
638
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
639
+ d_loss = 0.5 * (loss_real + loss_fake)
640
+ return d_loss
641
+
642
+ def vanilla_d_loss(logits_real, logits_fake):
643
+ d_loss = 0.5 * (
644
+ torch.mean(F.softplus(-logits_real)) +
645
+ torch.mean(F.softplus(logits_fake)))
646
+ return d_loss
647
+
648
+ class VQGAN(BasicModel):
649
+ def __init__(
650
+ self,
651
+ in_ch: int = 3,
652
+ out_ch: int = 3,
653
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D",),
654
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D",),
655
+ block_out_channels: Tuple[int] = (32, 64, 128, 256, ),
656
+ layers_per_block: int = 1,
657
+ act_fn: str = "silu",
658
+ latent_channels: int = 3,
659
+ sample_size: int = 32,
660
+ num_vq_embeddings: int = 64,
661
+ norm_num_groups: int = 32,
662
+
663
+ start_gan_train_step = 50000, # NOTE step increase with each optimizer
664
+ gan_loss_weight: float = 1.0, # alias discriminator
665
+ perceptual_loss_weight: float = 1.0,
666
+ embedding_loss_weight: float = 1.0,
667
+
668
+ optimizer=torch.optim.AdamW,
669
+ optimizer_kwargs={},
670
+ lr_scheduler=None,
671
+ lr_scheduler_kwargs={},
672
+ loss=torch.nn.MSELoss,
673
+ loss_kwargs={}
674
+ ):
675
+ super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs, loss, loss_kwargs)
676
+ self.vqvae = VQModel(in_ch, out_ch, down_block_types, up_block_types, block_out_channels, layers_per_block, act_fn,
677
+ latent_channels, sample_size, num_vq_embeddings, norm_num_groups)
678
+ self.discriminator = NLayerDiscriminator(in_ch)
679
+ # self.perceiver = ... # Currently not supported, would require another trained NN
680
+
681
+ self.start_gan_train_step = start_gan_train_step
682
+ self.perceptual_loss_weight = perceptual_loss_weight
683
+ self.gan_loss_weight = gan_loss_weight
684
+ self.embedding_loss_weight = embedding_loss_weight
685
+
686
+ def forward(self, x, condition=None):
687
+ return self.vqvae(x)
688
+
689
+ def encode(self, x):
690
+ z = self.vqvae.encode(x, return_loss=False)
691
+ return z
692
+
693
+ def decode(self, z):
694
+ x = self.vqvae.decode(z)
695
+ return x
696
+
697
+
698
+ def compute_lambda(self, rec_loss, gan_loss, eps=1e-4):
699
+ """Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841"""
700
+ last_layer = self.vqvae.decoder.conv_out.weight
701
+ rec_grads = torch.autograd.grad(rec_loss, last_layer, retain_graph=True)[0]
702
+ gan_grads = torch.autograd.grad(gan_loss, last_layer, retain_graph=True)[0]
703
+ d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps)
704
+ d_weight = torch.clamp(d_weight, 0.0, 1e4)
705
+ return d_weight.detach()
706
+
707
+
708
+
709
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
710
+ x = batch['source']
711
+ # condition = batch.get('target', None)
712
+
713
+ pred, vq_emb_loss = self.vqvae(x)
714
+
715
+ if optimizer_idx == 0:
716
+ # ------ VAE -------
717
+ vq_img_loss = F.mse_loss(pred, x)
718
+ vq_per_loss = 0.0 #self.perceiver(pred, x)
719
+ rec_loss = vq_img_loss+self.perceptual_loss_weight*vq_per_loss
720
+
721
+ # ------- GAN -----
722
+ if step > self.start_gan_train_step:
723
+ gan_loss = -torch.mean(self.discriminator(pred))
724
+ lambda_weight = self.compute_lambda(rec_loss, gan_loss)
725
+ gan_loss = gan_loss*lambda_weight
726
+ else:
727
+ gan_loss = torch.tensor([0.0], requires_grad=True, device=x.device)
728
+
729
+ loss = self.gan_loss_weight*gan_loss+rec_loss+self.embedding_loss_weight*vq_emb_loss
730
+
731
+ elif optimizer_idx == 1:
732
+ if step > self.start_gan_train_step//2:
733
+ logits_real = self.discriminator(x.detach())
734
+ logits_fake = self.discriminator(pred.detach())
735
+ loss = hinge_d_loss(logits_real, logits_fake)
736
+ else:
737
+ loss = torch.tensor([0.0], requires_grad=True, device=x.device)
738
+
739
+ # --------------------- Compute Metrics -------------------------------
740
+ results = {'loss':loss.detach(), f'loss_{optimizer_idx}':loss.detach()}
741
+ with torch.no_grad():
742
+ results[f'L2'] = torch.nn.functional.mse_loss(pred, x)
743
+ results[f'L1'] = torch.nn.functional.l1_loss(pred, x)
744
+
745
+ # ----------------- Log Scalars ----------------------
746
+ for metric_name, metric_val in results.items():
747
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
748
+
749
+ # ----------------- Save Image ------------------------------
750
+ if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0: # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ...
751
+ def norm(x):
752
+ return (x-x.min())/(x.max()-x.min())
753
+
754
+ images = torch.cat([x, pred])
755
+ log_step = self.global_step // self.trainer.log_every_n_steps
756
+ path_out = Path(self.logger.log_dir)/'images'
757
+ path_out.mkdir(parents=True, exist_ok=True)
758
+ images = torch.stack([norm(img) for img in images])
759
+ save_image(images, path_out/f'sample_{log_step}.png')
760
+
761
+ return loss
762
+
763
+ def configure_optimizers(self):
764
+ opt_vae = self.optimizer(self.vqvae.parameters(), **self.optimizer_kwargs)
765
+ opt_disc = self.optimizer(self.discriminator.parameters(), **self.optimizer_kwargs)
766
+ if self.lr_scheduler is not None:
767
+ scheduler = [
768
+ {
769
+ 'scheduler': self.lr_scheduler(opt_vae, **self.lr_scheduler_kwargs),
770
+ 'interval': 'step',
771
+ 'frequency': 1
772
+ },
773
+ {
774
+ 'scheduler': self.lr_scheduler(opt_disc, **self.lr_scheduler_kwargs),
775
+ 'interval': 'step',
776
+ 'frequency': 1
777
+ },
778
+ ]
779
+ else:
780
+ scheduler = []
781
+
782
+ return [opt_vae, opt_disc], scheduler
783
+
784
+ class VAEWrapper(BasicModel):
785
+ def __init__(
786
+ self,
787
+ in_ch: int = 3,
788
+ out_ch: int = 3,
789
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), # "DownEncoderBlock2D", "DownEncoderBlock2D",
790
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D","UpDecoderBlock2D" ), # "UpDecoderBlock2D", "UpDecoderBlock2D",
791
+ block_out_channels: Tuple[int] = (32, 64, 128, 256), # 128, 256
792
+ layers_per_block: int = 1,
793
+ act_fn: str = "silu",
794
+ latent_channels: int = 3,
795
+ norm_num_groups: int = 32,
796
+ sample_size: int = 32,
797
+
798
+ optimizer=torch.optim.AdamW,
799
+ optimizer_kwargs={'lr':1e-4, 'weight_decay':1e-3, 'amsgrad':True},
800
+ lr_scheduler=None,
801
+ lr_scheduler_kwargs={},
802
+ # loss=torch.nn.MSELoss, # WARNING: No Effect
803
+ # loss_kwargs={'reduction': 'mean'}
804
+ ):
805
+ super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs ) # loss, loss_kwargs
806
+ self.model = AutoencoderKL(in_ch, out_ch, down_block_types, up_block_types, block_out_channels,
807
+ layers_per_block, act_fn, latent_channels, norm_num_groups, sample_size)
808
+
809
+ self.logvar = nn.Parameter(torch.zeros(size=())) # Better weighting between KL and MSE, see (https://arxiv.org/abs/1903.05789), also used by Taming-Transfomer/Stable Diffusion
810
+
811
+ def forward(self, sample):
812
+ return self.model(sample)
813
+
814
+ def encode(self, x):
815
+ z = self.model.encode(x) # Latent space but not yet mapped to discrete embedding vectors
816
+ return z.sample(generator=None)
817
+
818
+ def decode(self, z):
819
+ x = self.model.decode(z)
820
+ return x
821
+
822
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
823
+ # ------------------------- Get Source/Target ---------------------------
824
+ x = batch['source']
825
+ target = x
826
+ HALF_LOG_TWO_PI = 0.91893 # log(2pi)/2
827
+
828
+ # ------------------------- Run Model ---------------------------
829
+ pred, kl_loss = self(x)
830
+
831
+ # ------------------------- Compute Loss ---------------------------
832
+ loss = torch.sum( torch.square(pred-target))/x.shape[0] #torch.sum( torch.square((pred-target)/torch.exp(self.logvar))/2 + self.logvar + HALF_LOG_TWO_PI )/x.shape[0]
833
+ loss += kl_loss
834
+
835
+ # --------------------- Compute Metrics -------------------------------
836
+ results = {'loss':loss.detach()}
837
+ with torch.no_grad():
838
+ results['L2'] = torch.nn.functional.mse_loss(pred, target)
839
+ results['L1'] = torch.nn.functional.l1_loss(pred, target)
840
+
841
+ # ----------------- Log Scalars ----------------------
842
+ for metric_name, metric_val in results.items():
843
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
844
+
845
+ # ----------------- Save Image ------------------------------
846
+ if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0:
847
+ def norm(x):
848
+ return (x-x.min())/(x.max()-x.min())
849
+
850
+ images = torch.cat([x, pred])
851
+ log_step = self.global_step // self.trainer.log_every_n_steps
852
+ path_out = Path(self.logger.log_dir)/'images'
853
+ path_out.mkdir(parents=True, exist_ok=True)
854
+ images = torch.stack([norm(img) for img in images])
855
+ save_image(images, path_out/f'sample_{log_step}.png')
856
+
857
+ return loss
medical_diffusion/external/stable_diffusion/attention.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from .util_attention import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def forward(self, x, context=None, mask=None):
171
+ h = self.heads
172
+
173
+ q = self.to_q(x)
174
+ context = default(context, x)
175
+ k = self.to_k(context)
176
+ v = self.to_v(context)
177
+
178
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179
+
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181
+
182
+ if exists(mask):
183
+ mask = rearrange(mask, 'b ... -> b (...)')
184
+ max_neg_value = -torch.finfo(sim.dtype).max
185
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
186
+ sim.masked_fill_(~mask, max_neg_value)
187
+
188
+ # attention, what we cannot get enough of
189
+ attn = sim.softmax(dim=-1)
190
+
191
+ out = einsum('b i j, b j d -> b i d', attn, v)
192
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193
+ return self.to_out(out)
194
+
195
+
196
+ class BasicTransformerBlock(nn.Module):
197
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198
+ super().__init__()
199
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203
+ self.norm1 = nn.LayerNorm(dim)
204
+ self.norm2 = nn.LayerNorm(dim)
205
+ self.norm3 = nn.LayerNorm(dim)
206
+ self.checkpoint = checkpoint
207
+
208
+ def forward(self, x, context=None):
209
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210
+
211
+ def _forward(self, x, context=None):
212
+ x = self.attn1(self.norm1(x)) + x
213
+ x = self.attn2(self.norm2(x), context=context) + x
214
+ x = self.ff(self.norm3(x)) + x
215
+ return x
216
+
217
+
218
+ class SpatialTransformer(nn.Module):
219
+ """
220
+ Transformer block for image-like data.
221
+ First, project the input (aka embedding)
222
+ and reshape to b, t, d.
223
+ Then apply standard transformer action.
224
+ Finally, reshape to image
225
+ """
226
+ def __init__(self, in_channels, n_heads, d_head,
227
+ depth=1, dropout=0., context_dim=None):
228
+ super().__init__()
229
+ self.in_channels = in_channels
230
+ inner_dim = n_heads * d_head
231
+ self.norm = Normalize(in_channels)
232
+
233
+ self.proj_in = nn.Conv2d(in_channels,
234
+ inner_dim,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0)
238
+
239
+ self.transformer_blocks = nn.ModuleList(
240
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241
+ for d in range(depth)]
242
+ )
243
+
244
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
245
+ in_channels,
246
+ kernel_size=1,
247
+ stride=1,
248
+ padding=0))
249
+
250
+ def forward(self, x, context=None):
251
+ # note: if no context is given, cross-attention defaults to self-attention
252
+ b, c, h, w = x.shape
253
+ x_in = x
254
+ x = self.norm(x)
255
+ x = self.proj_in(x)
256
+ x = rearrange(x, 'b c h w -> b (h w) c')
257
+ for block in self.transformer_blocks:
258
+ x = block(x, context=context)
259
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260
+ x = self.proj_out(x)
261
+ return x + x_in
medical_diffusion/external/stable_diffusion/lr_schedulers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class LambdaLinearScheduler:
4
+ def __init__(self, warm_up_steps=[10000,], f_min=[1.0,], f_max=[1.0,], f_start=[1.e-6], cycle_lengths=[10000000000000], verbosity_interval=0):
5
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
6
+ self.lr_warm_up_steps = warm_up_steps
7
+ self.f_start = f_start
8
+ self.f_min = f_min
9
+ self.f_max = f_max
10
+ self.cycle_lengths = cycle_lengths
11
+ self.cum_cycles = torch.cumsum(torch.tensor([0] + list(self.cycle_lengths)), 0)
12
+ self.last_f = 0.
13
+ self.verbosity_interval = verbosity_interval
14
+
15
+ def find_in_interval(self, n):
16
+ interval = 0
17
+ for cl in self.cum_cycles[1:]:
18
+ if n <= cl:
19
+ return interval
20
+ interval += 1
21
+
22
+ def schedule(self, n, **kwargs):
23
+ cycle = self.find_in_interval(n)
24
+ n = n - self.cum_cycles[cycle]
25
+
26
+ if n < self.lr_warm_up_steps[cycle]:
27
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
28
+ self.last_f = f
29
+ return f
30
+ else:
31
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
32
+ self.last_f = f
33
+ return f
medical_diffusion/external/stable_diffusion/unet_openai.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from .attention import SpatialTransformer
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ """
64
+ Any module where forward() takes timestep embeddings as a second argument.
65
+ """
66
+
67
+ @abstractmethod
68
+ def forward(self, x, emb):
69
+ """
70
+ Apply the module to `x` given `emb` timestep embeddings.
71
+ """
72
+
73
+
74
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
+ """
76
+ A sequential module that passes timestep embeddings to the children that
77
+ support it as an extra input.
78
+ """
79
+
80
+ def forward(self, x, emb, context=None):
81
+ for layer in self:
82
+ if isinstance(layer, TimestepBlock):
83
+ x = layer(x, emb)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class UNetModel(nn.Module):
414
+ """
415
+ The full UNet model with attention and timestep embedding.
416
+ :param in_channels: channels in the input Tensor.
417
+ :param model_channels: base channel count for the model.
418
+ :param out_channels: channels in the output Tensor.
419
+ :param num_res_blocks: number of residual blocks per downsample.
420
+ :param attention_resolutions: a collection of downsample rates at which
421
+ attention will take place. May be a set, list, or tuple.
422
+ For example, if this contains 4, then at 4x downsampling, attention
423
+ will be used.
424
+ :param dropout: the dropout probability.
425
+ :param channel_mult: channel multiplier for each level of the UNet.
426
+ :param conv_resample: if True, use learned convolutions for upsampling and
427
+ downsampling.
428
+ :param dims: determines if the signal is 1D, 2D, or 3D.
429
+ :param num_classes: if specified (as an int), then this model will be
430
+ class-conditional with `num_classes` classes.
431
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
+ :param num_heads: the number of attention heads in each attention layer.
433
+ :param num_heads_channels: if specified, ignore num_heads and instead use
434
+ a fixed channel width per attention head.
435
+ :param num_heads_upsample: works with num_heads to set a different number
436
+ of heads for upsampling. Deprecated.
437
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
+ :param resblock_updown: use residual blocks for up/downsampling.
439
+ :param use_new_attention_order: use a different attention pattern for potentially
440
+ increased efficiency.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ image_size=32,
446
+ in_channels=4,
447
+ model_channels=256,
448
+ out_channels=4,
449
+ num_res_blocks=2,
450
+ attention_resolutions=[4,2,1],
451
+ dropout=0,
452
+ channel_mult=(1, 2, 4),
453
+ conv_resample=True,
454
+ dims=2,
455
+ num_classes=None,
456
+ use_checkpoint=False,
457
+ use_fp16=False,
458
+ num_heads=8,
459
+ num_head_channels=-1,
460
+ num_heads_upsample=-1,
461
+ use_scale_shift_norm=False,
462
+ resblock_updown=False,
463
+ use_new_attention_order=False,
464
+ use_spatial_transformer=False, # custom transformer support
465
+ transformer_depth=1, # custom transformer support
466
+ context_dim=None, # custom transformer support
467
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
+ legacy=True,
469
+ **kwargs
470
+ ):
471
+ super().__init__()
472
+ if use_spatial_transformer:
473
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
474
+
475
+ if context_dim is not None:
476
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
477
+ # from omegaconf.listconfig import ListConfig
478
+ # if type(context_dim) == ListConfig:
479
+ # context_dim = list(context_dim)
480
+
481
+ if num_heads_upsample == -1:
482
+ num_heads_upsample = num_heads
483
+
484
+ if num_heads == -1:
485
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
486
+
487
+ if num_head_channels == -1:
488
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
489
+
490
+ self.image_size = image_size
491
+ self.in_channels = in_channels
492
+ self.model_channels = model_channels
493
+ self.out_channels = out_channels
494
+ self.num_res_blocks = num_res_blocks
495
+ self.attention_resolutions = attention_resolutions
496
+ self.dropout = dropout
497
+ self.channel_mult = channel_mult
498
+ self.conv_resample = conv_resample
499
+ self.num_classes = num_classes
500
+ self.use_checkpoint = use_checkpoint
501
+ self.dtype = th.float16 if use_fp16 else th.float32
502
+ self.num_heads = num_heads
503
+ self.num_head_channels = num_head_channels
504
+ self.num_heads_upsample = num_heads_upsample
505
+ self.predict_codebook_ids = n_embed is not None
506
+
507
+ time_embed_dim = model_channels * 4
508
+ self.time_embed = nn.Sequential(
509
+ linear(model_channels, time_embed_dim),
510
+ nn.SiLU(),
511
+ linear(time_embed_dim, time_embed_dim),
512
+ )
513
+
514
+ if self.num_classes is not None:
515
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
516
+
517
+ self.input_blocks = nn.ModuleList(
518
+ [
519
+ TimestepEmbedSequential(
520
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
521
+ )
522
+ ]
523
+ )
524
+ self._feature_size = model_channels
525
+ input_block_chans = [model_channels]
526
+ ch = model_channels
527
+ ds = 1
528
+ for level, mult in enumerate(channel_mult):
529
+ for _ in range(num_res_blocks):
530
+ layers = [
531
+ ResBlock(
532
+ ch,
533
+ time_embed_dim,
534
+ dropout,
535
+ out_channels=mult * model_channels,
536
+ dims=dims,
537
+ use_checkpoint=use_checkpoint,
538
+ use_scale_shift_norm=use_scale_shift_norm,
539
+ )
540
+ ]
541
+ ch = mult * model_channels
542
+ if ds in attention_resolutions:
543
+ if num_head_channels == -1:
544
+ dim_head = ch // num_heads
545
+ else:
546
+ num_heads = ch // num_head_channels
547
+ dim_head = num_head_channels
548
+ if legacy:
549
+ #num_heads = 1
550
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
551
+ layers.append(
552
+ AttentionBlock(
553
+ ch,
554
+ use_checkpoint=use_checkpoint,
555
+ num_heads=num_heads,
556
+ num_head_channels=dim_head,
557
+ use_new_attention_order=use_new_attention_order,
558
+ ) if not use_spatial_transformer else SpatialTransformer(
559
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
560
+ )
561
+ )
562
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
563
+ self._feature_size += ch
564
+ input_block_chans.append(ch)
565
+ if level != len(channel_mult) - 1:
566
+ out_ch = ch
567
+ self.input_blocks.append(
568
+ TimestepEmbedSequential(
569
+ ResBlock(
570
+ ch,
571
+ time_embed_dim,
572
+ dropout,
573
+ out_channels=out_ch,
574
+ dims=dims,
575
+ use_checkpoint=use_checkpoint,
576
+ use_scale_shift_norm=use_scale_shift_norm,
577
+ down=True,
578
+ )
579
+ if resblock_updown
580
+ else Downsample(
581
+ ch, conv_resample, dims=dims, out_channels=out_ch
582
+ )
583
+ )
584
+ )
585
+ ch = out_ch
586
+ input_block_chans.append(ch)
587
+ ds *= 2
588
+ self._feature_size += ch
589
+
590
+ if num_head_channels == -1:
591
+ dim_head = ch // num_heads
592
+ else:
593
+ num_heads = ch // num_head_channels
594
+ dim_head = num_head_channels
595
+ if legacy:
596
+ #num_heads = 1
597
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
598
+ self.middle_block = TimestepEmbedSequential(
599
+ ResBlock(
600
+ ch,
601
+ time_embed_dim,
602
+ dropout,
603
+ dims=dims,
604
+ use_checkpoint=use_checkpoint,
605
+ use_scale_shift_norm=use_scale_shift_norm,
606
+ ),
607
+ AttentionBlock(
608
+ ch,
609
+ use_checkpoint=use_checkpoint,
610
+ num_heads=num_heads,
611
+ num_head_channels=dim_head,
612
+ use_new_attention_order=use_new_attention_order,
613
+ ) if not use_spatial_transformer else SpatialTransformer(
614
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
615
+ ),
616
+ ResBlock(
617
+ ch,
618
+ time_embed_dim,
619
+ dropout,
620
+ dims=dims,
621
+ use_checkpoint=use_checkpoint,
622
+ use_scale_shift_norm=use_scale_shift_norm,
623
+ ),
624
+ )
625
+ self._feature_size += ch
626
+
627
+ self.output_blocks = nn.ModuleList([])
628
+ for level, mult in list(enumerate(channel_mult))[::-1]:
629
+ for i in range(num_res_blocks + 1):
630
+ ich = input_block_chans.pop()
631
+ layers = [
632
+ ResBlock(
633
+ ch + ich,
634
+ time_embed_dim,
635
+ dropout,
636
+ out_channels=model_channels * mult,
637
+ dims=dims,
638
+ use_checkpoint=use_checkpoint,
639
+ use_scale_shift_norm=use_scale_shift_norm,
640
+ )
641
+ ]
642
+ ch = model_channels * mult
643
+ if ds in attention_resolutions:
644
+ if num_head_channels == -1:
645
+ dim_head = ch // num_heads
646
+ else:
647
+ num_heads = ch // num_head_channels
648
+ dim_head = num_head_channels
649
+ if legacy:
650
+ #num_heads = 1
651
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
652
+ layers.append(
653
+ AttentionBlock(
654
+ ch,
655
+ use_checkpoint=use_checkpoint,
656
+ num_heads=num_heads_upsample,
657
+ num_head_channels=dim_head,
658
+ use_new_attention_order=use_new_attention_order,
659
+ ) if not use_spatial_transformer else SpatialTransformer(
660
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
661
+ )
662
+ )
663
+ if level and i == num_res_blocks:
664
+ out_ch = ch
665
+ layers.append(
666
+ ResBlock(
667
+ ch,
668
+ time_embed_dim,
669
+ dropout,
670
+ out_channels=out_ch,
671
+ dims=dims,
672
+ use_checkpoint=use_checkpoint,
673
+ use_scale_shift_norm=use_scale_shift_norm,
674
+ up=True,
675
+ )
676
+ if resblock_updown
677
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
678
+ )
679
+ ds //= 2
680
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
681
+ self._feature_size += ch
682
+
683
+ self.out = nn.Sequential(
684
+ normalization(ch),
685
+ nn.SiLU(),
686
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
687
+ )
688
+ if self.predict_codebook_ids:
689
+ self.id_predictor = nn.Sequential(
690
+ normalization(ch),
691
+ conv_nd(dims, model_channels, n_embed, 1),
692
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
693
+ )
694
+
695
+ def convert_to_fp16(self):
696
+ """
697
+ Convert the torso of the model to float16.
698
+ """
699
+ self.input_blocks.apply(convert_module_to_f16)
700
+ self.middle_block.apply(convert_module_to_f16)
701
+ self.output_blocks.apply(convert_module_to_f16)
702
+
703
+ def convert_to_fp32(self):
704
+ """
705
+ Convert the torso of the model to float32.
706
+ """
707
+ self.input_blocks.apply(convert_module_to_f32)
708
+ self.middle_block.apply(convert_module_to_f32)
709
+ self.output_blocks.apply(convert_module_to_f32)
710
+
711
+ def forward(self, x, t=None, condition=None, context=None, **kwargs):
712
+ """
713
+ Apply the model to an input batch.
714
+ :param x: an [N x C x ...] Tensor of inputs.
715
+ :param timesteps: a 1-D batch of timesteps.
716
+ :param context: conditioning plugged in via crossattn
717
+ :param y: an [N] Tensor of labels, if class-conditional.
718
+ :return: an [N x C x ...] Tensor of outputs.
719
+ """
720
+ condition = None # --------------------- WANRING ONLY for Testing ---------------------
721
+ assert (condition is not None) == (
722
+ self.num_classes is not None
723
+ ), "must specify y if and only if the model is class-conditional"
724
+ hs = []
725
+ t_emb = timestep_embedding(t, self.model_channels, repeat_only=False)
726
+ emb = self.time_embed(t_emb)
727
+
728
+ if self.num_classes is not None:
729
+ assert condition.shape == (x.shape[0],)
730
+ emb = emb + self.label_emb(condition)
731
+
732
+ h = x.type(self.dtype)
733
+ for module in self.input_blocks:
734
+ h = module(h, emb, context)
735
+ hs.append(h)
736
+ h = self.middle_block(h, emb, context)
737
+ for module in self.output_blocks:
738
+ h = th.cat([h, hs.pop()], dim=1)
739
+ h = module(h, emb, context)
740
+ h = h.type(x.dtype)
741
+ if self.predict_codebook_ids:
742
+ return self.id_predictor(h)
743
+ else:
744
+ return self.out(h), []
745
+
746
+
747
+ class EncoderUNetModel(nn.Module):
748
+ """
749
+ The half UNet model with attention and timestep embedding.
750
+ For usage, see UNet.
751
+ """
752
+
753
+ def __init__(
754
+ self,
755
+ image_size,
756
+ in_channels,
757
+ model_channels,
758
+ out_channels,
759
+ num_res_blocks,
760
+ attention_resolutions,
761
+ dropout=0,
762
+ channel_mult=(1, 2, 4, 8),
763
+ conv_resample=True,
764
+ dims=2,
765
+ use_checkpoint=False,
766
+ use_fp16=False,
767
+ num_heads=1,
768
+ num_head_channels=-1,
769
+ num_heads_upsample=-1,
770
+ use_scale_shift_norm=False,
771
+ resblock_updown=False,
772
+ use_new_attention_order=False,
773
+ pool="adaptive",
774
+ *args,
775
+ **kwargs
776
+ ):
777
+ super().__init__()
778
+
779
+ if num_heads_upsample == -1:
780
+ num_heads_upsample = num_heads
781
+
782
+ self.in_channels = in_channels
783
+ self.model_channels = model_channels
784
+ self.out_channels = out_channels
785
+ self.num_res_blocks = num_res_blocks
786
+ self.attention_resolutions = attention_resolutions
787
+ self.dropout = dropout
788
+ self.channel_mult = channel_mult
789
+ self.conv_resample = conv_resample
790
+ self.use_checkpoint = use_checkpoint
791
+ self.dtype = th.float16 if use_fp16 else th.float32
792
+ self.num_heads = num_heads
793
+ self.num_head_channels = num_head_channels
794
+ self.num_heads_upsample = num_heads_upsample
795
+
796
+ time_embed_dim = model_channels * 4
797
+ self.time_embed = nn.Sequential(
798
+ linear(model_channels, time_embed_dim),
799
+ nn.SiLU(),
800
+ linear(time_embed_dim, time_embed_dim),
801
+ )
802
+
803
+ self.input_blocks = nn.ModuleList(
804
+ [
805
+ TimestepEmbedSequential(
806
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
807
+ )
808
+ ]
809
+ )
810
+ self._feature_size = model_channels
811
+ input_block_chans = [model_channels]
812
+ ch = model_channels
813
+ ds = 1
814
+ for level, mult in enumerate(channel_mult):
815
+ for _ in range(num_res_blocks):
816
+ layers = [
817
+ ResBlock(
818
+ ch,
819
+ time_embed_dim,
820
+ dropout,
821
+ out_channels=mult * model_channels,
822
+ dims=dims,
823
+ use_checkpoint=use_checkpoint,
824
+ use_scale_shift_norm=use_scale_shift_norm,
825
+ )
826
+ ]
827
+ ch = mult * model_channels
828
+ if ds in attention_resolutions:
829
+ layers.append(
830
+ AttentionBlock(
831
+ ch,
832
+ use_checkpoint=use_checkpoint,
833
+ num_heads=num_heads,
834
+ num_head_channels=num_head_channels,
835
+ use_new_attention_order=use_new_attention_order,
836
+ )
837
+ )
838
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
839
+ self._feature_size += ch
840
+ input_block_chans.append(ch)
841
+ if level != len(channel_mult) - 1:
842
+ out_ch = ch
843
+ self.input_blocks.append(
844
+ TimestepEmbedSequential(
845
+ ResBlock(
846
+ ch,
847
+ time_embed_dim,
848
+ dropout,
849
+ out_channels=out_ch,
850
+ dims=dims,
851
+ use_checkpoint=use_checkpoint,
852
+ use_scale_shift_norm=use_scale_shift_norm,
853
+ down=True,
854
+ )
855
+ if resblock_updown
856
+ else Downsample(
857
+ ch, conv_resample, dims=dims, out_channels=out_ch
858
+ )
859
+ )
860
+ )
861
+ ch = out_ch
862
+ input_block_chans.append(ch)
863
+ ds *= 2
864
+ self._feature_size += ch
865
+
866
+ self.middle_block = TimestepEmbedSequential(
867
+ ResBlock(
868
+ ch,
869
+ time_embed_dim,
870
+ dropout,
871
+ dims=dims,
872
+ use_checkpoint=use_checkpoint,
873
+ use_scale_shift_norm=use_scale_shift_norm,
874
+ ),
875
+ AttentionBlock(
876
+ ch,
877
+ use_checkpoint=use_checkpoint,
878
+ num_heads=num_heads,
879
+ num_head_channels=num_head_channels,
880
+ use_new_attention_order=use_new_attention_order,
881
+ ),
882
+ ResBlock(
883
+ ch,
884
+ time_embed_dim,
885
+ dropout,
886
+ dims=dims,
887
+ use_checkpoint=use_checkpoint,
888
+ use_scale_shift_norm=use_scale_shift_norm,
889
+ ),
890
+ )
891
+ self._feature_size += ch
892
+ self.pool = pool
893
+ if pool == "adaptive":
894
+ self.out = nn.Sequential(
895
+ normalization(ch),
896
+ nn.SiLU(),
897
+ nn.AdaptiveAvgPool2d((1, 1)),
898
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
899
+ nn.Flatten(),
900
+ )
901
+ elif pool == "attention":
902
+ assert num_head_channels != -1
903
+ self.out = nn.Sequential(
904
+ normalization(ch),
905
+ nn.SiLU(),
906
+ AttentionPool2d(
907
+ (image_size // ds), ch, num_head_channels, out_channels
908
+ ),
909
+ )
910
+ elif pool == "spatial":
911
+ self.out = nn.Sequential(
912
+ nn.Linear(self._feature_size, 2048),
913
+ nn.ReLU(),
914
+ nn.Linear(2048, self.out_channels),
915
+ )
916
+ elif pool == "spatial_v2":
917
+ self.out = nn.Sequential(
918
+ nn.Linear(self._feature_size, 2048),
919
+ normalization(2048),
920
+ nn.SiLU(),
921
+ nn.Linear(2048, self.out_channels),
922
+ )
923
+ else:
924
+ raise NotImplementedError(f"Unexpected {pool} pooling")
925
+
926
+ def convert_to_fp16(self):
927
+ """
928
+ Convert the torso of the model to float16.
929
+ """
930
+ self.input_blocks.apply(convert_module_to_f16)
931
+ self.middle_block.apply(convert_module_to_f16)
932
+
933
+ def convert_to_fp32(self):
934
+ """
935
+ Convert the torso of the model to float32.
936
+ """
937
+ self.input_blocks.apply(convert_module_to_f32)
938
+ self.middle_block.apply(convert_module_to_f32)
939
+
940
+ def forward(self, x, timesteps):
941
+ """
942
+ Apply the model to an input batch.
943
+ :param x: an [N x C x ...] Tensor of inputs.
944
+ :param timesteps: a 1-D batch of timesteps.
945
+ :return: an [N x K] Tensor of outputs.
946
+ """
947
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
948
+
949
+ results = []
950
+ h = x.type(self.dtype)
951
+ for module in self.input_blocks:
952
+ h = module(h, emb)
953
+ if self.pool.startswith("spatial"):
954
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
955
+ h = self.middle_block(h, emb)
956
+ if self.pool.startswith("spatial"):
957
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
958
+ h = th.cat(results, axis=-1)
959
+ return self.out(h)
960
+ else:
961
+ h = h.type(x.dtype)
962
+ return self.out(h)
medical_diffusion/external/stable_diffusion/util.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ #--------------- Added ----------------
19
+ import importlib
20
+ def get_obj_from_str(string, reload=False):
21
+ module, cls = string.rsplit(".", 1)
22
+ if reload:
23
+ module_imp = importlib.import_module(module)
24
+ importlib.reload(module_imp)
25
+ return getattr(importlib.import_module(module, package=None), cls)
26
+
27
+ def instantiate_from_config(config):
28
+ if not "target" in config:
29
+ if config == '__is_first_stage__':
30
+ return None
31
+ elif config == "__is_unconditional__":
32
+ return None
33
+ raise KeyError("Expected key `target` to instantiate.")
34
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
35
+
36
+ #--------------------------------
37
+
38
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
39
+ if schedule == "linear":
40
+ betas = (
41
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
42
+ )
43
+
44
+ elif schedule == "cosine":
45
+ timesteps = (
46
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
47
+ )
48
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
49
+ alphas = torch.cos(alphas).pow(2)
50
+ alphas = alphas / alphas[0]
51
+ betas = 1 - alphas[1:] / alphas[:-1]
52
+ betas = np.clip(betas, a_min=0, a_max=0.999)
53
+
54
+ elif schedule == "sqrt_linear":
55
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
56
+ elif schedule == "sqrt":
57
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
58
+ else:
59
+ raise ValueError(f"schedule '{schedule}' unknown.")
60
+ return betas.numpy()
61
+
62
+
63
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
64
+ if ddim_discr_method == 'uniform':
65
+ c = num_ddpm_timesteps // num_ddim_timesteps
66
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
67
+ elif ddim_discr_method == 'quad':
68
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
69
+ else:
70
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
71
+
72
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
73
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
74
+ steps_out = ddim_timesteps + 1
75
+ if verbose:
76
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
77
+ return steps_out
78
+
79
+
80
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
81
+ # select alphas for computing the variance schedule
82
+ alphas = alphacums[ddim_timesteps]
83
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
84
+
85
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
86
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
87
+ if verbose:
88
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
89
+ print(f'For the chosen value of eta, which is {eta}, '
90
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
91
+ return sigmas, alphas, alphas_prev
92
+
93
+
94
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
95
+ """
96
+ Create a beta schedule that discretizes the given alpha_t_bar function,
97
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
98
+ :param num_diffusion_timesteps: the number of betas to produce.
99
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
100
+ produces the cumulative product of (1-beta) up to that
101
+ part of the diffusion process.
102
+ :param max_beta: the maximum beta to use; use values lower than 1 to
103
+ prevent singularities.
104
+ """
105
+ betas = []
106
+ for i in range(num_diffusion_timesteps):
107
+ t1 = i / num_diffusion_timesteps
108
+ t2 = (i + 1) / num_diffusion_timesteps
109
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
110
+ return np.array(betas)
111
+
112
+
113
+ def extract_into_tensor(a, t, x_shape):
114
+ b, *_ = t.shape
115
+ out = a.gather(-1, t)
116
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
117
+
118
+
119
+ def checkpoint(func, inputs, params, flag):
120
+ """
121
+ Evaluate a function without caching intermediate activations, allowing for
122
+ reduced memory at the expense of extra compute in the backward pass.
123
+ :param func: the function to evaluate.
124
+ :param inputs: the argument sequence to pass to `func`.
125
+ :param params: a sequence of parameters `func` depends on but does not
126
+ explicitly take as arguments.
127
+ :param flag: if False, disable gradient checkpointing.
128
+ """
129
+ if flag:
130
+ args = tuple(inputs) + tuple(params)
131
+ return CheckpointFunction.apply(func, len(inputs), *args)
132
+ else:
133
+ return func(*inputs)
134
+
135
+
136
+ class CheckpointFunction(torch.autograd.Function):
137
+ @staticmethod
138
+ def forward(ctx, run_function, length, *args):
139
+ ctx.run_function = run_function
140
+ ctx.input_tensors = list(args[:length])
141
+ ctx.input_params = list(args[length:])
142
+
143
+ with torch.no_grad():
144
+ output_tensors = ctx.run_function(*ctx.input_tensors)
145
+ return output_tensors
146
+
147
+ @staticmethod
148
+ def backward(ctx, *output_grads):
149
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
150
+ with torch.enable_grad():
151
+ # Fixes a bug where the first op in run_function modifies the
152
+ # Tensor storage in place, which is not allowed for detach()'d
153
+ # Tensors.
154
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
155
+ output_tensors = ctx.run_function(*shallow_copies)
156
+ input_grads = torch.autograd.grad(
157
+ output_tensors,
158
+ ctx.input_tensors + ctx.input_params,
159
+ output_grads,
160
+ allow_unused=True,
161
+ )
162
+ del ctx.input_tensors
163
+ del ctx.input_params
164
+ del output_tensors
165
+ return (None, None) + input_grads
166
+
167
+
168
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
169
+ """
170
+ Create sinusoidal timestep embeddings.
171
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
172
+ These may be fractional.
173
+ :param dim: the dimension of the output.
174
+ :param max_period: controls the minimum frequency of the embeddings.
175
+ :return: an [N x dim] Tensor of positional embeddings.
176
+ """
177
+ if not repeat_only:
178
+ half = dim // 2
179
+ freqs = torch.exp(
180
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
181
+ ).to(device=timesteps.device)
182
+ args = timesteps[:, None].float() * freqs[None]
183
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
184
+ if dim % 2:
185
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
186
+ else:
187
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
188
+ return embedding
189
+
190
+
191
+ def zero_module(module):
192
+ """
193
+ Zero out the parameters of a module and return it.
194
+ """
195
+ for p in module.parameters():
196
+ p.detach().zero_()
197
+ return module
198
+
199
+
200
+ def scale_module(module, scale):
201
+ """
202
+ Scale the parameters of a module and return it.
203
+ """
204
+ for p in module.parameters():
205
+ p.detach().mul_(scale)
206
+ return module
207
+
208
+
209
+ def mean_flat(tensor):
210
+ """
211
+ Take the mean over all non-batch dimensions.
212
+ """
213
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
214
+
215
+
216
+ def normalization(channels):
217
+ """
218
+ Make a standard normalization layer.
219
+ :param channels: number of input channels.
220
+ :return: an nn.Module for normalization.
221
+ """
222
+ return GroupNorm32(32, channels)
223
+
224
+
225
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
226
+ class SiLU(nn.Module):
227
+ def forward(self, x):
228
+ return x * torch.sigmoid(x)
229
+
230
+
231
+ class GroupNorm32(nn.GroupNorm):
232
+ def forward(self, x):
233
+ return super().forward(x.float()).type(x.dtype)
234
+
235
+ def conv_nd(dims, *args, **kwargs):
236
+ """
237
+ Create a 1D, 2D, or 3D convolution module.
238
+ """
239
+ if dims == 1:
240
+ return nn.Conv1d(*args, **kwargs)
241
+ elif dims == 2:
242
+ return nn.Conv2d(*args, **kwargs)
243
+ elif dims == 3:
244
+ return nn.Conv3d(*args, **kwargs)
245
+ raise ValueError(f"unsupported dimensions: {dims}")
246
+
247
+
248
+ def linear(*args, **kwargs):
249
+ """
250
+ Create a linear module.
251
+ """
252
+ return nn.Linear(*args, **kwargs)
253
+
254
+
255
+ def avg_pool_nd(dims, *args, **kwargs):
256
+ """
257
+ Create a 1D, 2D, or 3D average pooling module.
258
+ """
259
+ if dims == 1:
260
+ return nn.AvgPool1d(*args, **kwargs)
261
+ elif dims == 2:
262
+ return nn.AvgPool2d(*args, **kwargs)
263
+ elif dims == 3:
264
+ return nn.AvgPool3d(*args, **kwargs)
265
+ raise ValueError(f"unsupported dimensions: {dims}")
266
+
267
+
268
+ class HybridConditioner(nn.Module):
269
+
270
+ def __init__(self, c_concat_config, c_crossattn_config):
271
+ super().__init__()
272
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
273
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
274
+
275
+ def forward(self, c_concat, c_crossattn):
276
+ c_concat = self.concat_conditioner(c_concat)
277
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
278
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
279
+
280
+
281
+ def noise_like(shape, device, repeat=False):
282
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
283
+ noise = lambda: torch.randn(shape, device=device)
284
+ return repeat_noise() if repeat else noise()
medical_diffusion/external/stable_diffusion/util_attention.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from einops import repeat
8
+
9
+ def checkpoint(func, inputs, params, flag):
10
+ """
11
+ Evaluate a function without caching intermediate activations, allowing for
12
+ reduced memory at the expense of extra compute in the backward pass.
13
+ :param func: the function to evaluate.
14
+ :param inputs: the argument sequence to pass to `func`.
15
+ :param params: a sequence of parameters `func` depends on but does not
16
+ explicitly take as arguments.
17
+ :param flag: if False, disable gradient checkpointing.
18
+ """
19
+ if flag:
20
+ args = tuple(inputs) + tuple(params)
21
+ return CheckpointFunction.apply(func, len(inputs), *args)
22
+ else:
23
+ return func(*inputs)
24
+
25
+
26
+ class CheckpointFunction(torch.autograd.Function):
27
+ @staticmethod
28
+ def forward(ctx, run_function, length, *args):
29
+ ctx.run_function = run_function
30
+ ctx.input_tensors = list(args[:length])
31
+ ctx.input_params = list(args[length:])
32
+
33
+ with torch.no_grad():
34
+ output_tensors = ctx.run_function(*ctx.input_tensors)
35
+ return output_tensors
36
+
37
+ @staticmethod
38
+ def backward(ctx, *output_grads):
39
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
40
+ with torch.enable_grad():
41
+ # Fixes a bug where the first op in run_function modifies the
42
+ # Tensor storage in place, which is not allowed for detach()'d
43
+ # Tensors.
44
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
45
+ output_tensors = ctx.run_function(*shallow_copies)
46
+ input_grads = torch.autograd.grad(
47
+ output_tensors,
48
+ ctx.input_tensors + ctx.input_params,
49
+ output_grads,
50
+ allow_unused=True,
51
+ )
52
+ del ctx.input_tensors
53
+ del ctx.input_params
54
+ del output_tensors
55
+ return (None, None) + input_grads
56
+
medical_diffusion/external/unet_lucidrains.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, einsum
2
+ from einops import rearrange, reduce
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from functools import partial
6
+ import math
7
+
8
+ # -------------------------------- Embeddings ------------------------------------------------------
9
+ class SinusoidalPosEmb(nn.Module):
10
+ def __init__(self, dim):
11
+ super().__init__()
12
+ self.dim = dim
13
+
14
+ def forward(self, x):
15
+ device = x.device
16
+ half_dim = self.dim // 2
17
+ emb = math.log(10000) / (half_dim - 1)
18
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
19
+ emb = x[:, None] * emb[None, :]
20
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
21
+ return emb
22
+
23
+ class LearnedSinusoidalPosEmb(nn.Module):
24
+ """ following @crowsonkb 's lead with learned sinusoidal pos emb """
25
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
26
+
27
+ def __init__(self, dim):
28
+ super().__init__()
29
+ assert (dim % 2) == 0
30
+ half_dim = dim // 2
31
+ self.weights = nn.Parameter(torch.randn(half_dim))
32
+
33
+ def forward(self, x):
34
+ x = rearrange(x, 'b -> b 1')
35
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
36
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
37
+ fouriered = torch.cat((x, fouriered), dim = -1)
38
+ return fouriered
39
+
40
+ # -------------------------------------------------------------------
41
+
42
+ def exists(x):
43
+ return x is not None
44
+
45
+ def default(val, d):
46
+ if exists(val):
47
+ return val
48
+ return d() if callable(d) else d
49
+
50
+ def l2norm(t):
51
+ return F.normalize(t, dim = -1)
52
+
53
+ class Residual(nn.Module):
54
+ def __init__(self, fn):
55
+ super().__init__()
56
+ self.fn = fn
57
+
58
+ def forward(self, x, *args, **kwargs):
59
+ return self.fn(x, *args, **kwargs) + x
60
+
61
+ def Upsample(dim, dim_out = None):
62
+ return nn.Sequential(
63
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
64
+ nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
65
+ )
66
+
67
+ def Downsample(dim, dim_out = None):
68
+ return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
69
+
70
+ class WeightStandardizedConv2d(nn.Conv2d):
71
+ """
72
+ https://arxiv.org/abs/1903.10520
73
+ weight standardization purportedly works synergistically with group normalization
74
+ """
75
+ def forward(self, x):
76
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
77
+
78
+ weight = self.weight
79
+ mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
80
+ var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
81
+ normalized_weight = (weight - mean) * (var + eps).rsqrt()
82
+
83
+ return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
84
+
85
+
86
+ class LayerNorm(nn.Module):
87
+ def __init__(self, dim):
88
+ super().__init__()
89
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
90
+
91
+ def forward(self, x):
92
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
93
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
94
+ mean = torch.mean(x, dim = 1, keepdim = True)
95
+ return (x - mean) * (var + eps).rsqrt() * self.g
96
+
97
+ class PreNorm(nn.Module):
98
+ def __init__(self, dim, fn):
99
+ super().__init__()
100
+ self.fn = fn
101
+ self.norm = LayerNorm(dim)
102
+
103
+ def forward(self, x):
104
+ x = self.norm(x)
105
+ return self.fn(x)
106
+
107
+ class Block(nn.Module):
108
+ def __init__(self, dim, dim_out, groups = 8):
109
+ super().__init__()
110
+ self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
111
+ self.norm = nn.GroupNorm(groups, dim_out)
112
+ self.act = nn.SiLU()
113
+
114
+ def forward(self, x, scale_shift = None):
115
+ x = self.proj(x)
116
+ x = self.norm(x)
117
+
118
+ if exists(scale_shift):
119
+ scale, shift = scale_shift
120
+ x = x * (scale + 1) + shift
121
+
122
+ x = self.act(x)
123
+ return x
124
+
125
+ class ResnetBlock(nn.Module):
126
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
127
+ super().__init__()
128
+ self.mlp = nn.Sequential(
129
+ nn.SiLU(),
130
+ nn.Linear(time_emb_dim, dim_out * 2)
131
+ ) if exists(time_emb_dim) else None
132
+
133
+ self.block1 = Block(dim, dim_out, groups = groups)
134
+ self.block2 = Block(dim_out, dim_out, groups = groups)
135
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
136
+
137
+ def forward(self, x, time_emb = None):
138
+
139
+ scale_shift = None
140
+ if exists(self.mlp) and exists(time_emb):
141
+ time_emb = self.mlp(time_emb)
142
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
143
+ scale_shift = time_emb.chunk(2, dim = 1)
144
+
145
+ h = self.block1(x, scale_shift = scale_shift)
146
+
147
+ h = self.block2(h)
148
+
149
+ return h + self.res_conv(x)
150
+
151
+ class LinearAttention(nn.Module):
152
+ def __init__(self, dim, heads = 4, dim_head = 32):
153
+ super().__init__()
154
+ self.scale = dim_head ** -0.5
155
+ self.heads = heads
156
+ hidden_dim = dim_head * heads
157
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
158
+
159
+ self.to_out = nn.Sequential(
160
+ nn.Conv2d(hidden_dim, dim, 1),
161
+ LayerNorm(dim)
162
+ )
163
+
164
+ def forward(self, x):
165
+ b, c, h, w = x.shape
166
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
167
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
168
+
169
+ q = q.softmax(dim = -2)
170
+ k = k.softmax(dim = -1)
171
+
172
+ q = q * self.scale
173
+ v = v / (h * w)
174
+
175
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
176
+
177
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
178
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
179
+ return self.to_out(out)
180
+
181
+ class Attention(nn.Module):
182
+ def __init__(self, dim, heads = 4, dim_head = 32, scale = 10):
183
+ super().__init__()
184
+ self.scale = scale
185
+ self.heads = heads
186
+ hidden_dim = dim_head * heads
187
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
188
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
189
+
190
+ def forward(self, x):
191
+ b, c, h, w = x.shape
192
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
193
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
194
+
195
+ q, k = map(l2norm, (q, k))
196
+
197
+ sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
198
+ attn = sim.softmax(dim = -1)
199
+ out = einsum('b h i j, b h d j -> b h i d', attn, v)
200
+ out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
201
+ return self.to_out(out)
202
+
203
+
204
+
205
+ class UNet(nn.Module):
206
+ def __init__(
207
+ self,
208
+ dim=32,
209
+ init_dim = None,
210
+ out_dim = None,
211
+ dim_mults=(1, 2, 4, 8),
212
+ channels = 3,
213
+ self_condition = False,
214
+ resnet_block_groups = 8,
215
+ learned_variance = False,
216
+ learned_sinusoidal_cond = False,
217
+ learned_sinusoidal_dim = 16,
218
+ **kwargs
219
+ ):
220
+ super().__init__()
221
+
222
+ # determine dimensions
223
+
224
+ self.channels = channels
225
+ self.self_condition = self_condition
226
+ input_channels = channels * (2 if self_condition else 1)
227
+
228
+ init_dim = default(init_dim, dim)
229
+ self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
230
+
231
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
232
+ in_out = list(zip(dims[:-1], dims[1:]))
233
+
234
+ block_klass = partial(ResnetBlock, groups = resnet_block_groups)
235
+
236
+ # time embeddings
237
+
238
+ time_dim = dim * 4
239
+
240
+ self.learned_sinusoidal_cond = learned_sinusoidal_cond
241
+
242
+ if learned_sinusoidal_cond:
243
+ sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
244
+ fourier_dim = learned_sinusoidal_dim + 1
245
+ else:
246
+ sinu_pos_emb = SinusoidalPosEmb(dim)
247
+ fourier_dim = dim
248
+
249
+ self.time_mlp = nn.Sequential(
250
+ sinu_pos_emb,
251
+ nn.Linear(fourier_dim, time_dim),
252
+ nn.GELU(),
253
+ nn.Linear(time_dim, time_dim)
254
+ )
255
+
256
+ # layers
257
+
258
+ self.downs = nn.ModuleList([])
259
+ self.ups = nn.ModuleList([])
260
+ num_resolutions = len(in_out)
261
+
262
+ for ind, (dim_in, dim_out) in enumerate(in_out):
263
+ is_last = ind >= (num_resolutions - 1)
264
+
265
+ self.downs.append(nn.ModuleList([
266
+ block_klass(dim_in, dim_in, time_emb_dim = time_dim),
267
+ block_klass(dim_in, dim_in, time_emb_dim = time_dim),
268
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))),
269
+ Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
270
+ ]))
271
+
272
+ mid_dim = dims[-1]
273
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
274
+ self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
275
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
276
+
277
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
278
+ is_last = ind == (len(in_out) - 1)
279
+
280
+ self.ups.append(nn.ModuleList([
281
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
282
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
283
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
284
+ Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
285
+ ]))
286
+
287
+ default_out_dim = channels * (1 if not learned_variance else 2)
288
+ self.out_dim = default(out_dim, default_out_dim)
289
+
290
+ self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
291
+ self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
292
+
293
+ def forward(self, x, time, condition=None, self_cond=None):
294
+ if self.self_condition:
295
+ x_self_cond = default(self_cond, lambda: torch.zeros_like(x))
296
+ x = torch.cat((x_self_cond, x), dim = 1)
297
+
298
+ x = self.init_conv(x)
299
+ r = x.clone()
300
+
301
+ t = self.time_mlp(time)
302
+
303
+ h = []
304
+
305
+ for block1, block2, attn, downsample in self.downs:
306
+ x = block1(x, t)
307
+ h.append(x)
308
+
309
+ x = block2(x, t)
310
+ x = attn(x)
311
+ h.append(x)
312
+
313
+ x = downsample(x)
314
+
315
+ x = self.mid_block1(x, t)
316
+ x = self.mid_attn(x)
317
+ x = self.mid_block2(x, t)
318
+
319
+ for block1, block2, attn, upsample in self.ups:
320
+ x = torch.cat((x, h.pop()), dim = 1)
321
+ x = block1(x, t)
322
+
323
+ x = torch.cat((x, h.pop()), dim = 1)
324
+ x = block2(x, t)
325
+ x = attn(x)
326
+
327
+ x = upsample(x)
328
+
329
+ x = torch.cat((x, r), dim = 1)
330
+
331
+ x = self.final_res_block(x, t)
332
+ return self.final_conv(x), []
medical_diffusion/loss/gan_losses.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ def exp_d_loss(logits_real, logits_fake):
7
+ loss_real = torch.mean(torch.exp(-logits_real))
8
+ loss_fake = torch.mean(torch.exp(logits_fake))
9
+ d_loss = 0.5 * (loss_real + loss_fake)
10
+ return d_loss
11
+
12
+ def hinge_d_loss(logits_real, logits_fake):
13
+ loss_real = torch.mean(F.relu(1. - logits_real))
14
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
15
+ d_loss = 0.5 * (loss_real + loss_fake)
16
+ return d_loss
17
+
18
+ def vanilla_d_loss(logits_real, logits_fake):
19
+ d_loss = 0.5 * (
20
+ torch.mean(F.softplus(-logits_real)) +
21
+ torch.mean(F.softplus(logits_fake)))
22
+ return d_loss
medical_diffusion/loss/perceivers.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import lpips
4
+ import torch
5
+
6
+ class LPIPS(torch.nn.Module):
7
+ """Learned Perceptual Image Patch Similarity (LPIPS)"""
8
+ def __init__(self, linear_calibration=False, normalize=False):
9
+ super().__init__()
10
+ self.loss_fn = lpips.LPIPS(net='vgg', lpips=linear_calibration) # Note: only 'vgg' valid as loss
11
+ self.normalize = normalize # If true, normalize [0, 1] to [-1, 1]
12
+
13
+
14
+ def forward(self, pred, target):
15
+ # No need to do that because ScalingLayer was introduced in version 0.1 which does this indirectly
16
+ # if pred.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB
17
+ # pred = torch.concat([pred, pred, pred], dim=1)
18
+ # if target.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB
19
+ # target = torch.concat([target, target, target], dim=1)
20
+
21
+ if pred.ndim == 5: # 3D Image: Just use 2D model and compute average over slices
22
+ depth = pred.shape[2]
23
+ losses = torch.stack([self.loss_fn(pred[:,:,d], target[:,:,d], normalize=self.normalize) for d in range(depth)], dim=2)
24
+ return torch.mean(losses, dim=2, keepdim=True)
25
+ else:
26
+ return self.loss_fn(pred, target, normalize=self.normalize)
27
+
medical_diffusion/metrics/__init__.py ADDED
File without changes
medical_diffusion/metrics/torchmetrics_pr_recall.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torchmetrics import Metric
6
+ import torchvision.models as models
7
+ from torchvision import transforms
8
+
9
+
10
+
11
+ from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
12
+
13
+ if _TORCH_FIDELITY_AVAILABLE:
14
+ from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
15
+ else:
16
+ class FeatureExtractorInceptionV3(Module): # type: ignore
17
+ pass
18
+ __doctest_skip__ = ["ImprovedPrecessionRecall", "IPR"]
19
+
20
+ class NoTrainInceptionV3(FeatureExtractorInceptionV3):
21
+ def __init__(
22
+ self,
23
+ name: str,
24
+ features_list: List[str],
25
+ feature_extractor_weights_path: Optional[str] = None,
26
+ ) -> None:
27
+ super().__init__(name, features_list, feature_extractor_weights_path)
28
+ # put into evaluation mode
29
+ self.eval()
30
+
31
+ def train(self, mode: bool) -> "NoTrainInceptionV3":
32
+ """the inception network should not be able to be switched away from evaluation mode."""
33
+ return super().train(False)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ out = super().forward(x)
37
+ return out[0].reshape(x.shape[0], -1)
38
+
39
+
40
+ # -------------------------- VGG Trans ---------------------------
41
+ # class Normalize(object):
42
+ # """Rescale the image from 0-255 (uint8) to [0,1] (float32).
43
+ # Note, this doesn't ensure that min=0 and max=1 as a min-max scale would do!"""
44
+
45
+ # def __call__(self, image):
46
+ # return image/255
47
+
48
+ # # see https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html
49
+ # VGG_Trans = transforms.Compose([
50
+ # transforms.Resize([224, 224], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
51
+ # # transforms.Resize([256, 256], interpolation=InterpolationMode.BILINEAR),
52
+ # # transforms.CenterCrop(224),
53
+ # Normalize(), # scale to [0, 1]
54
+ # transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
55
+ # ])
56
+
57
+
58
+
59
+ class ImprovedPrecessionRecall(Metric):
60
+ is_differentiable: bool = False
61
+ higher_is_better: bool = True
62
+ full_state_update: bool = False
63
+
64
+
65
+ def __init__(self, feature=2048, knn=3, splits_real=1, splits_fake=5):
66
+ super().__init__()
67
+
68
+
69
+ # ------------------------- Init Feature Extractor (VGG or Inception) ------------------------------
70
+ # Original VGG: https://github.com/kynkaat/improved-precision-and-recall-metric/blob/b0247eafdead494a5d243bd2efb1b0b124379ae9/utils.py#L40
71
+ # Compare Inception: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L574
72
+ # TODO: Add option to switch between Inception and VGG feature extractor
73
+ # self.vgg_model = models.vgg16(weights='IMAGENET1K_V1').eval()
74
+ # self.feature_extractor = transforms.Compose([
75
+ # VGG_Trans,
76
+ # self.vgg_model.features,
77
+ # transforms.Lambda(lambda x: torch.flatten(x, 1)),
78
+ # self.vgg_model.classifier[:4] # [:4] corresponds to 4096 features
79
+ # ])
80
+
81
+ if isinstance(feature, int):
82
+ if not _TORCH_FIDELITY_AVAILABLE:
83
+ raise ModuleNotFoundError(
84
+ "FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
85
+ " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
86
+ )
87
+ valid_int_input = [64, 192, 768, 2048]
88
+ if feature not in valid_int_input:
89
+ raise ValueError(
90
+ f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
91
+ )
92
+
93
+ self.feature_extractor = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
94
+ elif isinstance(feature, torch.nn.Module):
95
+ self.feature_extractor = feature
96
+ else:
97
+ raise TypeError("Got unknown input to argument `feature`")
98
+
99
+ # --------------------------- End Feature Extractor ---------------------------------------------------------------
100
+
101
+ self.knn = knn
102
+ self.splits_real = splits_real
103
+ self.splits_fake = splits_fake
104
+ self.add_state("real_features", [], dist_reduce_fx=None)
105
+ self.add_state("fake_features", [], dist_reduce_fx=None)
106
+
107
+
108
+
109
+ def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
110
+ """Update the state with extracted features.
111
+
112
+ Args:
113
+ imgs: tensor with images feed to the feature extractor
114
+ real: bool indicating if ``imgs`` belong to the real or the fake distribution
115
+ """
116
+ assert torch.is_tensor(imgs) and imgs.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8'
117
+
118
+ features = self.feature_extractor(imgs).view(imgs.shape[0], -1)
119
+
120
+ if real:
121
+ self.real_features.append(features)
122
+ else:
123
+ self.fake_features.append(features)
124
+
125
+ def compute(self):
126
+ real_features = torch.concat(self.real_features)
127
+ fake_features = torch.concat(self.fake_features)
128
+
129
+ real_distances = _compute_pairwise_distances(real_features, self.splits_real)
130
+ real_radii = _distances2radii(real_distances, self.knn)
131
+
132
+ fake_distances = _compute_pairwise_distances(fake_features, self.splits_fake)
133
+ fake_radii = _distances2radii(fake_distances, self.knn)
134
+
135
+ precision = _compute_metric(real_features, real_radii, self.splits_real, fake_features, self.splits_fake)
136
+ recall = _compute_metric(fake_features, fake_radii, self.splits_fake, real_features, self.splits_real)
137
+
138
+ return precision, recall
139
+
140
+ def _compute_metric(ref_features, ref_radii, ref_splits, pred_features, pred_splits):
141
+ dist = _compute_pairwise_distances(ref_features, ref_splits, pred_features, pred_splits)
142
+ num_feat = pred_features.shape[0]
143
+ count = 0
144
+ for i in range(num_feat):
145
+ count += (dist[:, i] < ref_radii).any()
146
+ return count / num_feat
147
+
148
+ def _distances2radii(distances, knn):
149
+ return torch.topk(distances, knn+1, dim=1, largest=False)[0].max(dim=1)[0]
150
+
151
+ def _compute_pairwise_distances(X, splits_x, Y=None, splits_y=None):
152
+ # X = [B, features]
153
+ # Y = [B', features]
154
+ Y = X if Y is None else Y
155
+ # X = X.double()
156
+ # Y = Y.double()
157
+ splits_y = splits_x if splits_y is None else splits_y
158
+ dist = torch.concat([
159
+ torch.concat([
160
+ (torch.sum(X_batch**2, dim=1, keepdim=True) +
161
+ torch.sum(Y_batch**2, dim=1, keepdim=True).t() -
162
+ 2 * torch.einsum("bd,dn->bn", X_batch, Y_batch.t()))
163
+ for Y_batch in Y.chunk(splits_y, dim=0)], dim=1)
164
+ for X_batch in X.chunk(splits_x, dim=0)])
165
+
166
+ # dist = torch.maximum(dist, torch.zeros_like(dist))
167
+ dist[dist<0] = 0
168
+ return torch.sqrt(dist)
169
+
170
+
medical_diffusion/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_base import BasicModel
medical_diffusion/models/embedders/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .time_embedder import TimeEmbbeding, LearnedSinusoidalPosEmb, SinusoidalPosEmb
2
+ from .cond_embedders import LabelEmbedder
medical_diffusion/models/embedders/cond_embedders.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import torch
4
+ from monai.networks.layers.utils import get_act_layer
5
+
6
+ class LabelEmbedder(nn.Module):
7
+ def __init__(self, emb_dim=32, num_classes=2, act_name=("SWISH", {})):
8
+ super().__init__()
9
+ self.emb_dim = emb_dim
10
+ self.embedding = nn.Embedding(num_classes, emb_dim)
11
+
12
+ # self.embedding = nn.Embedding(num_classes, emb_dim//4)
13
+ # self.emb_net = nn.Sequential(
14
+ # nn.Linear(1, emb_dim),
15
+ # get_act_layer(act_name),
16
+ # nn.Linear(emb_dim, emb_dim)
17
+ # )
18
+
19
+ def forward(self, condition):
20
+ c = self.embedding(condition) #[B,] -> [B, C]
21
+ # c = self.emb_net(c)
22
+ # c = self.emb_net(condition[:,None].float())
23
+ # c = (2*condition-1)[:, None].expand(-1, self.emb_dim).type(torch.float32)
24
+ return c
25
+
26
+
27
+
medical_diffusion/models/embedders/latent_embedders.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision.utils import save_image
8
+ from monai.networks.blocks import UnetOutBlock
9
+
10
+
11
+ from medical_diffusion.models.utils.conv_blocks import DownBlock, UpBlock, BasicBlock, BasicResBlock, UnetResBlock, UnetBasicBlock
12
+ from medical_diffusion.loss.gan_losses import hinge_d_loss
13
+ from medical_diffusion.loss.perceivers import LPIPS
14
+ from medical_diffusion.models.model_base import BasicModel, VeryBasicModel
15
+
16
+
17
+ from pytorch_msssim import SSIM, ssim
18
+
19
+
20
+ class DiagonalGaussianDistribution(nn.Module):
21
+
22
+ def forward(self, x):
23
+ mean, logvar = torch.chunk(x, 2, dim=1)
24
+ logvar = torch.clamp(logvar, -30.0, 20.0)
25
+ std = torch.exp(0.5 * logvar)
26
+ sample = torch.randn(mean.shape, generator=None, device=x.device)
27
+ z = mean + std * sample
28
+
29
+ batch_size = x.shape[0]
30
+ var = torch.exp(logvar)
31
+ kl = 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar)/batch_size
32
+
33
+ return z, kl
34
+
35
+
36
+
37
+
38
+
39
+
40
+ class VectorQuantizer(nn.Module):
41
+ def __init__(self, num_embeddings, emb_channels, beta=0.25):
42
+ super().__init__()
43
+ self.num_embeddings = num_embeddings
44
+ self.emb_channels = emb_channels
45
+ self.beta = beta
46
+
47
+ self.embedder = nn.Embedding(num_embeddings, emb_channels)
48
+ self.embedder.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings)
49
+
50
+ def forward(self, z):
51
+ assert z.shape[1] == self.emb_channels, "Channels of z and codebook don't match"
52
+ z_ch = torch.moveaxis(z, 1, -1) # [B, C, *] -> [B, *, C]
53
+ z_flattened = z_ch.reshape(-1, self.emb_channels) # [B, *, C] -> [Bx*, C], Note: or use contiguous() and view()
54
+
55
+ # distances from z to embeddings e: (z - e)^2 = z^2 + e^2 - 2 e * z
56
+ dist = ( torch.sum(z_flattened**2, dim=1, keepdim=True)
57
+ + torch.sum(self.embedder.weight**2, dim=1)
58
+ -2* torch.einsum("bd,dn->bn", z_flattened, self.embedder.weight.t())
59
+ ) # [Bx*, num_embeddings]
60
+
61
+ min_encoding_indices = torch.argmin(dist, dim=1) # [Bx*]
62
+ z_q = self.embedder(min_encoding_indices) # [Bx*, C]
63
+ z_q = z_q.view(z_ch.shape) # [Bx*, C] -> [B, *, C]
64
+ z_q = torch.moveaxis(z_q, -1, 1) # [B, *, C] -> [B, C, *]
65
+
66
+ # Compute Embedding Loss
67
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
68
+
69
+ # preserve gradients
70
+ z_q = z + (z_q - z).detach()
71
+
72
+ return z_q, loss
73
+
74
+
75
+
76
+ class Discriminator(nn.Module):
77
+ def __init__(self,
78
+ in_channels=1,
79
+ spatial_dims = 3,
80
+ hid_chs = [32, 64, 128, 256, 512],
81
+ kernel_sizes=[(1,3,3), (1,3,3), (1,3,3), 3, 3],
82
+ strides = [ 1, (1,2,2), (1,2,2), 2, 2],
83
+ act_name=("Swish", {}),
84
+ norm_name = ("GROUP", {'num_groups':32, "affine": True}),
85
+ dropout=None
86
+ ):
87
+ super().__init__()
88
+
89
+ self.inc = BasicBlock(
90
+ spatial_dims=spatial_dims,
91
+ in_channels=in_channels,
92
+ out_channels=hid_chs[0],
93
+ kernel_size=kernel_sizes[0], # 2*pad = kernel-stride -> kernel = 2*pad + stride => 1 = 2*0+1, 3, =2*1+1, 2 = 2*0+2, 4 = 2*1+2
94
+ stride=strides[0],
95
+ norm_name=norm_name,
96
+ act_name=act_name,
97
+ dropout=dropout,
98
+ )
99
+
100
+ self.encoder = nn.Sequential(*[
101
+ BasicBlock(
102
+ spatial_dims=spatial_dims,
103
+ in_channels=hid_chs[i-1],
104
+ out_channels=hid_chs[i],
105
+ kernel_size=kernel_sizes[i],
106
+ stride=strides[i],
107
+ act_name=act_name,
108
+ norm_name=norm_name,
109
+ dropout=dropout)
110
+ for i in range(1, len(hid_chs))
111
+ ])
112
+
113
+
114
+ self.outc = BasicBlock(
115
+ spatial_dims=spatial_dims,
116
+ in_channels=hid_chs[-1],
117
+ out_channels=1,
118
+ kernel_size=3,
119
+ stride=1,
120
+ act_name=None,
121
+ norm_name=None,
122
+ dropout=None,
123
+ zero_conv=True
124
+ )
125
+
126
+
127
+
128
+ def forward(self, x):
129
+ x = self.inc(x)
130
+ x = self.encoder(x)
131
+ return self.outc(x)
132
+
133
+
134
+ class NLayerDiscriminator(nn.Module):
135
+ def __init__(self,
136
+ in_channels=1,
137
+ spatial_dims = 3,
138
+ hid_chs = [64, 128, 256, 512, 512],
139
+ kernel_sizes=[4, 4, 4, 4, 4],
140
+ strides = [2, 2, 2, 1, 1],
141
+ act_name=("LeakyReLU", {'negative_slope': 0.2}),
142
+ norm_name = ("BATCH", {}),
143
+ dropout=None
144
+ ):
145
+ super().__init__()
146
+
147
+ self.inc = BasicBlock(
148
+ spatial_dims=spatial_dims,
149
+ in_channels=in_channels,
150
+ out_channels=hid_chs[0],
151
+ kernel_size=kernel_sizes[0],
152
+ stride=strides[0],
153
+ norm_name=None,
154
+ act_name=act_name,
155
+ dropout=dropout,
156
+ )
157
+
158
+ self.encoder = nn.Sequential(*[
159
+ BasicBlock(
160
+ spatial_dims=spatial_dims,
161
+ in_channels=hid_chs[i-1],
162
+ out_channels=hid_chs[i],
163
+ kernel_size=kernel_sizes[i],
164
+ stride=strides[i],
165
+ act_name=act_name,
166
+ norm_name=norm_name,
167
+ dropout=dropout)
168
+ for i in range(1, len(strides))
169
+ ])
170
+
171
+
172
+ self.outc = BasicBlock(
173
+ spatial_dims=spatial_dims,
174
+ in_channels=hid_chs[-1],
175
+ out_channels=1,
176
+ kernel_size=4,
177
+ stride=1,
178
+ norm_name=None,
179
+ act_name=None,
180
+ dropout=False,
181
+ )
182
+
183
+ def forward(self, x):
184
+ x = self.inc(x)
185
+ x = self.encoder(x)
186
+ return self.outc(x)
187
+
188
+
189
+
190
+
191
+ class VQVAE(BasicModel):
192
+ def __init__(
193
+ self,
194
+ in_channels=3,
195
+ out_channels=3,
196
+ spatial_dims = 2,
197
+ emb_channels = 4,
198
+ num_embeddings = 8192,
199
+ hid_chs = [32, 64, 128, 256],
200
+ kernel_sizes=[ 3, 3, 3, 3],
201
+ strides = [ 1, 2, 2, 2],
202
+ norm_name = ("GROUP", {'num_groups':32, "affine": True}),
203
+ act_name=("Swish", {}),
204
+ dropout=0.0,
205
+ use_res_block=True,
206
+ deep_supervision=False,
207
+ learnable_interpolation=True,
208
+ use_attention='none',
209
+ beta = 0.25,
210
+ embedding_loss_weight=1.0,
211
+ perceiver = LPIPS,
212
+ perceiver_kwargs = {},
213
+ perceptual_loss_weight = 1.0,
214
+
215
+
216
+ optimizer=torch.optim.Adam,
217
+ optimizer_kwargs={'lr':1e-4},
218
+ lr_scheduler= None,
219
+ lr_scheduler_kwargs={},
220
+ loss = torch.nn.L1Loss,
221
+ loss_kwargs={'reduction': 'none'},
222
+
223
+ sample_every_n_steps = 1000
224
+
225
+ ):
226
+ super().__init__(
227
+ optimizer=optimizer,
228
+ optimizer_kwargs=optimizer_kwargs,
229
+ lr_scheduler=lr_scheduler,
230
+ lr_scheduler_kwargs=lr_scheduler_kwargs
231
+ )
232
+ self.sample_every_n_steps=sample_every_n_steps
233
+ self.loss_fct = loss(**loss_kwargs)
234
+ self.embedding_loss_weight = embedding_loss_weight
235
+ self.perceiver = perceiver(**perceiver_kwargs).eval() if perceiver is not None else None
236
+ self.perceptual_loss_weight = perceptual_loss_weight
237
+ use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
238
+ self.depth = len(strides)
239
+ self.deep_supervision = deep_supervision
240
+
241
+ # ----------- In-Convolution ------------
242
+ ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
243
+ self.inc = ConvBlock(spatial_dims, in_channels, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0],
244
+ act_name=act_name, norm_name=norm_name)
245
+
246
+ # ----------- Encoder ----------------
247
+ self.encoders = nn.ModuleList([
248
+ DownBlock(
249
+ spatial_dims,
250
+ hid_chs[i-1],
251
+ hid_chs[i],
252
+ kernel_sizes[i],
253
+ strides[i],
254
+ kernel_sizes[i],
255
+ norm_name,
256
+ act_name,
257
+ dropout,
258
+ use_res_block,
259
+ learnable_interpolation,
260
+ use_attention[i])
261
+ for i in range(1, self.depth)
262
+ ])
263
+
264
+ # ----------- Out-Encoder ------------
265
+ self.out_enc = BasicBlock(spatial_dims, hid_chs[-1], emb_channels, 1)
266
+
267
+
268
+ # ----------- Quantizer --------------
269
+ self.quantizer = VectorQuantizer(
270
+ num_embeddings=num_embeddings,
271
+ emb_channels=emb_channels,
272
+ beta=beta
273
+ )
274
+
275
+ # ----------- In-Decoder ------------
276
+ self.inc_dec = ConvBlock(spatial_dims, emb_channels, hid_chs[-1], 3, act_name=act_name, norm_name=norm_name)
277
+
278
+ # ------------ Decoder ----------
279
+ self.decoders = nn.ModuleList([
280
+ UpBlock(
281
+ spatial_dims,
282
+ hid_chs[i+1],
283
+ hid_chs[i],
284
+ kernel_size=kernel_sizes[i+1],
285
+ stride=strides[i+1],
286
+ upsample_kernel_size=strides[i+1],
287
+ norm_name=norm_name,
288
+ act_name=act_name,
289
+ dropout=dropout,
290
+ use_res_block=use_res_block,
291
+ learnable_interpolation=learnable_interpolation,
292
+ use_attention=use_attention[i],
293
+ skip_channels=0)
294
+ for i in range(self.depth-1)
295
+ ])
296
+
297
+ # --------------- Out-Convolution ----------------
298
+ self.outc = BasicBlock(spatial_dims, hid_chs[0], out_channels, 1, zero_conv=True)
299
+ if isinstance(deep_supervision, bool):
300
+ deep_supervision = self.depth-1 if deep_supervision else 0
301
+ self.outc_ver = nn.ModuleList([
302
+ BasicBlock(spatial_dims, hid_chs[i], out_channels, 1, zero_conv=True)
303
+ for i in range(1, deep_supervision+1)
304
+ ])
305
+
306
+
307
+ def encode(self, x):
308
+ h = self.inc(x)
309
+ for i in range(len(self.encoders)):
310
+ h = self.encoders[i](h)
311
+ z = self.out_enc(h)
312
+ return z
313
+
314
+ def decode(self, z):
315
+ z, _ = self.quantizer(z)
316
+ h = self.inc_dec(z)
317
+ for i in range(len(self.decoders), 0, -1):
318
+ h = self.decoders[i-1](h)
319
+ x = self.outc(h)
320
+ return x
321
+
322
+ def forward(self, x_in):
323
+ # --------- Encoder --------------
324
+ h = self.inc(x_in)
325
+ for i in range(len(self.encoders)):
326
+ h = self.encoders[i](h)
327
+ z = self.out_enc(h)
328
+
329
+ # --------- Quantizer --------------
330
+ z_q, emb_loss = self.quantizer(z)
331
+
332
+ # -------- Decoder -----------
333
+ out_hor = []
334
+ h = self.inc_dec(z_q)
335
+ for i in range(len(self.decoders)-1, -1, -1):
336
+ out_hor.append(self.outc_ver[i](h)) if i < len(self.outc_ver) else None
337
+ h = self.decoders[i](h)
338
+ out = self.outc(h)
339
+
340
+ return out, out_hor[::-1], emb_loss
341
+
342
+ def perception_loss(self, pred, target, depth=0):
343
+ if (self.perceiver is not None) and (depth<2):
344
+ self.perceiver.eval()
345
+ return self.perceiver(pred, target)*self.perceptual_loss_weight
346
+ else:
347
+ return 0
348
+
349
+ def ssim_loss(self, pred, target):
350
+ return 1-ssim(((pred+1)/2).clamp(0,1), (target.type(pred.dtype)+1)/2, data_range=1, size_average=False,
351
+ nonnegative_ssim=True).reshape(-1, *[1]*(pred.ndim-1))
352
+
353
+
354
+ def rec_loss(self, pred, pred_vertical, target):
355
+ interpolation_mode = 'nearest-exact'
356
+ weights = [1/2**i for i in range(1+len(pred_vertical))] # horizontal (equal) + vertical (reducing with every step down)
357
+ tot_weight = sum(weights)
358
+ weights = [w/tot_weight for w in weights]
359
+
360
+ # Loss
361
+ loss = 0
362
+ loss += torch.mean(self.loss_fct(pred, target)+self.perception_loss(pred, target)+self.ssim_loss(pred, target))*weights[0]
363
+
364
+ for i, pred_i in enumerate(pred_vertical):
365
+ target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
366
+ loss += torch.mean(self.loss_fct(pred_i, target_i)+self.perception_loss(pred_i, target_i)+self.ssim_loss(pred_i, target_i))*weights[i+1]
367
+
368
+ return loss
369
+
370
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
371
+ # ------------------------- Get Source/Target ---------------------------
372
+ x = batch['source']
373
+ target = x
374
+
375
+ # ------------------------- Run Model ---------------------------
376
+ pred, pred_vertical, emb_loss = self(x)
377
+
378
+ # ------------------------- Compute Loss ---------------------------
379
+ loss = self.rec_loss(pred, pred_vertical, target)
380
+ loss += emb_loss*self.embedding_loss_weight
381
+
382
+ # --------------------- Compute Metrics -------------------------------
383
+ with torch.no_grad():
384
+ logging_dict = {'loss':loss, 'emb_loss': emb_loss}
385
+ logging_dict['L2'] = torch.nn.functional.mse_loss(pred, target)
386
+ logging_dict['L1'] = torch.nn.functional.l1_loss(pred, target)
387
+ logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
388
+
389
+ # ----------------- Log Scalars ----------------------
390
+ for metric_name, metric_val in logging_dict.items():
391
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
392
+
393
+ # ----------------- Save Image ------------------------------
394
+ if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0:
395
+ log_step = self.global_step // self.sample_every_n_steps
396
+ path_out = Path(self.logger.log_dir)/'images'
397
+ path_out.mkdir(parents=True, exist_ok=True)
398
+ # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
399
+ def depth2batch(image):
400
+ return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
401
+ images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
402
+ save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
403
+
404
+ return loss
405
+
406
+
407
+
408
+ class VQGAN(VeryBasicModel):
409
+ def __init__(
410
+ self,
411
+ in_channels=3,
412
+ out_channels=3,
413
+ spatial_dims = 2,
414
+ emb_channels = 4,
415
+ num_embeddings = 8192,
416
+ hid_chs = [ 64, 128, 256, 512],
417
+ kernel_sizes=[ 3, 3, 3, 3],
418
+ strides = [ 1, 2, 2, 2],
419
+ norm_name = ("GROUP", {'num_groups':32, "affine": True}),
420
+ act_name=("Swish", {}),
421
+ dropout=0.0,
422
+ use_res_block=True,
423
+ deep_supervision=False,
424
+ learnable_interpolation=True,
425
+ use_attention='none',
426
+ beta = 0.25,
427
+ embedding_loss_weight=1.0,
428
+ perceiver = LPIPS,
429
+ perceiver_kwargs = {},
430
+ perceptual_loss_weight: float = 1.0,
431
+
432
+
433
+ start_gan_train_step = 50000, # NOTE step increase with each optimizer
434
+ gan_loss_weight: float = 1.0, # = discriminator
435
+
436
+ optimizer_vqvae=torch.optim.Adam,
437
+ optimizer_gan=torch.optim.Adam,
438
+ optimizer_vqvae_kwargs={'lr':1e-6},
439
+ optimizer_gan_kwargs={'lr':1e-6},
440
+ lr_scheduler_vqvae= None,
441
+ lr_scheduler_vqvae_kwargs={},
442
+ lr_scheduler_gan= None,
443
+ lr_scheduler_gan_kwargs={},
444
+
445
+ pixel_loss = torch.nn.L1Loss,
446
+ pixel_loss_kwargs={'reduction':'none'},
447
+ gan_loss_fct = hinge_d_loss,
448
+
449
+ sample_every_n_steps = 1000
450
+
451
+ ):
452
+ super().__init__()
453
+ self.sample_every_n_steps=sample_every_n_steps
454
+ self.start_gan_train_step = start_gan_train_step
455
+ self.gan_loss_weight = gan_loss_weight
456
+ self.embedding_loss_weight = embedding_loss_weight
457
+
458
+ self.optimizer_vqvae = optimizer_vqvae
459
+ self.optimizer_gan = optimizer_gan
460
+ self.optimizer_vqvae_kwargs = optimizer_vqvae_kwargs
461
+ self.optimizer_gan_kwargs = optimizer_gan_kwargs
462
+ self.lr_scheduler_vqvae = lr_scheduler_vqvae
463
+ self.lr_scheduler_vqvae_kwargs = lr_scheduler_vqvae_kwargs
464
+ self.lr_scheduler_gan = lr_scheduler_gan
465
+ self.lr_scheduler_gan_kwargs = lr_scheduler_gan_kwargs
466
+
467
+ self.pixel_loss_fct = pixel_loss(**pixel_loss_kwargs)
468
+ self.gan_loss_fct = gan_loss_fct
469
+
470
+ self.vqvae = VQVAE(in_channels, out_channels, spatial_dims, emb_channels, num_embeddings, hid_chs, kernel_sizes,
471
+ strides, norm_name, act_name, dropout, use_res_block, deep_supervision, learnable_interpolation, use_attention,
472
+ beta, embedding_loss_weight, perceiver, perceiver_kwargs, perceptual_loss_weight)
473
+
474
+ self.discriminator = nn.ModuleList([Discriminator(in_channels, spatial_dims, hid_chs, kernel_sizes, strides,
475
+ act_name, norm_name, dropout) for i in range(len(self.vqvae.outc_ver)+1)])
476
+
477
+
478
+ # self.discriminator = nn.ModuleList([NLayerDiscriminator(in_channels, spatial_dims)
479
+ # for _ in range(len(self.vqvae.decoder.outc_ver)+1)])
480
+
481
+
482
+
483
+ def encode(self, x):
484
+ return self.vqvae.encode(x)
485
+
486
+ def decode(self, z):
487
+ return self.vqvae.decode(z)
488
+
489
+ def forward(self, x):
490
+ return self.vqvae.forward(x)
491
+
492
+
493
+ def vae_img_loss(self, pred, target, dec_out_layer, step, discriminator, depth=0):
494
+ # ------ VQVAE -------
495
+ rec_loss = self.vqvae.rec_loss(pred, [], target)
496
+
497
+ # ------- GAN -----
498
+ if step > self.start_gan_train_step:
499
+ gan_loss = -torch.mean(discriminator[depth](pred))
500
+ lambda_weight = self.compute_lambda(rec_loss, gan_loss, dec_out_layer)
501
+ gan_loss = gan_loss*lambda_weight
502
+
503
+ with torch.no_grad():
504
+ self.log(f"train/gan_loss_{depth}", gan_loss, on_step=True, on_epoch=True)
505
+ self.log(f"train/lambda_{depth}", lambda_weight, on_step=True, on_epoch=True)
506
+ else:
507
+ gan_loss = 0 #torch.tensor([0.0], requires_grad=True, device=target.device)
508
+
509
+ return self.gan_loss_weight*gan_loss+rec_loss
510
+
511
+
512
+ def gan_img_loss(self, pred, target, step, discriminators, depth):
513
+ if (step > self.start_gan_train_step) and (depth<len(discriminators)):
514
+ logits_real = discriminators[depth](target.detach())
515
+ logits_fake = discriminators[depth](pred.detach())
516
+ loss = self.gan_loss_fct(logits_real, logits_fake)
517
+ else:
518
+ loss = torch.tensor(0.0, requires_grad=True, device=target.device)
519
+
520
+ with torch.no_grad():
521
+ self.log(f"train/loss_1_{depth}", loss, on_step=True, on_epoch=True)
522
+ return loss
523
+
524
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
525
+ # ------------------------- Get Source/Target ---------------------------
526
+ x = batch['source']
527
+ target = x
528
+
529
+ # ------------------------- Run Model ---------------------------
530
+ pred, pred_vertical, emb_loss = self(x)
531
+
532
+ # ------------------------- Compute Loss ---------------------------
533
+ interpolation_mode = 'area'
534
+ weights = [1/2**i for i in range(1+len(pred_vertical))] # horizontal + vertical (reducing with every step down)
535
+ tot_weight = sum(weights)
536
+ weights = [w/tot_weight for w in weights]
537
+ logging_dict = {}
538
+
539
+ if optimizer_idx == 0:
540
+ # Horizontal/Top Layer
541
+ img_loss = self.vae_img_loss(pred, target, self.vqvae.outc.conv, step, self.discriminator, 0)*weights[0]
542
+
543
+ # Vertical/Deep Layer
544
+ for i, pred_i in enumerate(pred_vertical):
545
+ target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
546
+ img_loss += self.vae_img_loss(pred_i, target_i, self.vqvae.outc_ver[i].conv, step, self.discriminator, i+1)*weights[i+1]
547
+ loss = img_loss+self.embedding_loss_weight*emb_loss
548
+
549
+ with torch.no_grad():
550
+ logging_dict[f'img_loss'] = img_loss
551
+ logging_dict[f'emb_loss'] = emb_loss
552
+ logging_dict['loss_0'] = loss
553
+
554
+ elif optimizer_idx == 1:
555
+ # Horizontal/Top Layer
556
+ loss = self.gan_img_loss(pred, target, step, self.discriminator, 0)*weights[0]
557
+
558
+ # Vertical/Deep Layer
559
+ for i, pred_i in enumerate(pred_vertical):
560
+ target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
561
+ loss += self.gan_img_loss(pred_i, target_i, step, self.discriminator, i+1)*weights[i+1]
562
+
563
+ with torch.no_grad():
564
+ logging_dict['loss_1'] = loss
565
+
566
+
567
+ # --------------------- Compute Metrics -------------------------------
568
+ with torch.no_grad():
569
+ logging_dict['loss'] = loss
570
+ logging_dict[f'L2'] = torch.nn.functional.mse_loss(pred, x)
571
+ logging_dict[f'L1'] = torch.nn.functional.l1_loss(pred, x)
572
+ logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
573
+
574
+ # ----------------- Log Scalars ----------------------
575
+ for metric_name, metric_val in logging_dict.items():
576
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
577
+
578
+ # ----------------- Save Image ------------------------------
579
+ if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ...
580
+
581
+ log_step = self.global_step // self.sample_every_n_steps
582
+ path_out = Path(self.logger.log_dir)/'images'
583
+ path_out.mkdir(parents=True, exist_ok=True)
584
+ # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
585
+ def depth2batch(image):
586
+ return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
587
+ images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
588
+ save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
589
+
590
+ return loss
591
+
592
+ def configure_optimizers(self):
593
+ opt_vqvae = self.optimizer_vqvae(self.vqvae.parameters(), **self.optimizer_vqvae_kwargs)
594
+ opt_gan = self.optimizer_gan(self.discriminator.parameters(), **self.optimizer_gan_kwargs)
595
+ schedulers = []
596
+ if self.lr_scheduler_vqvae is not None:
597
+ schedulers.append({
598
+ 'scheduler': self.lr_scheduler_vqvae(opt_vqvae, **self.lr_scheduler_vqvae_kwargs),
599
+ 'interval': 'step',
600
+ 'frequency': 1
601
+ })
602
+ if self.lr_scheduler_gan is not None:
603
+ schedulers.append({
604
+ 'scheduler': self.lr_scheduler_gan(opt_gan, **self.lr_scheduler_gan_kwargs),
605
+ 'interval': 'step',
606
+ 'frequency': 1
607
+ })
608
+ return [opt_vqvae, opt_gan], schedulers
609
+
610
+ def compute_lambda(self, rec_loss, gan_loss, dec_out_layer, eps=1e-4):
611
+ """Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841"""
612
+ rec_grads = torch.autograd.grad(rec_loss, dec_out_layer.weight, retain_graph=True)[0]
613
+ gan_grads = torch.autograd.grad(gan_loss, dec_out_layer.weight, retain_graph=True)[0]
614
+ d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps)
615
+ d_weight = torch.clamp(d_weight, 0.0, 1e4)
616
+ return d_weight.detach()
617
+
618
+
619
+
620
+ class VAE(BasicModel):
621
+ def __init__(
622
+ self,
623
+ in_channels=3,
624
+ out_channels=3,
625
+ spatial_dims = 2,
626
+ emb_channels = 4,
627
+ hid_chs = [ 64, 128, 256, 512],
628
+ kernel_sizes=[ 3, 3, 3, 3],
629
+ strides = [ 1, 2, 2, 2],
630
+ norm_name = ("GROUP", {'num_groups':8, "affine": True}),
631
+ act_name=("Swish", {}),
632
+ dropout=None,
633
+ use_res_block=True,
634
+ deep_supervision=False,
635
+ learnable_interpolation=True,
636
+ use_attention='none',
637
+ embedding_loss_weight=1e-6,
638
+ perceiver = LPIPS,
639
+ perceiver_kwargs = {},
640
+ perceptual_loss_weight = 1.0,
641
+
642
+
643
+ optimizer=torch.optim.Adam,
644
+ optimizer_kwargs={'lr':1e-4},
645
+ lr_scheduler= None,
646
+ lr_scheduler_kwargs={},
647
+ loss = torch.nn.L1Loss,
648
+ loss_kwargs={'reduction': 'none'},
649
+
650
+ sample_every_n_steps = 1000
651
+
652
+ ):
653
+ super().__init__(
654
+ optimizer=optimizer,
655
+ optimizer_kwargs=optimizer_kwargs,
656
+ lr_scheduler=lr_scheduler,
657
+ lr_scheduler_kwargs=lr_scheduler_kwargs
658
+ )
659
+ self.sample_every_n_steps=sample_every_n_steps
660
+ self.loss_fct = loss(**loss_kwargs)
661
+ # self.ssim_fct = SSIM(data_range=1, size_average=False, channel=out_channels, spatial_dims=spatial_dims, nonnegative_ssim=True)
662
+ self.embedding_loss_weight = embedding_loss_weight
663
+ self.perceiver = perceiver(**perceiver_kwargs).eval() if perceiver is not None else None
664
+ self.perceptual_loss_weight = perceptual_loss_weight
665
+ use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
666
+ self.depth = len(strides)
667
+ self.deep_supervision = deep_supervision
668
+ downsample_kernel_sizes = kernel_sizes
669
+ upsample_kernel_sizes = strides
670
+
671
+ # -------- Loss-Reg---------
672
+ # self.logvar = nn.Parameter(torch.zeros(size=()) )
673
+
674
+ # ----------- In-Convolution ------------
675
+ ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
676
+ self.inc = ConvBlock(
677
+ spatial_dims,
678
+ in_channels,
679
+ hid_chs[0],
680
+ kernel_size=kernel_sizes[0],
681
+ stride=strides[0],
682
+ act_name=act_name,
683
+ norm_name=norm_name,
684
+ emb_channels=None
685
+ )
686
+
687
+ # ----------- Encoder ----------------
688
+ self.encoders = nn.ModuleList([
689
+ DownBlock(
690
+ spatial_dims = spatial_dims,
691
+ in_channels = hid_chs[i-1],
692
+ out_channels = hid_chs[i],
693
+ kernel_size = kernel_sizes[i],
694
+ stride = strides[i],
695
+ downsample_kernel_size = downsample_kernel_sizes[i],
696
+ norm_name = norm_name,
697
+ act_name = act_name,
698
+ dropout = dropout,
699
+ use_res_block = use_res_block,
700
+ learnable_interpolation = learnable_interpolation,
701
+ use_attention = use_attention[i],
702
+ emb_channels = None
703
+ )
704
+ for i in range(1, self.depth)
705
+ ])
706
+
707
+ # ----------- Out-Encoder ------------
708
+ self.out_enc = nn.Sequential(
709
+ BasicBlock(spatial_dims, hid_chs[-1], 2*emb_channels, 3),
710
+ BasicBlock(spatial_dims, 2*emb_channels, 2*emb_channels, 1)
711
+ )
712
+
713
+
714
+ # ----------- Reparameterization --------------
715
+ self.quantizer = DiagonalGaussianDistribution()
716
+
717
+
718
+ # ----------- In-Decoder ------------
719
+ self.inc_dec = ConvBlock(spatial_dims, emb_channels, hid_chs[-1], 3, act_name=act_name, norm_name=norm_name)
720
+
721
+ # ------------ Decoder ----------
722
+ self.decoders = nn.ModuleList([
723
+ UpBlock(
724
+ spatial_dims = spatial_dims,
725
+ in_channels = hid_chs[i+1],
726
+ out_channels = hid_chs[i],
727
+ kernel_size=kernel_sizes[i+1],
728
+ stride=strides[i+1],
729
+ upsample_kernel_size=upsample_kernel_sizes[i+1],
730
+ norm_name=norm_name,
731
+ act_name=act_name,
732
+ dropout=dropout,
733
+ use_res_block=use_res_block,
734
+ learnable_interpolation=learnable_interpolation,
735
+ use_attention=use_attention[i],
736
+ emb_channels=None,
737
+ skip_channels=0
738
+ )
739
+ for i in range(self.depth-1)
740
+ ])
741
+
742
+ # --------------- Out-Convolution ----------------
743
+ self.outc = BasicBlock(spatial_dims, hid_chs[0], out_channels, 1, zero_conv=True)
744
+ if isinstance(deep_supervision, bool):
745
+ deep_supervision = self.depth-1 if deep_supervision else 0
746
+ self.outc_ver = nn.ModuleList([
747
+ BasicBlock(spatial_dims, hid_chs[i], out_channels, 1, zero_conv=True)
748
+ for i in range(1, deep_supervision+1)
749
+ ])
750
+ # self.logvar_ver = nn.ParameterList([
751
+ # nn.Parameter(torch.zeros(size=()) )
752
+ # for _ in range(1, deep_supervision+1)
753
+ # ])
754
+
755
+
756
+ def encode(self, x):
757
+ h = self.inc(x)
758
+ for i in range(len(self.encoders)):
759
+ h = self.encoders[i](h)
760
+ z = self.out_enc(h)
761
+ z, _ = self.quantizer(z)
762
+ return z
763
+
764
+ def decode(self, z):
765
+ h = self.inc_dec(z)
766
+ for i in range(len(self.decoders), 0, -1):
767
+ h = self.decoders[i-1](h)
768
+ x = self.outc(h)
769
+ return x
770
+
771
+ def forward(self, x_in):
772
+ # --------- Encoder --------------
773
+ h = self.inc(x_in)
774
+ for i in range(len(self.encoders)):
775
+ h = self.encoders[i](h)
776
+ z = self.out_enc(h)
777
+
778
+ # --------- Quantizer --------------
779
+ z_q, emb_loss = self.quantizer(z)
780
+
781
+ # -------- Decoder -----------
782
+ out_hor = []
783
+ h = self.inc_dec(z_q)
784
+ for i in range(len(self.decoders)-1, -1, -1):
785
+ out_hor.append(self.outc_ver[i](h)) if i < len(self.outc_ver) else None
786
+ h = self.decoders[i](h)
787
+ out = self.outc(h)
788
+
789
+ return out, out_hor[::-1], emb_loss
790
+
791
+ def perception_loss(self, pred, target, depth=0):
792
+ if (self.perceiver is not None) and (depth<2):
793
+ self.perceiver.eval()
794
+ return self.perceiver(pred, target)*self.perceptual_loss_weight
795
+ else:
796
+ return 0
797
+
798
+ def ssim_loss(self, pred, target):
799
+ return 1-ssim(((pred+1)/2).clamp(0,1), (target.type(pred.dtype)+1)/2, data_range=1, size_average=False,
800
+ nonnegative_ssim=True).reshape(-1, *[1]*(pred.ndim-1))
801
+
802
+ def rec_loss(self, pred, pred_vertical, target):
803
+ interpolation_mode = 'nearest-exact'
804
+
805
+ # Loss
806
+ loss = 0
807
+ rec_loss = self.loss_fct(pred, target)+self.perception_loss(pred, target)+self.ssim_loss(pred, target)
808
+ # rec_loss = rec_loss/ torch.exp(self.logvar) + self.logvar # Note this is include in Stable-Diffusion but logvar is not used in optimizer
809
+ loss += torch.sum(rec_loss)/pred.shape[0]
810
+
811
+
812
+ for i, pred_i in enumerate(pred_vertical):
813
+ target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
814
+ rec_loss_i = self.loss_fct(pred_i, target_i)+self.perception_loss(pred_i, target_i)+self.ssim_loss(pred_i, target_i)
815
+ # rec_loss_i = rec_loss_i/ torch.exp(self.logvar_ver[i]) + self.logvar_ver[i]
816
+ loss += torch.sum(rec_loss_i)/pred.shape[0]
817
+
818
+ return loss
819
+
820
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
821
+ # ------------------------- Get Source/Target ---------------------------
822
+ x = batch['source']
823
+ target = x
824
+
825
+ # ------------------------- Run Model ---------------------------
826
+ pred, pred_vertical, emb_loss = self(x)
827
+
828
+ # ------------------------- Compute Loss ---------------------------
829
+ loss = self.rec_loss(pred, pred_vertical, target)
830
+ loss += emb_loss*self.embedding_loss_weight
831
+
832
+ # --------------------- Compute Metrics -------------------------------
833
+ with torch.no_grad():
834
+ logging_dict = {'loss':loss, 'emb_loss': emb_loss}
835
+ logging_dict['L2'] = torch.nn.functional.mse_loss(pred, target)
836
+ logging_dict['L1'] = torch.nn.functional.l1_loss(pred, target)
837
+ logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
838
+ # logging_dict['logvar'] = self.logvar
839
+
840
+ # ----------------- Log Scalars ----------------------
841
+ for metric_name, metric_val in logging_dict.items():
842
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
843
+
844
+ # ----------------- Save Image ------------------------------
845
+ if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0:
846
+ log_step = self.global_step // self.sample_every_n_steps
847
+ path_out = Path(self.logger.log_dir)/'images'
848
+ path_out.mkdir(parents=True, exist_ok=True)
849
+ # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
850
+ def depth2batch(image):
851
+ return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
852
+ images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
853
+ save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
854
+
855
+ return loss
856
+
857
+
858
+
859
+
860
+ class VAEGAN(VeryBasicModel):
861
+ def __init__(
862
+ self,
863
+ in_channels=3,
864
+ out_channels=3,
865
+ spatial_dims = 2,
866
+ emb_channels = 4,
867
+ hid_chs = [ 64, 128, 256, 512],
868
+ kernel_sizes=[ 3, 3, 3, 3],
869
+ strides = [ 1, 2, 2, 2],
870
+ norm_name = ("GROUP", {'num_groups':8, "affine": True}),
871
+ act_name=("Swish", {}),
872
+ dropout=0.0,
873
+ use_res_block=True,
874
+ deep_supervision=False,
875
+ learnable_interpolation=True,
876
+ use_attention='none',
877
+ embedding_loss_weight=1e-6,
878
+ perceiver = LPIPS,
879
+ perceiver_kwargs = {},
880
+ perceptual_loss_weight = 1.0,
881
+
882
+
883
+ start_gan_train_step = 50000, # NOTE step increase with each optimizer
884
+ gan_loss_weight: float = 1.0, # = discriminator
885
+
886
+ optimizer_vqvae=torch.optim.Adam,
887
+ optimizer_gan=torch.optim.Adam,
888
+ optimizer_vqvae_kwargs={'lr':1e-6}, # 'weight_decay':1e-2, {'lr':1e-6, 'betas':(0.5, 0.9)}
889
+ optimizer_gan_kwargs={'lr':1e-6}, # 'weight_decay':1e-2,
890
+ lr_scheduler_vqvae= None,
891
+ lr_scheduler_vqvae_kwargs={},
892
+ lr_scheduler_gan= None,
893
+ lr_scheduler_gan_kwargs={},
894
+
895
+ pixel_loss = torch.nn.L1Loss,
896
+ pixel_loss_kwargs={'reduction':'none'},
897
+ gan_loss_fct = hinge_d_loss,
898
+
899
+ sample_every_n_steps = 1000
900
+
901
+ ):
902
+ super().__init__()
903
+ self.sample_every_n_steps=sample_every_n_steps
904
+ self.start_gan_train_step = start_gan_train_step
905
+ self.gan_loss_weight = gan_loss_weight
906
+ self.embedding_loss_weight = embedding_loss_weight
907
+
908
+ self.optimizer_vqvae = optimizer_vqvae
909
+ self.optimizer_gan = optimizer_gan
910
+ self.optimizer_vqvae_kwargs = optimizer_vqvae_kwargs
911
+ self.optimizer_gan_kwargs = optimizer_gan_kwargs
912
+ self.lr_scheduler_vqvae = lr_scheduler_vqvae
913
+ self.lr_scheduler_vqvae_kwargs = lr_scheduler_vqvae_kwargs
914
+ self.lr_scheduler_gan = lr_scheduler_gan
915
+ self.lr_scheduler_gan_kwargs = lr_scheduler_gan_kwargs
916
+
917
+ self.pixel_loss_fct = pixel_loss(**pixel_loss_kwargs)
918
+ self.gan_loss_fct = gan_loss_fct
919
+
920
+ self.vqvae = VAE(in_channels, out_channels, spatial_dims, emb_channels, hid_chs, kernel_sizes,
921
+ strides, norm_name, act_name, dropout, use_res_block, deep_supervision, learnable_interpolation, use_attention,
922
+ embedding_loss_weight, perceiver, perceiver_kwargs, perceptual_loss_weight)
923
+
924
+ self.discriminator = nn.ModuleList([Discriminator(in_channels, spatial_dims, hid_chs, kernel_sizes, strides,
925
+ act_name, norm_name, dropout) for i in range(len(self.vqvae.outc_ver)+1)])
926
+
927
+
928
+ # self.discriminator = nn.ModuleList([NLayerDiscriminator(in_channels, spatial_dims)
929
+ # for _ in range(len(self.vqvae.outc_ver)+1)])
930
+
931
+
932
+
933
+ def encode(self, x):
934
+ return self.vqvae.encode(x)
935
+
936
+ def decode(self, z):
937
+ return self.vqvae.decode(z)
938
+
939
+ def forward(self, x):
940
+ return self.vqvae.forward(x)
941
+
942
+
943
+ def vae_img_loss(self, pred, target, dec_out_layer, step, discriminator, depth=0):
944
+ # ------ VQVAE -------
945
+ rec_loss = self.vqvae.rec_loss(pred, [], target)
946
+
947
+ # ------- GAN -----
948
+ if (step > self.start_gan_train_step) and (depth<2):
949
+ gan_loss = -torch.sum(discriminator[depth](pred)) # clamp(..., None, 0) => only punish areas that were rated as fake (<0) by discriminator => ensures loss >0 and +- don't cannel out in sum
950
+ lambda_weight = self.compute_lambda(rec_loss, gan_loss, dec_out_layer)
951
+ gan_loss = gan_loss*lambda_weight
952
+
953
+ with torch.no_grad():
954
+ self.log(f"train/gan_loss_{depth}", gan_loss, on_step=True, on_epoch=True)
955
+ self.log(f"train/lambda_{depth}", lambda_weight, on_step=True, on_epoch=True)
956
+ else:
957
+ gan_loss = 0 #torch.tensor([0.0], requires_grad=True, device=target.device)
958
+
959
+
960
+
961
+ return self.gan_loss_weight*gan_loss+rec_loss
962
+
963
+ def gan_img_loss(self, pred, target, step, discriminators, depth):
964
+ if (step > self.start_gan_train_step) and (depth<len(discriminators)):
965
+ logits_real = discriminators[depth](target.detach())
966
+ logits_fake = discriminators[depth](pred.detach())
967
+ loss = self.gan_loss_fct(logits_real, logits_fake)
968
+ else:
969
+ loss = torch.tensor(0.0, requires_grad=True, device=target.device)
970
+
971
+ with torch.no_grad():
972
+ self.log(f"train/loss_1_{depth}", loss, on_step=True, on_epoch=True)
973
+ return loss
974
+
975
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
976
+ # ------------------------- Get Source/Target ---------------------------
977
+ x = batch['source']
978
+ target = x
979
+
980
+ # ------------------------- Run Model ---------------------------
981
+ pred, pred_vertical, emb_loss = self(x)
982
+
983
+ # ------------------------- Compute Loss ---------------------------
984
+ interpolation_mode = 'area'
985
+ logging_dict = {}
986
+
987
+ if optimizer_idx == 0:
988
+ # Horizontal/Top Layer
989
+ img_loss = self.vae_img_loss(pred, target, self.vqvae.outc.conv, step, self.discriminator, 0)
990
+
991
+ # Vertical/Deep Layer
992
+ for i, pred_i in enumerate(pred_vertical):
993
+ target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
994
+ img_loss += self.vae_img_loss(pred_i, target_i, self.vqvae.outc_ver[i].conv, step, self.discriminator, i+1)
995
+ loss = img_loss+self.embedding_loss_weight*emb_loss
996
+
997
+ with torch.no_grad():
998
+ logging_dict[f'img_loss'] = img_loss
999
+ logging_dict[f'emb_loss'] = emb_loss
1000
+ logging_dict['loss_0'] = loss
1001
+
1002
+ elif optimizer_idx == 1:
1003
+ # Horizontal/Top Layer
1004
+ loss = self.gan_img_loss(pred, target, step, self.discriminator, 0)
1005
+
1006
+ # Vertical/Deep Layer
1007
+ for i, pred_i in enumerate(pred_vertical):
1008
+ target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
1009
+ loss += self.gan_img_loss(pred_i, target_i, step, self.discriminator, i+1)
1010
+
1011
+ with torch.no_grad():
1012
+ logging_dict['loss_1'] = loss
1013
+
1014
+
1015
+ # --------------------- Compute Metrics -------------------------------
1016
+ with torch.no_grad():
1017
+ logging_dict['loss'] = loss
1018
+ logging_dict[f'L2'] = torch.nn.functional.mse_loss(pred, x)
1019
+ logging_dict[f'L1'] = torch.nn.functional.l1_loss(pred, x)
1020
+ logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
1021
+ # logging_dict['logvar'] = self.vqvae.logvar
1022
+
1023
+ # ----------------- Log Scalars ----------------------
1024
+ for metric_name, metric_val in logging_dict.items():
1025
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
1026
+
1027
+ # ----------------- Save Image ------------------------------
1028
+ if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ...
1029
+
1030
+ log_step = self.global_step // self.sample_every_n_steps
1031
+ path_out = Path(self.logger.log_dir)/'images'
1032
+ path_out.mkdir(parents=True, exist_ok=True)
1033
+ # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
1034
+ def depth2batch(image):
1035
+ return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
1036
+ images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
1037
+ save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
1038
+
1039
+ return loss
1040
+
1041
+ def configure_optimizers(self):
1042
+ opt_vqvae = self.optimizer_vqvae(self.vqvae.parameters(), **self.optimizer_vqvae_kwargs)
1043
+ opt_gan = self.optimizer_gan(self.discriminator.parameters(), **self.optimizer_gan_kwargs)
1044
+ schedulers = []
1045
+ if self.lr_scheduler_vqvae is not None:
1046
+ schedulers.append({
1047
+ 'scheduler': self.lr_scheduler_vqvae(opt_vqvae, **self.lr_scheduler_vqvae_kwargs),
1048
+ 'interval': 'step',
1049
+ 'frequency': 1
1050
+ })
1051
+ if self.lr_scheduler_gan is not None:
1052
+ schedulers.append({
1053
+ 'scheduler': self.lr_scheduler_gan(opt_gan, **self.lr_scheduler_gan_kwargs),
1054
+ 'interval': 'step',
1055
+ 'frequency': 1
1056
+ })
1057
+ return [opt_vqvae, opt_gan], schedulers
1058
+
1059
+ def compute_lambda(self, rec_loss, gan_loss, dec_out_layer, eps=1e-4):
1060
+ """Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841"""
1061
+ rec_grads = torch.autograd.grad(rec_loss, dec_out_layer.weight, retain_graph=True)[0]
1062
+ gan_grads = torch.autograd.grad(gan_loss, dec_out_layer.weight, retain_graph=True)[0]
1063
+ d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps)
1064
+ d_weight = torch.clamp(d_weight, 0.0, 1e4)
1065
+ return d_weight.detach()
medical_diffusion/models/embedders/time_embedder.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from monai.networks.layers.utils import get_act_layer
6
+
7
+ class SinusoidalPosEmb(nn.Module):
8
+ def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False):
9
+ super().__init__()
10
+ self.emb_dim = emb_dim
11
+ self.downscale_freq_shift = downscale_freq_shift
12
+ self.max_period = max_period
13
+ self.flip_sin_to_cos=flip_sin_to_cos
14
+
15
+ def forward(self, x):
16
+ device = x.device
17
+ half_dim = self.emb_dim // 2
18
+ emb = math.log(self.max_period) / (half_dim - self.downscale_freq_shift)
19
+ emb = torch.exp(-emb*torch.arange(half_dim, device=device))
20
+ emb = x[:, None] * emb[None, :]
21
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
22
+
23
+ if self.flip_sin_to_cos:
24
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
25
+
26
+ if self.emb_dim % 2 == 1:
27
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
28
+ return emb
29
+
30
+
31
+ class LearnedSinusoidalPosEmb(nn.Module):
32
+ """ following @crowsonkb 's lead with learned sinusoidal pos emb """
33
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
34
+
35
+ def __init__(self, emb_dim):
36
+ super().__init__()
37
+ self.emb_dim = emb_dim
38
+ half_dim = emb_dim // 2
39
+ self.weights = nn.Parameter(torch.randn(half_dim))
40
+
41
+ def forward(self, x):
42
+ x = x[:, None]
43
+ freqs = x * self.weights[None, :] * 2 * math.pi
44
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
45
+ fouriered = torch.cat((x, fouriered), dim = -1)
46
+ if self.emb_dim % 2 == 1:
47
+ fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0))
48
+ return fouriered
49
+
50
+
51
+
52
+ class TimeEmbbeding(nn.Module):
53
+ def __init__(
54
+ self,
55
+ emb_dim = 64,
56
+ pos_embedder = SinusoidalPosEmb,
57
+ pos_embedder_kwargs = {},
58
+ act_name=("SWISH", {}) # Swish = SiLU
59
+ ):
60
+ super().__init__()
61
+ self.emb_dim = emb_dim
62
+ self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4)
63
+ pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim
64
+ self.pos_embedder = pos_embedder(**pos_embedder_kwargs)
65
+
66
+
67
+ self.time_emb = nn.Sequential(
68
+ self.pos_embedder,
69
+ nn.Linear(self.pos_emb_dim, self.emb_dim),
70
+ get_act_layer(act_name),
71
+ nn.Linear(self.emb_dim, self.emb_dim)
72
+ )
73
+
74
+ def forward(self, time):
75
+ return self.time_emb(time)
medical_diffusion/models/estimators/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .unet2 import UNet
medical_diffusion/models/estimators/unet.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from monai.networks.blocks import UnetOutBlock
5
+
6
+ from medical_diffusion.models.utils.conv_blocks import BasicBlock, UpBlock, DownBlock, UnetBasicBlock, UnetResBlock, save_add
7
+ from medical_diffusion.models.embedders import TimeEmbbeding
8
+ from medical_diffusion.models.utils.attention_blocks import SpatialTransformer, LinearTransformer
9
+
10
+
11
+
12
+
13
+
14
+
15
+ class UNet(nn.Module):
16
+
17
+ def __init__(self,
18
+ in_ch=1,
19
+ out_ch=1,
20
+ spatial_dims = 3,
21
+ hid_chs = [32, 64, 128, 256],
22
+ kernel_sizes=[ 1, 3, 3, 3],
23
+ strides = [ 1, 2, 2, 2],
24
+ downsample_kernel_sizes = None,
25
+ upsample_kernel_sizes = None,
26
+ act_name=("SWISH", {}),
27
+ norm_name = ("GROUP", {'num_groups':32, "affine": True}),
28
+ time_embedder=TimeEmbbeding,
29
+ time_embedder_kwargs={},
30
+ cond_embedder=None,
31
+ cond_embedder_kwargs={},
32
+ deep_supervision=True, # True = all but last layer, 0/False=disable, 1=only first layer, ...
33
+ use_res_block=True,
34
+ estimate_variance=False ,
35
+ use_self_conditioning = False,
36
+ dropout=0.0,
37
+ learnable_interpolation=True,
38
+ use_attention='none',
39
+ ):
40
+ super().__init__()
41
+ use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
42
+ self.use_self_conditioning = use_self_conditioning
43
+ self.use_res_block = use_res_block
44
+ self.depth = len(strides)
45
+ if downsample_kernel_sizes is None:
46
+ downsample_kernel_sizes = kernel_sizes
47
+ if upsample_kernel_sizes is None:
48
+ upsample_kernel_sizes = strides
49
+
50
+
51
+ # ------------- Time-Embedder-----------
52
+ if time_embedder is not None:
53
+ self.time_embedder=time_embedder(**time_embedder_kwargs)
54
+ time_emb_dim = self.time_embedder.emb_dim
55
+ else:
56
+ self.time_embedder = None
57
+
58
+ # ------------- Condition-Embedder-----------
59
+ if cond_embedder is not None:
60
+ self.cond_embedder=cond_embedder(**cond_embedder_kwargs)
61
+ else:
62
+ self.cond_embedder = None
63
+
64
+ # ----------- In-Convolution ------------
65
+ in_ch = in_ch*2 if self.use_self_conditioning else in_ch
66
+ ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
67
+ self.inc = ConvBlock(
68
+ spatial_dims = spatial_dims,
69
+ in_channels = in_ch,
70
+ out_channels = hid_chs[0],
71
+ kernel_size=kernel_sizes[0],
72
+ stride=strides[0],
73
+ act_name=act_name,
74
+ norm_name=norm_name,
75
+ emb_channels=time_emb_dim
76
+ )
77
+
78
+
79
+ # ----------- Encoder ----------------
80
+ self.encoders = nn.ModuleList([
81
+ DownBlock(
82
+ spatial_dims = spatial_dims,
83
+ in_channels = hid_chs[i-1],
84
+ out_channels = hid_chs[i],
85
+ kernel_size = kernel_sizes[i],
86
+ stride = strides[i],
87
+ downsample_kernel_size = downsample_kernel_sizes[i],
88
+ norm_name = norm_name,
89
+ act_name = act_name,
90
+ dropout = dropout,
91
+ use_res_block = use_res_block,
92
+ learnable_interpolation = learnable_interpolation,
93
+ use_attention = use_attention[i],
94
+ emb_channels = time_emb_dim
95
+ )
96
+ for i in range(1, self.depth)
97
+ ])
98
+
99
+
100
+
101
+ # ------------ Decoder ----------
102
+ self.decoders = nn.ModuleList([
103
+ UpBlock(
104
+ spatial_dims = spatial_dims,
105
+ in_channels = hid_chs[i+1],
106
+ out_channels = hid_chs[i],
107
+ kernel_size=kernel_sizes[i+1],
108
+ stride=strides[i+1],
109
+ upsample_kernel_size=upsample_kernel_sizes[i+1],
110
+ norm_name=norm_name,
111
+ act_name=act_name,
112
+ dropout=dropout,
113
+ use_res_block=use_res_block,
114
+ learnable_interpolation=learnable_interpolation,
115
+ use_attention=use_attention[i],
116
+ emb_channels=time_emb_dim,
117
+ skip_channels=hid_chs[i]
118
+ )
119
+ for i in range(self.depth-1)
120
+ ])
121
+
122
+
123
+ # --------------- Out-Convolution ----------------
124
+ out_ch_hor = out_ch*2 if estimate_variance else out_ch
125
+ self.outc = UnetOutBlock(spatial_dims, hid_chs[0], out_ch_hor, dropout=None)
126
+ if isinstance(deep_supervision, bool):
127
+ deep_supervision = self.depth-1 if deep_supervision else 0
128
+ self.outc_ver = nn.ModuleList([
129
+ UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None)
130
+ for i in range(1, deep_supervision+1)
131
+ ])
132
+
133
+
134
+ def forward(self, x_t, t=None, condition=None, self_cond=None):
135
+ # x_t [B, C, *]
136
+ # t [B,]
137
+ # condition [B,]
138
+ # self_cond [B, C, *]
139
+ x = [ None for _ in range(len(self.encoders)+1) ]
140
+
141
+ # -------- Time Embedding (Global) -----------
142
+ if t is None:
143
+ time_emb = None
144
+ else:
145
+ time_emb = self.time_embedder(t) # [B, C]
146
+
147
+ # -------- Condition Embedding (Global) -----------
148
+ if (condition is None) or (self.cond_embedder is None):
149
+ cond_emb = None
150
+ else:
151
+ cond_emb = self.cond_embedder(condition) # [B, C]
152
+
153
+ # ----------- Embedding Summation --------
154
+ emb = save_add(time_emb, cond_emb)
155
+
156
+ # ---------- Self-conditioning-----------
157
+ if self.use_self_conditioning:
158
+ self_cond = torch.zeros_like(x_t) if self_cond is None else x_t
159
+ x_t = torch.cat([x_t, self_cond], dim=1)
160
+
161
+ # -------- In-Convolution --------------
162
+ x[0] = self.inc(x_t, emb)
163
+
164
+ # --------- Encoder --------------
165
+ for i in range(len(self.encoders)):
166
+ x[i+1] = self.encoders[i](x[i], emb)
167
+
168
+ # -------- Decoder -----------
169
+ for i in range(len(self.decoders), 0, -1):
170
+ x[i-1] = self.decoders[i-1](x[i], x[i-1], emb)
171
+
172
+ # ---------Out-Convolution ------------
173
+ y = self.outc(x[0])
174
+ y_ver = [outc_ver_i(x[i+1]) for i, outc_ver_i in enumerate(self.outc_ver)]
175
+
176
+ return y, y_ver
177
+
178
+
179
+
180
+
181
+ if __name__=='__main__':
182
+ model = UNet(in_ch=3, use_res_block=False, learnable_interpolation=False)
183
+ input = torch.randn((1,3,16,128,128))
184
+ time = torch.randn((1,))
185
+ out_hor, out_ver = model(input, time)
186
+ print(out_hor[0].shape)
medical_diffusion/models/estimators/unet2.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from monai.networks.blocks import UnetOutBlock
5
+
6
+ from medical_diffusion.models.utils.conv_blocks import BasicBlock, UpBlock, DownBlock, UnetBasicBlock, UnetResBlock, save_add, BasicDown, BasicUp, SequentialEmb
7
+ from medical_diffusion.models.embedders import TimeEmbbeding
8
+ from medical_diffusion.models.utils.attention_blocks import Attention, zero_module
9
+
10
+
11
+
12
+
13
+
14
+
15
+ class UNet(nn.Module):
16
+
17
+ def __init__(self,
18
+ in_ch=1,
19
+ out_ch=1,
20
+ spatial_dims = 3,
21
+ hid_chs = [256, 256, 512, 1024],
22
+ kernel_sizes=[ 3, 3, 3, 3],
23
+ strides = [ 1, 2, 2, 2], # WARNING, last stride is ignored (follows OpenAI)
24
+ act_name=("SWISH", {}),
25
+ norm_name = ("GROUP", {'num_groups':32, "affine": True}),
26
+ time_embedder=TimeEmbbeding,
27
+ time_embedder_kwargs={},
28
+ cond_embedder=None,
29
+ cond_embedder_kwargs={},
30
+ deep_supervision=True, # True = all but last layer, 0/False=disable, 1=only first layer, ...
31
+ use_res_block=True,
32
+ estimate_variance=False ,
33
+ use_self_conditioning = False,
34
+ dropout=0.0,
35
+ learnable_interpolation=True,
36
+ use_attention='none',
37
+ num_res_blocks=2,
38
+ ):
39
+ super().__init__()
40
+ use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
41
+ self.use_self_conditioning = use_self_conditioning
42
+ self.use_res_block = use_res_block
43
+ self.depth = len(strides)
44
+ self.num_res_blocks = num_res_blocks
45
+
46
+ # ------------- Time-Embedder-----------
47
+ if time_embedder is not None:
48
+ self.time_embedder=time_embedder(**time_embedder_kwargs)
49
+ time_emb_dim = self.time_embedder.emb_dim
50
+ else:
51
+ self.time_embedder = None
52
+ time_emb_dim = None
53
+
54
+ # ------------- Condition-Embedder-----------
55
+ if cond_embedder is not None:
56
+ self.cond_embedder=cond_embedder(**cond_embedder_kwargs)
57
+ cond_emb_dim = self.cond_embedder.emb_dim
58
+ else:
59
+ self.cond_embedder = None
60
+ cond_emb_dim = None
61
+
62
+
63
+ ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
64
+
65
+ # ----------- In-Convolution ------------
66
+ in_ch = in_ch*2 if self.use_self_conditioning else in_ch
67
+ self.in_conv = BasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0])
68
+
69
+
70
+ # ----------- Encoder ------------
71
+ in_blocks = []
72
+ for i in range(1, self.depth):
73
+ for k in range(num_res_blocks):
74
+ seq_list = []
75
+ seq_list.append(
76
+ ConvBlock(
77
+ spatial_dims=spatial_dims,
78
+ in_channels=hid_chs[i-1 if k==0 else i],
79
+ out_channels=hid_chs[i],
80
+ kernel_size=kernel_sizes[i],
81
+ stride=1,
82
+ norm_name=norm_name,
83
+ act_name=act_name,
84
+ dropout=dropout,
85
+ emb_channels=time_emb_dim
86
+ )
87
+ )
88
+
89
+ seq_list.append(
90
+ Attention(
91
+ spatial_dims=spatial_dims,
92
+ in_channels=hid_chs[i],
93
+ out_channels=hid_chs[i],
94
+ num_heads=8,
95
+ ch_per_head=hid_chs[i]//8,
96
+ depth=1,
97
+ norm_name=norm_name,
98
+ dropout=dropout,
99
+ emb_dim=time_emb_dim,
100
+ attention_type=use_attention[i]
101
+ )
102
+ )
103
+ in_blocks.append(SequentialEmb(*seq_list))
104
+
105
+ if i < self.depth-1:
106
+ in_blocks.append(
107
+ BasicDown(
108
+ spatial_dims=spatial_dims,
109
+ in_channels=hid_chs[i],
110
+ out_channels=hid_chs[i],
111
+ kernel_size=kernel_sizes[i],
112
+ stride=strides[i],
113
+ learnable_interpolation=learnable_interpolation
114
+ )
115
+ )
116
+
117
+
118
+ self.in_blocks = nn.ModuleList(in_blocks)
119
+
120
+ # ----------- Middle ------------
121
+ self.middle_block = SequentialEmb(
122
+ ConvBlock(
123
+ spatial_dims=spatial_dims,
124
+ in_channels=hid_chs[-1],
125
+ out_channels=hid_chs[-1],
126
+ kernel_size=kernel_sizes[-1],
127
+ stride=1,
128
+ norm_name=norm_name,
129
+ act_name=act_name,
130
+ dropout=dropout,
131
+ emb_channels=time_emb_dim
132
+ ),
133
+ Attention(
134
+ spatial_dims=spatial_dims,
135
+ in_channels=hid_chs[-1],
136
+ out_channels=hid_chs[-1],
137
+ num_heads=8,
138
+ ch_per_head=hid_chs[-1]//8,
139
+ depth=1,
140
+ norm_name=norm_name,
141
+ dropout=dropout,
142
+ emb_dim=time_emb_dim,
143
+ attention_type=use_attention[-1]
144
+ ),
145
+ ConvBlock(
146
+ spatial_dims=spatial_dims,
147
+ in_channels=hid_chs[-1],
148
+ out_channels=hid_chs[-1],
149
+ kernel_size=kernel_sizes[-1],
150
+ stride=1,
151
+ norm_name=norm_name,
152
+ act_name=act_name,
153
+ dropout=dropout,
154
+ emb_channels=time_emb_dim
155
+ )
156
+ )
157
+
158
+
159
+
160
+ # ------------ Decoder ----------
161
+ out_blocks = []
162
+ for i in range(1, self.depth):
163
+ for k in range(num_res_blocks+1):
164
+ seq_list = []
165
+ out_channels=hid_chs[i-1 if k==0 else i]
166
+ seq_list.append(
167
+ ConvBlock(
168
+ spatial_dims=spatial_dims,
169
+ in_channels=hid_chs[i]+hid_chs[i-1 if k==0 else i],
170
+ out_channels=out_channels,
171
+ kernel_size=kernel_sizes[i],
172
+ stride=1,
173
+ norm_name=norm_name,
174
+ act_name=act_name,
175
+ dropout=dropout,
176
+ emb_channels=time_emb_dim
177
+ )
178
+ )
179
+
180
+ seq_list.append(
181
+ Attention(
182
+ spatial_dims=spatial_dims,
183
+ in_channels=out_channels,
184
+ out_channels=out_channels,
185
+ num_heads=8,
186
+ ch_per_head=out_channels//8,
187
+ depth=1,
188
+ norm_name=norm_name,
189
+ dropout=dropout,
190
+ emb_dim=time_emb_dim,
191
+ attention_type=use_attention[i]
192
+ )
193
+ )
194
+
195
+ if (i >1) and k==0:
196
+ seq_list.append(
197
+ BasicUp(
198
+ spatial_dims=spatial_dims,
199
+ in_channels=out_channels,
200
+ out_channels=out_channels,
201
+ kernel_size=strides[i],
202
+ stride=strides[i],
203
+ learnable_interpolation=learnable_interpolation
204
+ )
205
+ )
206
+
207
+ out_blocks.append(SequentialEmb(*seq_list))
208
+ self.out_blocks = nn.ModuleList(out_blocks)
209
+
210
+
211
+ # --------------- Out-Convolution ----------------
212
+ out_ch_hor = out_ch*2 if estimate_variance else out_ch
213
+ self.outc = zero_module(UnetOutBlock(spatial_dims, hid_chs[0], out_ch_hor, dropout=None))
214
+ if isinstance(deep_supervision, bool):
215
+ deep_supervision = self.depth-2 if deep_supervision else 0
216
+ self.outc_ver = nn.ModuleList([
217
+ zero_module(UnetOutBlock(spatial_dims, hid_chs[i]+hid_chs[i-1], out_ch, dropout=None) )
218
+ for i in range(2, deep_supervision+2)
219
+ ])
220
+
221
+
222
+ def forward(self, x_t, t=None, condition=None, self_cond=None):
223
+ # x_t [B, C, *]
224
+ # t [B,]
225
+ # condition [B,]
226
+ # self_cond [B, C, *]
227
+
228
+
229
+ # -------- Time Embedding (Gloabl) -----------
230
+ if t is None:
231
+ time_emb = None
232
+ else:
233
+ time_emb = self.time_embedder(t) # [B, C]
234
+
235
+ # -------- Condition Embedding (Gloabl) -----------
236
+ if (condition is None) or (self.cond_embedder is None):
237
+ cond_emb = None
238
+ else:
239
+ cond_emb = self.cond_embedder(condition) # [B, C]
240
+
241
+ emb = save_add(time_emb, cond_emb)
242
+
243
+ # ---------- Self-conditioning-----------
244
+ if self.use_self_conditioning:
245
+ self_cond = torch.zeros_like(x_t) if self_cond is None else x_t
246
+ x_t = torch.cat([x_t, self_cond], dim=1)
247
+
248
+ # --------- Encoder --------------
249
+ x = [self.in_conv(x_t)]
250
+ for i in range(len(self.in_blocks)):
251
+ x.append(self.in_blocks[i](x[i], emb))
252
+
253
+ # ---------- Middle --------------
254
+ h = self.middle_block(x[-1], emb)
255
+
256
+ # -------- Decoder -----------
257
+ y_ver = []
258
+ for i in range(len(self.out_blocks), 0, -1):
259
+ h = torch.cat([h, x.pop()], dim=1)
260
+
261
+ depth, j = i//(self.num_res_blocks+1), i%(self.num_res_blocks+1)-1
262
+ y_ver.append(self.outc_ver[depth-1](h)) if (len(self.outc_ver)>=depth>0) and (j==0) else None
263
+
264
+ h = self.out_blocks[i-1](h, emb)
265
+
266
+ # ---------Out-Convolution ------------
267
+ y = self.outc(h)
268
+
269
+ return y, y_ver[::-1]
270
+
271
+
272
+
273
+
274
+ if __name__=='__main__':
275
+ model = UNet(in_ch=3, use_res_block=False, learnable_interpolation=False)
276
+ input = torch.randn((1,3,16,32,32))
277
+ time = torch.randn((1,))
278
+ out_hor, out_ver = model(input, time)
279
+ print(out_hor[0].shape)
medical_diffusion/models/model_base.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ import json
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning.utilities.cloud_io import load as pl_load
10
+ from pytorch_lightning.utilities.migration import pl_legacy_patch
11
+
12
+ class VeryBasicModel(pl.LightningModule):
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.save_hyperparameters()
16
+ self._step_train = 0
17
+ self._step_val = 0
18
+ self._step_test = 0
19
+
20
+
21
+ def forward(self, x_in):
22
+ raise NotImplementedError
23
+
24
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
25
+ raise NotImplementedError
26
+
27
+ def training_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0 ):
28
+ self._step_train += 1 # =self.global_step
29
+ return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx)
30
+
31
+ def validation_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0):
32
+ self._step_val += 1
33
+ return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx )
34
+
35
+ def test_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0):
36
+ self._step_test += 1
37
+ return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx)
38
+
39
+ def _epoch_end(self, outputs: list, state: str):
40
+ return
41
+
42
+ def training_epoch_end(self, outputs):
43
+ self._epoch_end(outputs, "train")
44
+
45
+ def validation_epoch_end(self, outputs):
46
+ self._epoch_end(outputs, "val")
47
+
48
+ def test_epoch_end(self, outputs):
49
+ self._epoch_end(outputs, "test")
50
+
51
+ @classmethod
52
+ def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path):
53
+ with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f:
54
+ json.dump({'best_model_epoch': Path(best_model_path).name}, f)
55
+
56
+ @classmethod
57
+ def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs):
58
+ path_version = 'lightning_logs/version_'+str(version)
59
+ with open(Path(path_checkpoint_dir) / path_version/ 'best_checkpoint.json', 'r') as f:
60
+ path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch'])
61
+ return Path(path_checkpoint_dir)/path_rel_best_checkpoint
62
+
63
+ @classmethod
64
+ def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs):
65
+ path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version)
66
+ return cls.load_from_checkpoint(path_best_checkpoint, **kwargs)
67
+
68
+ def load_pretrained(self, checkpoint_path, map_location=None, **kwargs):
69
+ if checkpoint_path.is_dir():
70
+ checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs)
71
+
72
+ with pl_legacy_patch():
73
+ if map_location is not None:
74
+ checkpoint = pl_load(checkpoint_path, map_location=map_location)
75
+ else:
76
+ checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
77
+ return self.load_weights(checkpoint["state_dict"], **kwargs)
78
+
79
+ def load_weights(self, pretrained_weights, strict=True, **kwargs):
80
+ filter = kwargs.get('filter', lambda key:key in pretrained_weights)
81
+ init_weights = self.state_dict()
82
+ pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter(key)}
83
+ init_weights.update(pretrained_weights)
84
+ self.load_state_dict(init_weights, strict=strict)
85
+ return self
86
+
87
+
88
+
89
+
90
+ class BasicModel(VeryBasicModel):
91
+ def __init__(self,
92
+ optimizer=torch.optim.AdamW,
93
+ optimizer_kwargs={'lr':1e-3, 'weight_decay':1e-2},
94
+ lr_scheduler= None,
95
+ lr_scheduler_kwargs={},
96
+ ):
97
+ super().__init__()
98
+ self.save_hyperparameters()
99
+ self.optimizer = optimizer
100
+ self.optimizer_kwargs = optimizer_kwargs
101
+ self.lr_scheduler = lr_scheduler
102
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs
103
+
104
+ def configure_optimizers(self):
105
+ optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs)
106
+ if self.lr_scheduler is not None:
107
+ lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
108
+ return [optimizer], [lr_scheduler]
109
+ else:
110
+ return [optimizer]
111
+
112
+
113
+
114
+
medical_diffusion/models/noise_schedulers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scheduler_base import BasicNoiseScheduler
2
+ from .gaussian_scheduler import GaussianNoiseScheduler
medical_diffusion/models/noise_schedulers/gaussian_scheduler.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ from medical_diffusion.models.noise_schedulers import BasicNoiseScheduler
7
+
8
+ class GaussianNoiseScheduler(BasicNoiseScheduler):
9
+ def __init__(
10
+ self,
11
+ timesteps=1000,
12
+ T = None,
13
+ schedule_strategy='cosine',
14
+ beta_start = 0.0001, # default 1e-4, stable-diffusion ~ 1e-3
15
+ beta_end = 0.02,
16
+ betas = None,
17
+ ):
18
+ super().__init__(timesteps, T)
19
+
20
+ self.schedule_strategy = schedule_strategy
21
+
22
+ if betas is not None:
23
+ betas = torch.as_tensor(betas, dtype = torch.float64)
24
+ elif schedule_strategy == "linear":
25
+ betas = torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
26
+ elif schedule_strategy == "scaled_linear": # proposed as "quadratic" in https://arxiv.org/abs/2006.11239, used in stable-diffusion
27
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64)**2
28
+ elif schedule_strategy == "cosine":
29
+ s = 0.008
30
+ x = torch.linspace(0, timesteps, timesteps + 1, dtype = torch.float64) # [0, T]
31
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
32
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
33
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
34
+ betas = torch.clip(betas, 0, 0.999)
35
+ else:
36
+ raise NotImplementedError(f"{schedule_strategy} does is not implemented for {self.__class__}")
37
+
38
+
39
+ alphas = 1-betas
40
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
41
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
42
+
43
+
44
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
45
+
46
+ register_buffer('betas', betas) # (0 , 1)
47
+
48
+ register_buffer('alphas', alphas) # (1 , 0)
49
+ register_buffer('alphas_cumprod', alphas_cumprod)
50
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
51
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
52
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
53
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
54
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
55
+
56
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
57
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
58
+ register_buffer('posterior_variance', betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod))
59
+
60
+
61
+ def estimate_x_t(self, x_0, t, x_T=None):
62
+ # NOTE: t == 0 means diffused for 1 step (https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils.py#L108)
63
+ # NOTE: t == 0 means not diffused for cold-diffusion (in contradiction to the above comment) https://github.com/arpitbansal297/Cold-Diffusion-Models/blob/c828140b7047ca22f995b99fbcda360bc30fc25d/denoising-diffusion-pytorch/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L361
64
+ x_T = self.x_final(x_0) if x_T is None else x_T
65
+ # ndim = x_0.ndim
66
+ # x_t = (self.extract(self.sqrt_alphas_cumprod, t, ndim)*x_0 +
67
+ # self.extract(self.sqrt_one_minus_alphas_cumprod, t, ndim)*x_T)
68
+ def clipper(b):
69
+ tb = t[b]
70
+ if tb<0:
71
+ return x_0[b]
72
+ elif tb>=self.T:
73
+ return x_T[b]
74
+ else:
75
+ return self.sqrt_alphas_cumprod[tb]*x_0[b]+self.sqrt_one_minus_alphas_cumprod[tb]*x_T[b]
76
+ x_t = torch.stack([clipper(b) for b in range(t.shape[0])])
77
+ return x_t
78
+
79
+
80
+ def estimate_x_t_prior_from_x_T(self, x_t, t, x_T, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
81
+ x_0 = self.estimate_x_0(x_t, x_T, t, clip_x0)
82
+ return self.estimate_x_t_prior_from_x_0(x_t, t, x_0, use_log, clip_x0, var_scale, cold_diffusion)
83
+
84
+
85
+ def estimate_x_t_prior_from_x_0(self, x_t, t, x_0, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
86
+ x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
87
+
88
+ if cold_diffusion: # see https://arxiv.org/abs/2208.09392
89
+ x_T_est = self.estimate_x_T(x_t, x_0, t) # or use x_T estimated by UNet if available?
90
+ x_t_est = self.estimate_x_t(x_0, t, x_T=x_T_est)
91
+ x_t_prior = self.estimate_x_t(x_0, t-1, x_T=x_T_est)
92
+ noise_t = x_t_est-x_t_prior
93
+ x_t_prior = x_t-noise_t
94
+ else:
95
+ mean = self.estimate_mean_t(x_t, x_0, t)
96
+ variance = self.estimate_variance_t(t, x_t.ndim, use_log, var_scale)
97
+ std = torch.exp(0.5*variance) if use_log else torch.sqrt(variance)
98
+ std[t==0] = 0.0
99
+ x_T = self.x_final(x_t)
100
+ x_t_prior = mean+std*x_T
101
+ return x_t_prior, x_0
102
+
103
+
104
+ def estimate_mean_t(self, x_t, x_0, t):
105
+ ndim = x_t.ndim
106
+ return (self.extract(self.posterior_mean_coef1, t, ndim)*x_0+
107
+ self.extract(self.posterior_mean_coef2, t, ndim)*x_t)
108
+
109
+
110
+ def estimate_variance_t(self, t, ndim, log=True, var_scale=0, eps=1e-20):
111
+ min_variance = self.extract(self.posterior_variance, t, ndim)
112
+ max_variance = self.extract(self.betas, t, ndim)
113
+ if log:
114
+ min_variance = torch.log(min_variance.clamp(min=eps))
115
+ max_variance = torch.log(max_variance.clamp(min=eps))
116
+ return var_scale * max_variance + (1 - var_scale) * min_variance
117
+
118
+
119
+ def estimate_x_0(self, x_t, x_T, t, clip_x0=True):
120
+ ndim = x_t.ndim
121
+ x_0 = (self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t -
122
+ self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)*x_T)
123
+ x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
124
+ return x_0
125
+
126
+
127
+ def estimate_x_T(self, x_t, x_0, t, clip_x0=True):
128
+ ndim = x_t.ndim
129
+ x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
130
+ return ((self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - x_0)/
131
+ self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim))
132
+
133
+
134
+ @classmethod
135
+ def x_final(cls, x):
136
+ return torch.randn_like(x)
137
+
138
+ @classmethod
139
+ def _clip_x_0(cls, x_0):
140
+ # See "static/dynamic thresholding" in Imagen https://arxiv.org/abs/2205.11487
141
+
142
+ # "static thresholding"
143
+ m = 1 # Set this to about 4*sigma = 4 if latent diffusion is used
144
+ x_0 = x_0.clamp(-m, m)
145
+
146
+ # "dynamic thresholding"
147
+ # r = torch.stack([torch.quantile(torch.abs(x_0_b), 0.997) for x_0_b in x_0])
148
+ # r = torch.maximum(r, torch.full_like(r,m))
149
+ # x_0 = torch.stack([x_0_b.clamp(-r_b, r_b)/r_b*m for x_0_b, r_b in zip(x_0, r) ] )
150
+
151
+ return x_0
152
+
153
+
154
+
medical_diffusion/models/noise_schedulers/scheduler_base.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BasicNoiseScheduler(nn.Module):
8
+ def __init__(
9
+ self,
10
+ timesteps=1000,
11
+ T=None,
12
+ ):
13
+ super().__init__()
14
+ self.timesteps = timesteps
15
+ self.T = timesteps if T is None else T
16
+
17
+ self.register_buffer('timesteps_array', torch.linspace(0, self.T-1, self.timesteps, dtype=torch.long)) # NOTE: End is inclusive therefore use -1 to get [0, T-1]
18
+
19
+ def __len__(self):
20
+ return len(self.timesteps)
21
+
22
+ def sample(self, x_0):
23
+ """Randomly sample t from [0,T] and return x_t and x_T based on x_0"""
24
+ t = torch.randint(0, self.T, (x_0.shape[0],), dtype=torch.long, device=x_0.device) # NOTE: High is exclusive, therefore [0, T-1]
25
+ x_T = self.x_final(x_0)
26
+ return self.estimate_x_t(x_0, t, x_T), x_T, t
27
+
28
+ def estimate_x_t_prior_from_x_T(self, x_T, t, **kwargs):
29
+ raise NotImplemented
30
+
31
+ def estimate_x_t_prior_from_x_0(self, x_0, t, **kwargs):
32
+ raise NotImplemented
33
+
34
+ def estimate_x_t(self, x_0, t, x_T=None, **kwargs):
35
+ """Get x_t at time t"""
36
+ raise NotImplemented
37
+
38
+ @classmethod
39
+ def x_final(cls, x):
40
+ """Get noise that should be obtained for t->T """
41
+ raise NotImplemented
42
+
43
+ @staticmethod
44
+ def extract(x, t, ndim):
45
+ """Extract values from x at t and reshape them to n-dim tensor"""
46
+ return x.gather(0, t).reshape(-1, *((1,)*(ndim-1)))
47
+
48
+
49
+
medical_diffusion/models/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .diffusion_pipeline import DiffusionPipeline
medical_diffusion/models/pipelines/diffusion_pipeline.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from pathlib import Path
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torchvision.utils import save_image
9
+ import streamlit as st
10
+
11
+ from medical_diffusion.models import BasicModel
12
+ from medical_diffusion.utils.train_utils import EMAModel
13
+ from medical_diffusion.utils.math_utils import kl_gaussians
14
+
15
+
16
+
17
+
18
+
19
+
20
+ class DiffusionPipeline(BasicModel):
21
+ def __init__(self,
22
+ noise_scheduler,
23
+ noise_estimator,
24
+ latent_embedder=None,
25
+ noise_scheduler_kwargs={},
26
+ noise_estimator_kwargs={},
27
+ latent_embedder_checkpoint='',
28
+ estimator_objective = 'x_T', # 'x_T' or 'x_0'
29
+ estimate_variance=False,
30
+ use_self_conditioning=False,
31
+ classifier_free_guidance_dropout=0.5, # Probability to drop condition during training, has only an effect for label-conditioned training
32
+ num_samples = 4,
33
+ do_input_centering = True, # Only for training
34
+ clip_x0=True, # Has only an effect during traing if use_self_conditioning=True, import for inference/sampling
35
+ use_ema = False,
36
+ ema_kwargs = {},
37
+ optimizer=torch.optim.AdamW,
38
+ optimizer_kwargs={'lr':1e-4}, # stable-diffusion ~ 1e-4
39
+ lr_scheduler= None, # stable-diffusion - LambdaLR
40
+ lr_scheduler_kwargs={},
41
+ loss=torch.nn.L1Loss,
42
+ loss_kwargs={},
43
+ sample_every_n_steps = 1000
44
+ ):
45
+ # self.save_hyperparameters(ignore=['noise_estimator', 'noise_scheduler'])
46
+ super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs)
47
+ self.loss_fct = loss(**loss_kwargs)
48
+ self.sample_every_n_steps=sample_every_n_steps
49
+
50
+ noise_estimator_kwargs['estimate_variance'] = estimate_variance
51
+ noise_estimator_kwargs['use_self_conditioning'] = use_self_conditioning
52
+
53
+ self.noise_scheduler = noise_scheduler(**noise_scheduler_kwargs)
54
+ self.noise_estimator = noise_estimator(**noise_estimator_kwargs)
55
+
56
+ with torch.no_grad():
57
+ if latent_embedder is not None:
58
+ self.latent_embedder = latent_embedder.load_from_checkpoint(latent_embedder_checkpoint)
59
+ for param in self.latent_embedder.parameters():
60
+ param.requires_grad = False
61
+ else:
62
+ self.latent_embedder = None
63
+
64
+ self.estimator_objective = estimator_objective
65
+ self.use_self_conditioning = use_self_conditioning
66
+ self.num_samples = num_samples
67
+ self.classifier_free_guidance_dropout = classifier_free_guidance_dropout
68
+ self.do_input_centering = do_input_centering
69
+ self.estimate_variance = estimate_variance
70
+ self.clip_x0 = clip_x0
71
+
72
+ self.use_ema = use_ema
73
+ if use_ema:
74
+ self.ema_model = EMAModel(self.noise_estimator, **ema_kwargs)
75
+
76
+
77
+
78
+ def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
79
+ results = {}
80
+ x_0 = batch['source']
81
+ condition = batch.get('target', None)
82
+
83
+ # Embed into latent space or normalize
84
+ if self.latent_embedder is not None:
85
+ self.latent_embedder.eval()
86
+ with torch.no_grad():
87
+ x_0 = self.latent_embedder.encode(x_0)
88
+
89
+ if self.do_input_centering:
90
+ x_0 = 2*x_0-1 # [0, 1] -> [-1, 1]
91
+
92
+ # if self.clip_x0:
93
+ # x_0 = torch.clamp(x_0, -1, 1)
94
+
95
+
96
+ # Sample Noise
97
+ with torch.no_grad():
98
+ # Randomly selecting t [0,T-1] and compute x_t (noisy version of x_0 at t)
99
+ x_t, x_T, t = self.noise_scheduler.sample(x_0)
100
+
101
+ # Use EMA Model
102
+ if self.use_ema and (state != 'train'):
103
+ noise_estimator = self.ema_model.averaged_model
104
+ else:
105
+ noise_estimator = self.noise_estimator
106
+
107
+ # Re-estimate x_T or x_0, self-conditioned on previous estimate
108
+ self_cond = None
109
+ if self.use_self_conditioning:
110
+ with torch.no_grad():
111
+ pred, pred_vertical = noise_estimator(x_t, t, condition, None)
112
+ if self.estimate_variance:
113
+ pred, _ = pred.chunk(2, dim = 1) # Seperate actual prediction and variance estimation
114
+ if self.estimator_objective == "x_T": # self condition on x_0
115
+ self_cond = self.noise_scheduler.estimate_x_0(x_t, pred, t=t, clip_x0=self.clip_x0)
116
+ elif self.estimator_objective == "x_0": # self condition on x_T
117
+ self_cond = self.noise_scheduler.estimate_x_T(x_t, pred, t=t, clip_x0=self.clip_x0)
118
+ else:
119
+ raise NotImplementedError(f"Option estimator_target={self.estimator_objective} not supported.")
120
+
121
+ # Classifier free guidance
122
+ if torch.rand(1)<self.classifier_free_guidance_dropout:
123
+ condition = None
124
+
125
+ # Run Denoise
126
+ pred, pred_vertical = noise_estimator(x_t, t, condition, self_cond)
127
+
128
+ # Separate variance (scale) if it was learned
129
+ if self.estimate_variance:
130
+ pred, pred_var = pred.chunk(2, dim = 1) # Separate actual prediction and variance estimation
131
+
132
+ # Specify target
133
+ if self.estimator_objective == "x_T":
134
+ target = x_T
135
+ elif self.estimator_objective == "x_0":
136
+ target = x_0
137
+ else:
138
+ raise NotImplementedError(f"Option estimator_target={self.estimator_objective} not supported.")
139
+
140
+
141
+ # ------------------------- Compute Loss ---------------------------
142
+ interpolation_mode = 'area'
143
+ loss = 0
144
+ weights = [1/2**i for i in range(1+len(pred_vertical))] # horizontal (equal) + vertical (reducing with every step down)
145
+ tot_weight = sum(weights)
146
+ weights = [w/tot_weight for w in weights]
147
+
148
+ # ----------------- MSE/L1, ... ----------------------
149
+ loss += self.loss_fct(pred, target)*weights[0]
150
+
151
+ # ----------------- Variance Loss --------------
152
+ if self.estimate_variance:
153
+ # var_scale = var_scale.clamp(-1, 1) # Should not be necessary
154
+ var_scale = (pred_var+1)/2 # Assumed to be in [-1, 1] -> [0, 1]
155
+ pred_logvar = self.noise_scheduler.estimate_variance_t(t, x_t.ndim, log=True, var_scale=var_scale)
156
+ # pred_logvar = pred_var # If variance is estimated directly
157
+
158
+ if self.estimator_objective == 'x_T':
159
+ pred_x_0 = self.noise_scheduler.estimate_x_0(x_t, x_T, t, clip_x0=self.clip_x0)
160
+ elif self.estimator_objective == "x_0":
161
+ pred_x_0 = pred
162
+ else:
163
+ raise NotImplementedError()
164
+
165
+ with torch.no_grad():
166
+ pred_mean = self.noise_scheduler.estimate_mean_t(x_t, pred_x_0, t)
167
+ true_mean = self.noise_scheduler.estimate_mean_t(x_t, x_0, t)
168
+ true_logvar = self.noise_scheduler.estimate_variance_t(t, x_t.ndim, log=True, var_scale=0)
169
+
170
+ kl_loss = torch.mean(kl_gaussians(true_mean, true_logvar, pred_mean, pred_logvar), dim=list(range(1, x_0.ndim)))
171
+ nnl_loss = torch.mean(F.gaussian_nll_loss(pred_x_0, x_0, torch.exp(pred_logvar), reduction='none'), dim=list(range(1, x_0.ndim)))
172
+ var_loss = torch.mean(torch.where(t == 0, nnl_loss, kl_loss))
173
+ loss += var_loss
174
+
175
+ results['variance_scale'] = torch.mean(var_scale)
176
+ results['variance_loss'] = var_loss
177
+
178
+
179
+ # ----------------------------- Deep Supervision -------------------------
180
+ for i, pred_i in enumerate(pred_vertical):
181
+ target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
182
+ loss += self.loss_fct(pred_i, target_i)*weights[i+1]
183
+ results['loss'] = loss
184
+
185
+
186
+
187
+ # --------------------- Compute Metrics -------------------------------
188
+ with torch.no_grad():
189
+ results['L2'] = F.mse_loss(pred, target)
190
+ results['L1'] = F.l1_loss(pred, target)
191
+ # results['SSIM'] = SSIMMetric(data_range=pred.max()-pred.min(), spatial_dims=source.ndim-2)(pred, target)
192
+
193
+ # for i, pred_i in enumerate(pred_vertical):
194
+ # target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
195
+ # results[f'L1_{i}'] = F.l1_loss(pred_i, target_i).detach()
196
+
197
+
198
+
199
+ # ----------------- Log Scalars ----------------------
200
+ for metric_name, metric_val in results.items():
201
+ self.log(f"{state}/{metric_name}", metric_val, batch_size=x_0.shape[0], on_step=True, on_epoch=True)
202
+
203
+
204
+ #------------------ Log Image -----------------------
205
+ if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0:
206
+ dataformats = 'NHWC' if x_0.ndim == 5 else 'HWC'
207
+ def norm(x):
208
+ return (x-x.min())/(x.max()-x.min())
209
+
210
+ sample_cond = condition[0:self.num_samples] if condition is not None else None
211
+ sample_img = self.sample(num_samples=self.num_samples, img_size=x_0.shape[1:], condition=sample_cond).detach()
212
+
213
+ log_step = self.global_step // self.sample_every_n_steps
214
+ # self.logger.experiment.add_images("predict_img", norm(torch.moveaxis(pred[0,-1:], 0,-1)), global_step=self.current_epoch, dataformats=dataformats)
215
+ # self.logger.experiment.add_images("target_img", norm(torch.moveaxis(target[0,-1:], 0,-1)), global_step=self.current_epoch, dataformats=dataformats)
216
+
217
+ # self.logger.experiment.add_images("source_img", norm(torch.moveaxis(x_0[0,-1:], 0,-1)), global_step=log_step, dataformats=dataformats)
218
+ # self.logger.experiment.add_images("sample_img", norm(torch.moveaxis(sample_img[0,-1:], 0,-1)), global_step=log_step, dataformats=dataformats)
219
+
220
+ path_out = Path(self.logger.log_dir)/'images'
221
+ path_out.mkdir(parents=True, exist_ok=True)
222
+ # for 3D images use depth as batch :[D, C, H, W], never show more than 32 images
223
+ def depth2batch(image):
224
+ return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
225
+ images = depth2batch(sample_img)[:32]
226
+ save_image(images, path_out/f'sample_{log_step}.png', normalize=True)
227
+
228
+
229
+ return loss
230
+
231
+
232
+ def forward(self, x_t, t, condition=None, self_cond=None, guidance_scale=1.0, cold_diffusion=False, un_cond=None):
233
+ # Note: x_t expected to be in range ~ [-1, 1]
234
+ if self.use_ema:
235
+ noise_estimator = self.ema_model.averaged_model
236
+ else:
237
+ noise_estimator = self.noise_estimator
238
+
239
+ # Concatenate inputs for guided and unguided diffusion as proposed by classifier-free-guidance
240
+ if (condition is not None) and (guidance_scale != 1.0):
241
+ # Model prediction
242
+ pred_uncond, _ = noise_estimator(x_t, t, condition=un_cond, self_cond=self_cond)
243
+ pred_cond, _ = noise_estimator(x_t, t, condition=condition, self_cond=self_cond)
244
+ pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
245
+
246
+ if self.estimate_variance:
247
+ pred_uncond, pred_var_uncond = pred_uncond.chunk(2, dim = 1)
248
+ pred_cond, pred_var_cond = pred_cond.chunk(2, dim = 1)
249
+ pred_var = pred_var_uncond + guidance_scale * (pred_var_cond - pred_var_uncond)
250
+ else:
251
+ pred, _ = noise_estimator(x_t, t, condition=condition, self_cond=self_cond)
252
+ if self.estimate_variance:
253
+ pred, pred_var = pred.chunk(2, dim = 1)
254
+
255
+ if self.estimate_variance:
256
+ pred_var_scale = pred_var/2+0.5 # [-1, 1] -> [0, 1]
257
+ pred_var_value = pred_var
258
+ else:
259
+ pred_var_scale = 0
260
+ pred_var_value = None
261
+
262
+ # pred_var_scale = pred_var_scale.clamp(0, 1)
263
+
264
+ if self.estimator_objective == 'x_0':
265
+ x_t_prior, x_0 = self.noise_scheduler.estimate_x_t_prior_from_x_0(x_t, t, pred, clip_x0=self.clip_x0, var_scale=pred_var_scale, cold_diffusion=cold_diffusion)
266
+ x_T = self.noise_scheduler.estimate_x_T(x_t, x_0=pred, t=t, clip_x0=self.clip_x0)
267
+ self_cond = x_T
268
+ elif self.estimator_objective == 'x_T':
269
+ x_t_prior, x_0 = self.noise_scheduler.estimate_x_t_prior_from_x_T(x_t, t, pred, clip_x0=self.clip_x0, var_scale=pred_var_scale, cold_diffusion=cold_diffusion)
270
+ x_T = pred
271
+ self_cond = x_0
272
+ else:
273
+ raise ValueError("Unknown Objective")
274
+
275
+ return x_t_prior, x_0, x_T, self_cond
276
+
277
+
278
+ @torch.no_grad()
279
+ def denoise(self, x_t, steps=None, condition=None, use_ddim=True, **kwargs):
280
+ self_cond = None
281
+
282
+ # ---------- run denoise loop ---------------
283
+ if use_ddim:
284
+ steps = self.noise_scheduler.timesteps if steps is None else steps
285
+ timesteps_array = torch.linspace(0, self.noise_scheduler.T-1, steps, dtype=torch.long, device=x_t.device) # [0, 1, 2, ..., T-1] if steps = T
286
+ else:
287
+ timesteps_array = self.noise_scheduler.timesteps_array[slice(0, steps)] # [0, ...,T-1] (target time not time of x_t)
288
+
289
+ st_prog_bar = st.progress(0)
290
+ for i, t in tqdm(enumerate(reversed(timesteps_array))):
291
+ st_prog_bar.progress((i+1)/len(timesteps_array))
292
+
293
+ # UNet prediction
294
+ x_t, x_0, x_T, self_cond = self(x_t, t.expand(x_t.shape[0]), condition, self_cond=self_cond, **kwargs)
295
+ self_cond = self_cond if self.use_self_conditioning else None
296
+
297
+ if use_ddim and (steps-i-1>0):
298
+ t_next = timesteps_array[steps-i-2]
299
+ alpha = self.noise_scheduler.alphas_cumprod[t]
300
+ alpha_next = self.noise_scheduler.alphas_cumprod[t_next]
301
+ sigma = kwargs.get('eta', 1) * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
302
+ c = (1 - alpha_next - sigma ** 2).sqrt()
303
+ noise = torch.randn_like(x_t)
304
+ x_t = x_0 * alpha_next.sqrt() + c * x_T + sigma * noise
305
+
306
+ # ------ Eventually decode from latent space into image space--------
307
+ if self.latent_embedder is not None:
308
+ x_t = self.latent_embedder.decode(x_t)
309
+
310
+ return x_t # Should be x_0 in final step (t=0)
311
+
312
+ @torch.no_grad()
313
+ def sample(self, num_samples, img_size, condition=None, **kwargs):
314
+ template = torch.zeros((num_samples, *img_size), device=self.device)
315
+ x_T = self.noise_scheduler.x_final(template)
316
+ x_0 = self.denoise(x_T, condition=condition, **kwargs)
317
+ return x_0
318
+
319
+
320
+ @torch.no_grad()
321
+ def interpolate(self, img1, img2, i = None, condition=None, lam = 0.5, **kwargs):
322
+ assert img1.shape == img2.shape, "Image 1 and 2 must have equal shape"
323
+
324
+ t = self.noise_scheduler.T-1 if i is None else i
325
+ t = torch.full(img1.shape[:1], i, device=img1.device)
326
+
327
+ img1_t = self.noise_scheduler.estimate_x_t(img1, t=t, clip_x0=self.clip_x0)
328
+ img2_t = self.noise_scheduler.estimate_x_t(img2, t=t, clip_x0=self.clip_x0)
329
+
330
+ img = (1 - lam) * img1_t + lam * img2_t
331
+ img = self.denoise(img, i, condition, **kwargs)
332
+ return img
333
+
334
+ def on_train_batch_end(self, *args, **kwargs):
335
+ if self.use_ema:
336
+ self.ema_model.step(self.noise_estimator)
337
+
338
+ def configure_optimizers(self):
339
+ optimizer = self.optimizer(self.noise_estimator.parameters(), **self.optimizer_kwargs)
340
+ if self.lr_scheduler is not None:
341
+ lr_scheduler = {
342
+ 'scheduler': self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs),
343
+ 'interval': 'step',
344
+ 'frequency': 1
345
+ }
346
+ return [optimizer], [lr_scheduler]
347
+ else:
348
+ return [optimizer]
medical_diffusion/models/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .attention_blocks import *
2
+ from .conv_blocks import *
medical_diffusion/models/utils/attention_blocks.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ from monai.networks.blocks import TransformerBlock
6
+ from monai.networks.layers.utils import get_norm_layer, get_dropout_layer
7
+ from monai.networks.layers.factories import Conv
8
+ from einops import rearrange
9
+
10
+
11
+ class GEGLU(nn.Module):
12
+ def __init__(self, in_channels, out_channels):
13
+ super().__init__()
14
+ self.norm = nn.LayerNorm(in_channels)
15
+ self.proj = nn.Linear(in_channels, out_channels*2, bias=True)
16
+
17
+ def forward(self, x):
18
+ # x expected to be [B, C, *]
19
+ # Workaround as layer norm can't currently be applied on arbitrary dimension: https://github.com/pytorch/pytorch/issues/71465
20
+ b, c, *spatial = x.shape
21
+ x = x.reshape(b, c, -1).transpose(1, 2) # -> [B, C, N] -> [B, N, C]
22
+ x = self.norm(x)
23
+ x, gate = self.proj(x).chunk(2, dim=-1)
24
+ x = x * F.gelu(gate)
25
+ return x.transpose(1, 2).reshape(b, -1, *spatial) # -> [B, C, N] -> [B, C, *]
26
+
27
+ def zero_module(module):
28
+ """
29
+ Zero out the parameters of a module and return it.
30
+ """
31
+ for p in module.parameters():
32
+ p.detach().zero_()
33
+ return module
34
+
35
+ def compute_attention(q,k,v , num_heads, scale):
36
+ q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> (b h) d n', h=num_heads), (q, k, v)) # [(BxHeads), Dim_per_head, N]
37
+
38
+ attn = (torch.einsum('b d i, b d j -> b i j', q*scale, k*scale)).softmax(dim=-1) # Matrix product = [(BxHeads), Dim_per_head, N] * [(BxHeads), Dim_per_head, N'] =[(BxHeads), N, N']
39
+
40
+ out = torch.einsum('b i j, b d j-> b d i', attn, v) # Matrix product: [(BxHeads), N, N'] * [(BxHeads), Dim_per_head, N'] = [(BxHeads), Dim_per_head, N]
41
+ out = rearrange(out, '(b h) d n-> b (h d) n', h=num_heads) # -> [B, (Heads x Dim_per_head), N]
42
+
43
+ return out
44
+
45
+
46
+ class LinearTransformerNd(nn.Module):
47
+ """ Combines multi-head self-attention and multi-head cross-attention.
48
+
49
+ Multi-Head Self-Attention:
50
+ Similar to multi-head self-attention (https://arxiv.org/abs/1706.03762) without Norm+MLP (compare Monai TransformerBlock)
51
+ Proposed here: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
52
+ Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/diffusionmodules/openaimodel.py#L278
53
+ Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L80
54
+ Similar to: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/dfbafee555bdae80b55d63a989073836bbfc257e/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L209
55
+ Similar to: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py#L150
56
+
57
+ CrossAttention:
58
+ Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L152
59
+
60
+ """
61
+ def __init__(
62
+ self,
63
+ spatial_dims,
64
+ in_channels,
65
+ out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
66
+ num_heads=8,
67
+ ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
68
+ norm_name=("GROUP", {'num_groups':32, "affine": True}), # Or use LayerNorm but be aware of https://github.com/pytorch/pytorch/issues/71465 (=> GroupNorm with num_groups=1)
69
+ dropout=None,
70
+ emb_dim=None,
71
+ ):
72
+ super().__init__()
73
+ hid_channels = num_heads*ch_per_head
74
+ self.num_heads = num_heads
75
+ self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale
76
+
77
+ self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
78
+ emb_dim = in_channels if emb_dim is None else emb_dim
79
+
80
+ Convolution = Conv["conv", spatial_dims]
81
+ self.to_q = Convolution(in_channels, hid_channels, 1)
82
+ self.to_k = Convolution(emb_dim, hid_channels, 1)
83
+ self.to_v = Convolution(emb_dim, hid_channels, 1)
84
+
85
+ self.to_out = nn.Sequential(
86
+ zero_module(Convolution(hid_channels, out_channels, 1)),
87
+ nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
88
+ )
89
+
90
+ def forward(self, x, embedding=None):
91
+ # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
92
+ # if no embedding is given, cross-attention defaults to self-attention
93
+
94
+ # Normalize
95
+ b, c, *spatial = x.shape
96
+ x_n = self.norm_x(x)
97
+
98
+ # Attention: embedding (cross-attention) or x (self-attention)
99
+ if embedding is None:
100
+ embedding = x_n # WARNING: This assumes that emb_dim==in_channels
101
+ else:
102
+ if embedding.ndim == 2:
103
+ embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *]
104
+ # Why no normalization for embedding here?
105
+
106
+ # Convolution
107
+ q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), *]
108
+ k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), *]
109
+ v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), *]
110
+
111
+ # Flatten
112
+ q = q.reshape(b, c, -1) # -> [B, (Heads x Dim_per_head), N]
113
+ k = k.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N']
114
+ v = v.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N']
115
+
116
+ # Apply attention
117
+ out = compute_attention(q, k, v, self.num_heads, self.scale)
118
+
119
+ out = out.reshape(*out.shape[:2], *spatial) # -> [B, (Heads x Dim_per_head), *]
120
+ out = self.to_out(out) # -> [B, C', *]
121
+
122
+
123
+ if x.shape == out.shape:
124
+ out = x + out
125
+ return out # [B, C', *]
126
+
127
+
128
+ class LinearTransformer(nn.Module):
129
+ """ See LinearTransformer, however this implementation is fixed to Conv1d/Linear"""
130
+ def __init__(
131
+ self,
132
+ spatial_dims,
133
+ in_channels,
134
+ out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
135
+ num_heads,
136
+ ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
137
+ norm_name=("GROUP", {'num_groups':32, "affine": True}),
138
+ dropout=None,
139
+ emb_dim=None
140
+ ):
141
+ super().__init__()
142
+ hid_channels = num_heads*ch_per_head
143
+ self.num_heads = num_heads
144
+ self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale
145
+
146
+ self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
147
+ emb_dim = in_channels if emb_dim is None else emb_dim
148
+
149
+ # Note: Conv1d and Linear are interchangeable but order of input changes [B, C, N] <-> [B, N, C]
150
+ self.to_q = nn.Conv1d(in_channels, hid_channels, 1)
151
+ self.to_k = nn.Conv1d(emb_dim, hid_channels, 1)
152
+ self.to_v = nn.Conv1d(emb_dim, hid_channels, 1)
153
+ # self.to_qkv = nn.Conv1d(emb_dim, hid_channels*3, 1)
154
+
155
+ self.to_out = nn.Sequential(
156
+ zero_module(nn.Conv1d(hid_channels, out_channels, 1)),
157
+ nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
158
+ )
159
+
160
+ def forward(self, x, embedding=None):
161
+ # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
162
+ # if no embedding is given, cross-attention defaults to self-attention
163
+
164
+ # Normalize
165
+ b, c, *spatial = x.shape
166
+ x_n = self.norm_x(x)
167
+
168
+ # Attention: embedding (cross-attention) or x (self-attention)
169
+ if embedding is None:
170
+ embedding = x_n # WARNING: This assumes that emb_dim==in_channels
171
+ else:
172
+ if embedding.ndim == 2:
173
+ embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *]
174
+ # Why no normalization for embedding here?
175
+
176
+ # Flatten
177
+ x_n = x_n.reshape(b, c, -1) # [B, C, *] -> [B, C, N]
178
+ embedding = embedding.reshape(*embedding.shape[:2], -1) # [B, C*, *] -> [B, C*, N']
179
+
180
+ # Convolution
181
+ q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), N]
182
+ k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), N']
183
+ v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), N']
184
+ # qkv = self.to_qkv(x_n)
185
+ # q,k,v = qkv.split(qkv.shape[1]//3, dim=1)
186
+
187
+ # Apply attention
188
+ out = compute_attention(q, k, v, self.num_heads, self.scale)
189
+
190
+ out = self.to_out(out) # -> [B, C', N]
191
+ out = out.reshape(*out.shape[:2], *spatial) # -> [B, C', *]
192
+
193
+ if x.shape == out.shape:
194
+ out = x + out
195
+ return out # [B, C', *]
196
+
197
+
198
+
199
+
200
+ class BasicTransformerBlock(nn.Module):
201
+ def __init__(
202
+ self,
203
+ spatial_dims,
204
+ in_channels,
205
+ out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
206
+ num_heads,
207
+ ch_per_head=32,
208
+ norm_name=("GROUP", {'num_groups':32, "affine": True}),
209
+ dropout=None,
210
+ emb_dim=None
211
+ ):
212
+ super().__init__()
213
+ self.self_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, None)
214
+ if emb_dim is not None:
215
+ self.cros_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, emb_dim)
216
+ self.proj_out = nn.Sequential(
217
+ GEGLU(in_channels, in_channels*4),
218
+ nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims),
219
+ Conv["conv", spatial_dims](in_channels*4, out_channels, 1, bias=True)
220
+ )
221
+
222
+
223
+ def forward(self, x, embedding=None):
224
+ # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
225
+ x = self.self_atn(x)
226
+ if embedding is not None:
227
+ x = self.cros_atn(x, embedding=embedding)
228
+ out = self.proj_out(x)
229
+ if out.shape[1] == x.shape[1]:
230
+ return out + x
231
+ return x
232
+
233
+ class SpatialTransformer(nn.Module):
234
+ """ Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L218
235
+ Unrelated to: https://arxiv.org/abs/1506.02025
236
+ """
237
+ def __init__(
238
+ self,
239
+ spatial_dims,
240
+ in_channels,
241
+ out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
242
+ num_heads,
243
+ ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
244
+ norm_name = ("GROUP", {'num_groups':32, "affine": True}),
245
+ dropout=None,
246
+ emb_dim=None,
247
+ depth=1
248
+ ):
249
+ super().__init__()
250
+ self.in_channels = in_channels
251
+ self.norm = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
252
+ conv_class = Conv["conv", spatial_dims]
253
+ hid_channels = num_heads*ch_per_head
254
+
255
+ self.proj_in = conv_class(
256
+ in_channels,
257
+ hid_channels,
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ )
262
+
263
+ self.transformer_blocks = nn.ModuleList([
264
+ BasicTransformerBlock(spatial_dims, hid_channels, hid_channels, num_heads, ch_per_head, norm_name, dropout=dropout, emb_dim=emb_dim)
265
+ for _ in range(depth)]
266
+ )
267
+
268
+ self.proj_out = conv_class( # Note: zero_module is used in original code
269
+ hid_channels,
270
+ out_channels,
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ )
275
+
276
+ def forward(self, x, embedding=None):
277
+ # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
278
+ # Note: if no embedding is given, cross-attention is disabled
279
+ h = self.norm(x)
280
+ h = self.proj_in(h)
281
+
282
+ for block in self.transformer_blocks:
283
+ h = block(h, embedding=embedding)
284
+
285
+ h = self.proj_out(h) # -> [B, C'', *]
286
+ if h.shape == x.shape:
287
+ return h + x
288
+ return h
289
+
290
+
291
+ class Attention(nn.Module):
292
+ def __init__(
293
+ self,
294
+ spatial_dims,
295
+ in_channels,
296
+ out_channels,
297
+ num_heads=8,
298
+ ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
299
+ norm_name = ("GROUP", {'num_groups':32, "affine": True}),
300
+ dropout=0,
301
+ emb_dim=None,
302
+ depth=1,
303
+ attention_type='linear'
304
+ ) -> None:
305
+ super().__init__()
306
+ if attention_type == 'spatial':
307
+ self.attention = SpatialTransformer(
308
+ spatial_dims=spatial_dims,
309
+ in_channels=in_channels,
310
+ out_channels=out_channels,
311
+ num_heads=num_heads,
312
+ ch_per_head=ch_per_head,
313
+ depth=depth,
314
+ norm_name=norm_name,
315
+ dropout=dropout,
316
+ emb_dim=emb_dim
317
+ )
318
+ elif attention_type == 'linear':
319
+ self.attention = LinearTransformer(
320
+ spatial_dims=spatial_dims,
321
+ in_channels=in_channels,
322
+ out_channels=out_channels,
323
+ num_heads=num_heads,
324
+ ch_per_head=ch_per_head,
325
+ norm_name=norm_name,
326
+ dropout=dropout,
327
+ emb_dim=emb_dim
328
+ )
329
+
330
+
331
+ def forward(self, x, emb=None):
332
+ if hasattr(self, 'attention'):
333
+ return self.attention(x, emb)
334
+ else:
335
+ return x
medical_diffusion/models/utils/conv_blocks.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple, Union, Type
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+
8
+
9
+ from monai.networks.blocks.dynunet_block import get_padding, get_output_padding
10
+ from monai.networks.layers import Pool, Conv
11
+ from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_dropout_layer
12
+ from monai.utils.misc import ensure_tuple_rep
13
+
14
+ from medical_diffusion.models.utils.attention_blocks import Attention, zero_module
15
+
16
+ def save_add(*args):
17
+ args = [arg for arg in args if arg is not None]
18
+ return sum(args) if len(args)>0 else None
19
+
20
+
21
+ class SequentialEmb(nn.Sequential):
22
+ def forward(self, input, emb):
23
+ for module in self:
24
+ input = module(input, emb)
25
+ return input
26
+
27
+
28
+ class BasicDown(nn.Module):
29
+ def __init__(
30
+ self,
31
+ spatial_dims,
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size=3,
35
+ stride=2,
36
+ learnable_interpolation=True,
37
+ use_res=False
38
+ ) -> None:
39
+ super().__init__()
40
+
41
+ if learnable_interpolation:
42
+ Convolution = Conv[Conv.CONV, spatial_dims]
43
+ self.down_op = Convolution(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=get_padding(kernel_size, stride),
49
+ dilation=1,
50
+ groups=1,
51
+ bias=True,
52
+ )
53
+
54
+ if use_res:
55
+ self.down_skip = nn.PixelUnshuffle(2) # WARNING: Only supports 2D, , out_channels == 4*in_channels
56
+
57
+ else:
58
+ Pooling = Pool['avg', spatial_dims]
59
+ self.down_op = Pooling(
60
+ kernel_size=kernel_size,
61
+ stride=stride,
62
+ padding=get_padding(kernel_size, stride)
63
+ )
64
+
65
+
66
+ def forward(self, x, emb=None):
67
+ y = self.down_op(x)
68
+ if hasattr(self, 'down_skip'):
69
+ y = y+self.down_skip(x)
70
+ return y
71
+
72
+ class BasicUp(nn.Module):
73
+ def __init__(
74
+ self,
75
+ spatial_dims,
76
+ in_channels,
77
+ out_channels,
78
+ kernel_size=2,
79
+ stride=2,
80
+ learnable_interpolation=True,
81
+ use_res=False,
82
+ ) -> None:
83
+ super().__init__()
84
+ self.learnable_interpolation = learnable_interpolation
85
+ if learnable_interpolation:
86
+ # TransConvolution = Conv[Conv.CONVTRANS, spatial_dims]
87
+ # padding = get_padding(kernel_size, stride)
88
+ # output_padding = get_output_padding(kernel_size, stride, padding)
89
+ # self.up_op = TransConvolution(
90
+ # in_channels,
91
+ # out_channels,
92
+ # kernel_size=kernel_size,
93
+ # stride=stride,
94
+ # padding=padding,
95
+ # output_padding=output_padding,
96
+ # groups=1,
97
+ # bias=True,
98
+ # dilation=1
99
+ # )
100
+
101
+ self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size)
102
+ -2*np.atleast_1d(get_padding(kernel_size, stride)))
103
+ Convolution = Conv[Conv.CONV, spatial_dims]
104
+ self.up_op = Convolution(
105
+ in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1,
110
+ dilation=1,
111
+ groups=1,
112
+ bias=True,
113
+ )
114
+
115
+ if use_res:
116
+ self.up_skip = nn.PixelShuffle(2) # WARNING: Only supports 2D, out_channels == in_channels/4
117
+ else:
118
+ self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size)
119
+ -2*np.atleast_1d(get_padding(kernel_size, stride)))
120
+
121
+ def forward(self, x, emb=None):
122
+ if self.learnable_interpolation:
123
+ new_size = self.calc_shape(x.shape[2:])
124
+ x_res = F.interpolate(x, size=new_size, mode='nearest-exact')
125
+ y = self.up_op(x_res)
126
+ if hasattr(self, 'up_skip'):
127
+ y = y+self.up_skip(x)
128
+ return y
129
+ else:
130
+ new_size = self.calc_shape(x.shape[2:])
131
+ return F.interpolate(x, size=new_size, mode='nearest-exact')
132
+
133
+
134
+ class BasicBlock(nn.Module):
135
+ """
136
+ A block that consists of Conv-Norm-Drop-Act, similar to blocks.Convolution.
137
+
138
+ Args:
139
+ spatial_dims: number of spatial dimensions.
140
+ in_channels: number of input channels.
141
+ out_channels: number of output channels.
142
+ kernel_size: convolution kernel size.
143
+ stride: convolution stride.
144
+ norm_name: feature normalization type and arguments.
145
+ act_name: activation layer type and arguments.
146
+ dropout: dropout probability.
147
+ zero_conv: zero out the parameters of the convolution.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ spatial_dims: int,
153
+ in_channels: int,
154
+ out_channels: int,
155
+ kernel_size: Union[Sequence[int], int],
156
+ stride: Union[Sequence[int], int]=1,
157
+ norm_name: Union[Tuple, str, None]=None,
158
+ act_name: Union[Tuple, str, None] = None,
159
+ dropout: Optional[Union[Tuple, str, float]] = None,
160
+ zero_conv: bool = False,
161
+ ):
162
+ super().__init__()
163
+ Convolution = Conv[Conv.CONV, spatial_dims]
164
+ conv = Convolution(
165
+ in_channels,
166
+ out_channels,
167
+ kernel_size=kernel_size,
168
+ stride=stride,
169
+ padding=get_padding(kernel_size, stride),
170
+ dilation=1,
171
+ groups=1,
172
+ bias=True,
173
+ )
174
+ self.conv = zero_module(conv) if zero_conv else conv
175
+
176
+ if norm_name is not None:
177
+ self.norm = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
178
+ if dropout is not None:
179
+ self.drop = get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
180
+ if act_name is not None:
181
+ self.act = get_act_layer(name=act_name)
182
+
183
+
184
+ def forward(self, inp):
185
+ out = self.conv(inp)
186
+ if hasattr(self, "norm"):
187
+ out = self.norm(out)
188
+ if hasattr(self, 'drop'):
189
+ out = self.drop(out)
190
+ if hasattr(self, "act"):
191
+ out = self.act(out)
192
+ return out
193
+
194
+ class BasicResBlock(nn.Module):
195
+ """
196
+ A block that consists of Conv-Act-Norm + skip.
197
+
198
+ Args:
199
+ spatial_dims: number of spatial dimensions.
200
+ in_channels: number of input channels.
201
+ out_channels: number of output channels.
202
+ kernel_size: convolution kernel size.
203
+ stride: convolution stride.
204
+ norm_name: feature normalization type and arguments.
205
+ act_name: activation layer type and arguments.
206
+ dropout: dropout probability.
207
+ zero_conv: zero out the parameters of the convolution.
208
+ """
209
+ def __init__(
210
+ self,
211
+ spatial_dims: int,
212
+ in_channels: int,
213
+ out_channels: int,
214
+ kernel_size: Union[Sequence[int], int],
215
+ stride: Union[Sequence[int], int]=1,
216
+ norm_name: Union[Tuple, str, None]=None,
217
+ act_name: Union[Tuple, str, None] = None,
218
+ dropout: Optional[Union[Tuple, str, float]] = None,
219
+ zero_conv: bool = False
220
+ ):
221
+ super().__init__()
222
+ self.basic_block = BasicBlock(spatial_dims, in_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, zero_conv)
223
+ Convolution = Conv[Conv.CONV, spatial_dims]
224
+ self.conv_res = Convolution(
225
+ in_channels,
226
+ out_channels,
227
+ kernel_size=1,
228
+ stride=stride,
229
+ padding=get_padding(1, stride),
230
+ dilation=1,
231
+ groups=1,
232
+ bias=True,
233
+ ) if in_channels != out_channels else nn.Identity()
234
+
235
+
236
+ def forward(self, inp):
237
+ out = self.basic_block(inp)
238
+ residual = self.conv_res(inp)
239
+ out = out+residual
240
+ return out
241
+
242
+
243
+
244
+ class UnetBasicBlock(nn.Module):
245
+ """
246
+ A modified version of monai.networks.blocks.UnetBasicBlock with additional embedding
247
+
248
+ Args:
249
+ spatial_dims: number of spatial dimensions.
250
+ in_channels: number of input channels.
251
+ out_channels: number of output channels.
252
+ kernel_size: convolution kernel size.
253
+ stride: convolution stride.
254
+ norm_name: feature normalization type and arguments.
255
+ act_name: activation layer type and arguments.
256
+ dropout: dropout probability.
257
+ emb_channels: Number of embedding channels
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ spatial_dims: int,
263
+ in_channels: int,
264
+ out_channels: int,
265
+ kernel_size: Union[Sequence[int], int],
266
+ stride: Union[Sequence[int], int]=1,
267
+ norm_name: Union[Tuple, str]=None,
268
+ act_name: Union[Tuple, str]=None,
269
+ dropout: Optional[Union[Tuple, str, float]] = None,
270
+ emb_channels: int = None,
271
+ blocks = 2
272
+ ):
273
+ super().__init__()
274
+ self.block_seq = nn.ModuleList([
275
+ BasicBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1)
276
+ for i in range(blocks)
277
+ ])
278
+
279
+ if emb_channels is not None:
280
+ self.local_embedder = nn.Sequential(
281
+ get_act_layer(name=act_name),
282
+ nn.Linear(emb_channels, out_channels),
283
+ )
284
+
285
+ def forward(self, x, emb=None):
286
+ # ------------ Embedding ----------
287
+ if emb is not None:
288
+ emb = self.local_embedder(emb)
289
+ b,c, *_ = emb.shape
290
+ sp_dim = x.ndim-2
291
+ emb = emb.reshape(b, c, *((1,)*sp_dim) )
292
+ # scale, shift = emb.chunk(2, dim = 1)
293
+ # x = x * (scale + 1) + shift
294
+ # x = x+emb
295
+
296
+ # ----------- Convolution ---------
297
+ n_blocks = len(self.block_seq)
298
+ for i, block in enumerate(self.block_seq):
299
+ x = block(x)
300
+ if (emb is not None) and i<n_blocks:
301
+ x += emb
302
+ return x
303
+
304
+
305
+ class UnetResBlock(nn.Module):
306
+ """
307
+ A modified version of monai.networks.blocks.UnetResBlock with additional skip connection and embedding
308
+
309
+ Args:
310
+ spatial_dims: number of spatial dimensions.
311
+ in_channels: number of input channels.
312
+ out_channels: number of output channels.
313
+ kernel_size: convolution kernel size.
314
+ stride: convolution stride.
315
+ norm_name: feature normalization type and arguments.
316
+ act_name: activation layer type and arguments.
317
+ dropout: dropout probability.
318
+ emb_channels: Number of embedding channels
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ spatial_dims: int,
324
+ in_channels: int,
325
+ out_channels: int,
326
+ kernel_size: Union[Sequence[int], int],
327
+ stride: Union[Sequence[int], int]=1,
328
+ norm_name: Union[Tuple, str]=None,
329
+ act_name: Union[Tuple, str]=None,
330
+ dropout: Optional[Union[Tuple, str, float]] = None,
331
+ emb_channels: int = None,
332
+ blocks = 2
333
+ ):
334
+ super().__init__()
335
+ self.block_seq = nn.ModuleList([
336
+ BasicResBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1)
337
+ for i in range(blocks)
338
+ ])
339
+
340
+ if emb_channels is not None:
341
+ self.local_embedder = nn.Sequential(
342
+ get_act_layer(name=act_name),
343
+ nn.Linear(emb_channels, out_channels),
344
+ )
345
+
346
+
347
+ def forward(self, x, emb=None):
348
+ # ------------ Embedding ----------
349
+ if emb is not None:
350
+ emb = self.local_embedder(emb)
351
+ b,c, *_ = emb.shape
352
+ sp_dim = x.ndim-2
353
+ emb = emb.reshape(b, c, *((1,)*sp_dim) )
354
+ # scale, shift = emb.chunk(2, dim = 1)
355
+ # x = x * (scale + 1) + shift
356
+ # x = x+emb
357
+
358
+ # ----------- Convolution ---------
359
+ n_blocks = len(self.block_seq)
360
+ for i, block in enumerate(self.block_seq):
361
+ x = block(x)
362
+ if (emb is not None) and i<n_blocks-1:
363
+ x += emb
364
+ return x
365
+
366
+
367
+
368
+ class DownBlock(nn.Module):
369
+ def __init__(
370
+ self,
371
+ spatial_dims: int,
372
+ in_channels: int,
373
+ out_channels: int,
374
+ kernel_size: Union[Sequence[int], int],
375
+ stride: Union[Sequence[int], int],
376
+ downsample_kernel_size: Union[Sequence[int], int],
377
+ norm_name: Union[Tuple, str],
378
+ act_name: Union[Tuple, str],
379
+ dropout: Optional[Union[Tuple, str, float]] = None,
380
+ use_res_block: bool = False,
381
+ learnable_interpolation: bool = True,
382
+ use_attention: str = 'none',
383
+ emb_channels: int = None
384
+ ):
385
+ super(DownBlock, self).__init__()
386
+ enable_down = ensure_tuple_rep(stride, spatial_dims) != ensure_tuple_rep(1, spatial_dims)
387
+ down_out_channels = out_channels if learnable_interpolation and enable_down else in_channels
388
+
389
+ # -------------- Down ----------------------
390
+ self.down_op = BasicDown(
391
+ spatial_dims,
392
+ in_channels,
393
+ out_channels,
394
+ kernel_size=downsample_kernel_size,
395
+ stride=stride,
396
+ learnable_interpolation=learnable_interpolation,
397
+ use_res=False
398
+ ) if enable_down else nn.Identity()
399
+
400
+
401
+ # ---------------- Attention -------------
402
+ self.attention = Attention(
403
+ spatial_dims=spatial_dims,
404
+ in_channels=down_out_channels,
405
+ out_channels=down_out_channels,
406
+ num_heads=8,
407
+ ch_per_head=down_out_channels//8,
408
+ depth=1,
409
+ norm_name=norm_name,
410
+ dropout=dropout,
411
+ emb_dim=emb_channels,
412
+ attention_type=use_attention
413
+ )
414
+
415
+ # -------------- Convolution ----------------------
416
+ ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
417
+ self.conv_block = ConvBlock(
418
+ spatial_dims,
419
+ down_out_channels,
420
+ out_channels,
421
+ kernel_size=kernel_size,
422
+ stride=1,
423
+ dropout=dropout,
424
+ norm_name=norm_name,
425
+ act_name=act_name,
426
+ emb_channels=emb_channels
427
+ )
428
+
429
+
430
+ def forward(self, x, emb=None):
431
+ # ----------- Down ---------
432
+ x = self.down_op(x)
433
+
434
+ # ----------- Attention -------------
435
+ if self.attention is not None:
436
+ x = self.attention(x, emb)
437
+
438
+ # ------------- Convolution --------------
439
+ x = self.conv_block(x, emb)
440
+
441
+ return x
442
+
443
+
444
+ class UpBlock(nn.Module):
445
+ def __init__(
446
+ self,
447
+ spatial_dims,
448
+ in_channels: int,
449
+ out_channels: int,
450
+ kernel_size: Union[Sequence[int], int],
451
+ stride: Union[Sequence[int], int],
452
+ upsample_kernel_size: Union[Sequence[int], int],
453
+ norm_name: Union[Tuple, str],
454
+ act_name: Union[Tuple, str],
455
+ dropout: Optional[Union[Tuple, str, float]] = None,
456
+ use_res_block: bool = False,
457
+ learnable_interpolation: bool = True,
458
+ use_attention: str = 'none',
459
+ emb_channels: int = None,
460
+ skip_channels: int = 0
461
+ ):
462
+ super(UpBlock, self).__init__()
463
+ enable_up = ensure_tuple_rep(stride, spatial_dims) != ensure_tuple_rep(1, spatial_dims)
464
+ skip_out_channels = out_channels if learnable_interpolation and enable_up else in_channels+skip_channels
465
+ self.learnable_interpolation = learnable_interpolation
466
+
467
+
468
+ # -------------- Up ----------------------
469
+ self.up_op = BasicUp(
470
+ spatial_dims=spatial_dims,
471
+ in_channels=in_channels,
472
+ out_channels=out_channels,
473
+ kernel_size=upsample_kernel_size,
474
+ stride=stride,
475
+ learnable_interpolation=learnable_interpolation,
476
+ use_res=False
477
+ ) if enable_up else nn.Identity()
478
+
479
+ # ---------------- Attention -------------
480
+ self.attention = Attention(
481
+ spatial_dims=spatial_dims,
482
+ in_channels=skip_out_channels,
483
+ out_channels=skip_out_channels,
484
+ num_heads=8,
485
+ ch_per_head=skip_out_channels//8,
486
+ depth=1,
487
+ norm_name=norm_name,
488
+ dropout=dropout,
489
+ emb_dim=emb_channels,
490
+ attention_type=use_attention
491
+ )
492
+
493
+
494
+ # -------------- Convolution ----------------------
495
+ ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
496
+ self.conv_block = ConvBlock(
497
+ spatial_dims,
498
+ skip_out_channels,
499
+ out_channels,
500
+ kernel_size=kernel_size,
501
+ stride=1,
502
+ dropout=dropout,
503
+ norm_name=norm_name,
504
+ act_name=act_name,
505
+ emb_channels=emb_channels
506
+ )
507
+
508
+
509
+
510
+ def forward(self, x_enc, x_skip=None, emb=None):
511
+ # ----------- Up -------------
512
+ x = self.up_op(x_enc)
513
+
514
+ # ----------- Skip Connection ------------
515
+ if x_skip is not None:
516
+ if self.learnable_interpolation: # Channel of x_enc and x_skip are equal and summation is possible
517
+ x = x+x_skip
518
+ else:
519
+ x = torch.cat((x, x_skip), dim=1)
520
+
521
+ # ----------- Attention -------------
522
+ if self.attention is not None:
523
+ x = self.attention(x, emb)
524
+
525
+ # ----------- Convolution ------------
526
+ x = self.conv_block(x, emb)
527
+
528
+ return x
medical_diffusion/utils/math_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def kl_gaussians(mean1, logvar1, mean2, logvar2):
4
+ """ Compute the KL divergence between two gaussians."""
5
+ return 0.5 * (logvar2-logvar1 + torch.exp(logvar1 - logvar2) + torch.pow(mean1 - mean2, 2) * torch.exp(-logvar2)-1.0)
6
+
medical_diffusion/utils/train_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class EMAModel(nn.Module):
6
+ # See: https://github.com/huggingface/diffusers/blob/3100bc967084964480628ae61210b7eaa7436f1d/src/diffusers/training_utils.py#L42
7
+ """
8
+ Exponential Moving Average of models weights
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ model,
14
+ update_after_step=0,
15
+ inv_gamma=1.0,
16
+ power=2 / 3,
17
+ min_value=0.0,
18
+ max_value=0.9999,
19
+ ):
20
+ super().__init__()
21
+ """
22
+ @crowsonkb's notes on EMA Warmup:
23
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
24
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
25
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
26
+ at 215.4k steps).
27
+ Args:
28
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
29
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
30
+ min_value (float): The minimum EMA decay rate. Default: 0.
31
+ """
32
+
33
+ self.averaged_model = copy.deepcopy(model).eval()
34
+ self.averaged_model.requires_grad_(False)
35
+
36
+ self.update_after_step = update_after_step
37
+ self.inv_gamma = inv_gamma
38
+ self.power = power
39
+ self.min_value = min_value
40
+ self.max_value = max_value
41
+
42
+ self.averaged_model = self.averaged_model #.to(device=model.device)
43
+
44
+ self.decay = 0.0
45
+ self.optimization_step = 0
46
+
47
+ def get_decay(self, optimization_step):
48
+ """
49
+ Compute the decay factor for the exponential moving average.
50
+ """
51
+ step = max(0, optimization_step - self.update_after_step - 1)
52
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
53
+
54
+ if step <= 0:
55
+ return 0.0
56
+
57
+ return max(self.min_value, min(value, self.max_value))
58
+
59
+ @torch.no_grad()
60
+ def step(self, new_model):
61
+ ema_state_dict = {}
62
+ ema_params = self.averaged_model.state_dict()
63
+
64
+ self.decay = self.get_decay(self.optimization_step)
65
+
66
+ for key, param in new_model.named_parameters():
67
+ if isinstance(param, dict):
68
+ continue
69
+ try:
70
+ ema_param = ema_params[key]
71
+ except KeyError:
72
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
73
+ ema_params[key] = ema_param
74
+
75
+ if not param.requires_grad:
76
+ ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
77
+ ema_param = ema_params[key]
78
+ else:
79
+ ema_param.mul_(self.decay)
80
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
81
+
82
+ ema_state_dict[key] = ema_param
83
+
84
+ for key, param in new_model.named_buffers():
85
+ ema_state_dict[key] = param
86
+
87
+ self.averaged_model.load_state_dict(ema_state_dict, strict=False)
88
+ self.optimization_step += 1
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch # pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
2
+ numpy
3
+ sklearn
4
+ pytorch-lightning
5
+ pytorch_msssim
6
+ monai
7
+ torchmetrics
8
+ torch-fidelity
9
+ torchio
10
+ pillow
11
+ einops
12
+ torchvision
13
+ matplotlib
14
+ pandas
15
+ lpips
16
+
17
+ streamlit
scripts/evaluate_images.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import logging
3
+ from datetime import datetime
4
+ from tqdm import tqdm
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data.dataloader import DataLoader
9
+ from torchvision.datasets import ImageFolder
10
+ from torch.utils.data import TensorDataset, Subset
11
+ from torchmetrics.image.fid import FrechetInceptionDistance as FID
12
+ from torchmetrics.image.inception import InceptionScore as IS
13
+
14
+ from medical_diffusion.metrics.torchmetrics_pr_recall import ImprovedPrecessionRecall
15
+
16
+
17
+ # ----------------Settings --------------
18
+ batch_size = 100
19
+ max_samples = None # set to None for all
20
+ # path_out = Path.cwd()/'results'/'MSIvsMSS_2'/'metrics'
21
+ # path_out = Path.cwd()/'results'/'AIROGS'/'metrics'
22
+ path_out = Path.cwd()/'results'/'CheXpert'/'metrics'
23
+ path_out.mkdir(parents=True, exist_ok=True)
24
+
25
+
26
+ # ----------------- Logging -----------
27
+ current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S")
28
+ logger = logging.getLogger()
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger.addHandler(logging.FileHandler(path_out/f'metrics_{current_time}.log', 'w'))
31
+
32
+ # -------------- Helpers ---------------------
33
+ pil2torch = lambda x: torch.as_tensor(np.array(x)).moveaxis(-1, 0) # In contrast to ToTensor(), this will not cast 0-255 to 0-1 and destroy uint8 (required later)
34
+
35
+
36
+ # ---------------- Dataset/Dataloader ----------------
37
+ # ds_real = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train', transform=pil2torch)
38
+ # ds_fake = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/SYNTH-CRC-10K/', transform=pil2torch)
39
+ # ds_fake = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_250', transform=pil2torch)
40
+
41
+ # ds_real = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/', transform=pil2torch)
42
+ # ds_fake = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_generated_stylegan3/', transform=pil2torch)
43
+ # ds_fake = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_generated_diffusion', transform=pil2torch)
44
+
45
+ ds_real = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference/', transform=pil2torch)
46
+ # ds_fake = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_progan/', transform=pil2torch)
47
+ ds_fake = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_diffusion3_250/', transform=pil2torch)
48
+
49
+ ds_real.samples = ds_real.samples[slice(max_samples)]
50
+ ds_fake.samples = ds_fake.samples[slice(max_samples)]
51
+
52
+
53
+ # --------- Select specific class ------------
54
+ # target_class = 'MSIH'
55
+ # ds_real = Subset(ds_real, [i for i in range(len(ds_real)) if ds_real.samples[i][1] == ds_real.class_to_idx[target_class]])
56
+ # ds_fake = Subset(ds_fake, [i for i in range(len(ds_fake)) if ds_fake.samples[i][1] == ds_fake.class_to_idx[target_class]])
57
+
58
+ # Only for testing metrics against OpenAI implementation
59
+ # ds_real = TensorDataset(torch.from_numpy(np.load('/home/gustav/Documents/code/guided-diffusion/data/VIRTUAL_imagenet64_labeled.npz')['arr_0']).swapaxes(1,-1))
60
+ # ds_fake = TensorDataset(torch.from_numpy(np.load('/home/gustav/Documents/code/guided-diffusion/data/biggan_deep_imagenet64.npz')['arr_0']).swapaxes(1,-1))
61
+
62
+
63
+ dm_real = DataLoader(ds_real, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False)
64
+ dm_fake = DataLoader(ds_fake, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False)
65
+
66
+ logger.info(f"Samples Real: {len(ds_real)}")
67
+ logger.info(f"Samples Fake: {len(ds_fake)}")
68
+
69
+ # ------------- Init Metrics ----------------------
70
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
71
+ calc_fid = FID().to(device) # requires uint8
72
+ # calc_is = IS(splits=1).to(device) # requires uint8, features must be 1008 see https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L603
73
+ calc_pr = ImprovedPrecessionRecall(splits_real=1, splits_fake=1).to(device)
74
+
75
+
76
+
77
+
78
+ # --------------- Start Calculation -----------------
79
+ for real_batch in tqdm(dm_real):
80
+ imgs_real_batch = real_batch[0].to(device)
81
+
82
+ # -------------- FID -------------------
83
+ calc_fid.update(imgs_real_batch, real=True)
84
+
85
+ # ------ Improved Precision/Recall--------
86
+ calc_pr.update(imgs_real_batch, real=True)
87
+
88
+ # torch.save(torch.concat(calc_fid.real_features), 'real_fid.pt')
89
+ # torch.save(torch.concat(calc_pr.real_features), 'real_ipr.pt')
90
+
91
+
92
+ for fake_batch in tqdm(dm_fake):
93
+ imgs_fake_batch = fake_batch[0].to(device)
94
+
95
+ # -------------- FID -------------------
96
+ calc_fid.update(imgs_fake_batch, real=False)
97
+
98
+ # -------------- IS -------------------
99
+ # calc_is.update(imgs_fake_batch)
100
+
101
+ # ---- Improved Precision/Recall--------
102
+ calc_pr.update(imgs_fake_batch, real=False)
103
+
104
+ # torch.save(torch.concat(calc_fid.fake_features), 'fake_fid.pt')
105
+ # torch.save(torch.concat(calc_pr.fake_features), 'fake_ipr.pt')
106
+
107
+ # --------------- Load features --------------
108
+ # real_fid = torch.as_tensor(torch.load('real_fid.pt'), device=device)
109
+ # real_ipr = torch.as_tensor(torch.load('real_ipr.pt'), device=device)
110
+ # fake_fid = torch.as_tensor(torch.load('fake_fid.pt'), device=device)
111
+ # fake_ipr = torch.as_tensor(torch.load('fake_ipr.pt'), device=device)
112
+
113
+ # calc_fid.real_features = real_fid.chunk(batch_size)
114
+ # calc_pr.real_features = real_ipr.chunk(batch_size)
115
+ # calc_fid.fake_features = fake_fid.chunk(batch_size)
116
+ # calc_pr.fake_features = fake_ipr.chunk(batch_size)
117
+
118
+
119
+
120
+ # -------------- Summary -------------------
121
+ fid = calc_fid.compute()
122
+ logger.info(f"FID Score: {fid}")
123
+
124
+ # is_mean, is_std = calc_is.compute()
125
+ # logger.info(f"IS Score: mean {is_mean} std {is_std}")
126
+
127
+ precision, recall = calc_pr.compute()
128
+ logger.info(f"Precision: {precision}, Recall {recall} ")
129
+
scripts/evaluate_latent_embedder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import logging
3
+ from datetime import datetime
4
+ from tqdm import tqdm
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchvision.transforms.functional as tF
9
+ from torch.utils.data.dataloader import DataLoader
10
+ from torchvision.datasets import ImageFolder
11
+ from torch.utils.data import TensorDataset, Subset
12
+
13
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
14
+ from torchmetrics.functional import multiscale_structural_similarity_index_measure as mmssim
15
+
16
+ from medical_diffusion.models.embedders.latent_embedders import VAE
17
+
18
+
19
+ # ----------------Settings --------------
20
+ batch_size = 100
21
+ max_samples = None # set to None for all
22
+ target_class = None # None for no specific class
23
+ # path_out = Path.cwd()/'results'/'MSIvsMSS_2'/'metrics'
24
+ # path_out = Path.cwd()/'results'/'AIROGS'/'metrics'
25
+ path_out = Path.cwd()/'results'/'CheXpert'/'metrics'
26
+ path_out.mkdir(parents=True, exist_ok=True)
27
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+
29
+ # ----------------- Logging -----------
30
+ current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S")
31
+ logger = logging.getLogger()
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger.addHandler(logging.FileHandler(path_out/f'metrics_{current_time}.log', 'w'))
34
+
35
+
36
+ # -------------- Helpers ---------------------
37
+ pil2torch = lambda x: torch.as_tensor(np.array(x)).moveaxis(-1, 0) # In contrast to ToTensor(), this will not cast 0-255 to 0-1 and destroy uint8 (required later)
38
+
39
+ # ---------------- Dataset/Dataloader ----------------
40
+ ds_real = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', transform=pil2torch)
41
+ # ds_real = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/', transform=pil2torch)
42
+ # ds_real = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference_test/', transform=pil2torch)
43
+
44
+ # ---------- Limit Sample Size
45
+ ds_real.samples = ds_real.samples[slice(max_samples)]
46
+
47
+
48
+ # --------- Select specific class ------------
49
+ if target_class is not None:
50
+ ds_real = Subset(ds_real, [i for i in range(len(ds_real)) if ds_real.samples[i][1] == ds_real.class_to_idx[target_class]])
51
+ dm_real = DataLoader(ds_real, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False)
52
+
53
+ logger.info(f"Samples Real: {len(ds_real)}")
54
+
55
+
56
+ # --------------- Load Model ------------------
57
+ model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt')
58
+ model.to(device)
59
+
60
+ # from diffusers import StableDiffusionPipeline
61
+ # with open('auth_token.txt', 'r') as file:
62
+ # auth_token = file.read()
63
+ # pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, use_auth_token=auth_token)
64
+ # model = pipe.vae
65
+ # model.to(device)
66
+
67
+
68
+ # ------------- Init Metrics ----------------------
69
+ calc_lpips = LPIPS().to(device)
70
+
71
+
72
+ # --------------- Start Calculation -----------------
73
+ mmssim_list, mse_list = [], []
74
+ for real_batch in tqdm(dm_real):
75
+ imgs_real_batch = real_batch[0].to(device)
76
+
77
+ imgs_real_batch = tF.normalize(imgs_real_batch/255, 0.5, 0.5) # [0, 255] -> [-1, 1]
78
+ with torch.no_grad():
79
+ imgs_fake_batch = model(imgs_real_batch)[0].clamp(-1, 1)
80
+
81
+ # -------------- LPIP -------------------
82
+ calc_lpips.update(imgs_real_batch, imgs_fake_batch) # expect input to be [-1, 1]
83
+
84
+ # -------------- MS-SSIM + MSE -------------------
85
+ for img_real, img_fake in zip(imgs_real_batch, imgs_fake_batch):
86
+ img_real, img_fake = (img_real+1)/2, (img_fake+1)/2 # [-1, 1] -> [0, 1]
87
+ mmssim_list.append(mmssim(img_real[None], img_fake[None], normalize='relu'))
88
+ mse_list.append(torch.mean(torch.square(img_real-img_fake)))
89
+
90
+
91
+ # -------------- Summary -------------------
92
+ mmssim_list = torch.stack(mmssim_list)
93
+ mse_list = torch.stack(mse_list)
94
+
95
+ lpips = 1-calc_lpips.compute()
96
+ logger.info(f"LPIPS Score: {lpips}")
97
+ logger.info(f"MS-SSIM: {torch.mean(mmssim_list)} ± {torch.std(mmssim_list)}")
98
+ logger.info(f"MSE: {torch.mean(mse_list)} ± {torch.std(mse_list)}")
scripts/helpers/dump_discrimnator.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN
4
+ from pytorch_lightning.trainer import Trainer
5
+ from pytorch_lightning.callbacks import ModelCheckpoint
6
+
7
+ path_root = Path('runs/2022_12_01_210017_patho_vaegan')
8
+
9
+ # Load model
10
+ model = VAEGAN.load_from_checkpoint(path_root/'last.ckpt')
11
+ # model = torch.load(path_root/'last.ckpt')
12
+
13
+
14
+
15
+ # Save model-part
16
+ # torch.save(model.vqvae, path_root/'last_vae.ckpt') # Not working
17
+ # ------ Ugly workaround ----------
18
+ checkpointing = ModelCheckpoint()
19
+ trainer = Trainer(callbacks=[checkpointing])
20
+ trainer.strategy._lightning_module = model.vqvae
21
+ trainer.model = model.vqvae
22
+ trainer.save_checkpoint(path_root/'last_vae.ckpt')
23
+ # -----------------
24
+
25
+ model = VAE.load_from_checkpoint(path_root/'last_vae.ckpt')
26
+ # model = torch.load(path_root/'last_vae.ckpt') # load_state_dict