Fucius commited on
Commit
821f875
·
verified ·
1 Parent(s): e9bd2d4

Upload 49 files

Browse files
Files changed (49) hide show
  1. src/efficientvit/__init__.py +0 -0
  2. src/efficientvit/apps/__init__.py +0 -0
  3. src/efficientvit/apps/data_provider/__init__.py +7 -0
  4. src/efficientvit/apps/data_provider/augment/__init__.py +6 -0
  5. src/efficientvit/apps/data_provider/augment/bbox.py +30 -0
  6. src/efficientvit/apps/data_provider/augment/color_aug.py +84 -0
  7. src/efficientvit/apps/data_provider/base.py +223 -0
  8. src/efficientvit/apps/data_provider/random_resolution/__init__.py +7 -0
  9. src/efficientvit/apps/data_provider/random_resolution/_data_loader.py +1598 -0
  10. src/efficientvit/apps/data_provider/random_resolution/_data_worker.py +377 -0
  11. src/efficientvit/apps/data_provider/random_resolution/controller.py +94 -0
  12. src/efficientvit/apps/setup.py +141 -0
  13. src/efficientvit/apps/trainer/__init__.py +6 -0
  14. src/efficientvit/apps/trainer/base.py +297 -0
  15. src/efficientvit/apps/trainer/run_config.py +121 -0
  16. src/efficientvit/apps/utils/__init__.py +12 -0
  17. src/efficientvit/apps/utils/dist.py +73 -0
  18. src/efficientvit/apps/utils/ema.py +50 -0
  19. src/efficientvit/apps/utils/export.py +47 -0
  20. src/efficientvit/apps/utils/init.py +68 -0
  21. src/efficientvit/apps/utils/lr.py +48 -0
  22. src/efficientvit/apps/utils/metric.py +37 -0
  23. src/efficientvit/apps/utils/misc.py +111 -0
  24. src/efficientvit/apps/utils/opt.py +31 -0
  25. src/efficientvit/models/__init__.py +0 -0
  26. src/efficientvit/models/efficientvit/__init__.py +8 -0
  27. src/efficientvit/models/efficientvit/backbone.py +372 -0
  28. src/efficientvit/models/efficientvit/cls.py +174 -0
  29. src/efficientvit/models/efficientvit/sam.py +653 -0
  30. src/efficientvit/models/efficientvit/seg.py +355 -0
  31. src/efficientvit/models/nn/__init__.py +8 -0
  32. src/efficientvit/models/nn/act.py +30 -0
  33. src/efficientvit/models/nn/drop.py +98 -0
  34. src/efficientvit/models/nn/norm.py +157 -0
  35. src/efficientvit/models/nn/ops.py +585 -0
  36. src/efficientvit/models/utils/__init__.py +7 -0
  37. src/efficientvit/models/utils/list.py +57 -0
  38. src/efficientvit/models/utils/network.py +77 -0
  39. src/efficientvit/models/utils/random.py +73 -0
  40. src/efficientvit/sam_model_zoo.py +53 -0
  41. src/ip_adapter/attention_processor.py +424 -0
  42. src/ip_adapter/resampler.py +120 -0
  43. src/ip_adapter/utils.py +5 -0
  44. src/pipelines/instantid_pipeline.py +768 -0
  45. src/pipelines/instantid_single_pieline.py +772 -0
  46. src/pipelines/lora_pipeline.py +681 -0
  47. src/prompt_attention/p2p_attention.py +148 -0
  48. src/prompt_attention/p2p_utils.py +74 -0
  49. src/prompt_attention/seq_aligner.py +66 -0
src/efficientvit/__init__.py ADDED
File without changes
src/efficientvit/apps/__init__.py ADDED
File without changes
src/efficientvit/apps/data_provider/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .augment import *
6
+ from .base import *
7
+ from .random_resolution import *
src/efficientvit/apps/data_provider/augment/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .bbox import *
6
+ from .color_aug import *
src/efficientvit/apps/data_provider/augment/bbox.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+
7
+ __all__ = ["rand_bbox"]
8
+
9
+
10
+ def rand_bbox(
11
+ h: int,
12
+ w: int,
13
+ lam: float,
14
+ rand_func: callable = np.random.uniform,
15
+ ) -> tuple[int, int, int, int]:
16
+ """randomly sample bbox, used in cutmix"""
17
+ cut_rat = np.sqrt(1.0 - lam)
18
+ cut_w = w * cut_rat
19
+ cut_h = h * cut_rat
20
+
21
+ # uniform
22
+ cx = rand_func(0, w)
23
+ cy = rand_func(0, h)
24
+
25
+ bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
26
+ bby1 = int(np.clip(cy - cut_h / 2, 0, h))
27
+ bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
28
+ bby2 = int(np.clip(cy + cut_h / 2, 0, h))
29
+
30
+ return bbx1, bby1, bbx2, bby2
src/efficientvit/apps/data_provider/augment/color_aug.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ from timm.data.auto_augment import rand_augment_transform
9
+
10
+ __all__ = ["ColorAug", "RandAug"]
11
+
12
+
13
+ class ImageAug:
14
+ def aug_image(self, image: Image.Image) -> Image.Image:
15
+ raise NotImplementedError
16
+
17
+ def __call__(
18
+ self, feed_dict: dict or np.ndarray or Image.Image
19
+ ) -> dict or np.ndarray or Image.Image:
20
+ if isinstance(feed_dict, dict):
21
+ output_dict = feed_dict
22
+ image = feed_dict[self.key]
23
+ else:
24
+ output_dict = None
25
+ image = feed_dict
26
+ is_ndarray = isinstance(image, np.ndarray)
27
+ if is_ndarray:
28
+ image = Image.fromarray(image)
29
+
30
+ image = self.aug_image(image)
31
+
32
+ if is_ndarray:
33
+ image = np.array(image)
34
+
35
+ if output_dict is None:
36
+ return image
37
+ else:
38
+ output_dict[self.key] = image
39
+ return output_dict
40
+
41
+
42
+ class ColorAug(transforms.ColorJitter, ImageAug):
43
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
44
+ super().__init__(
45
+ brightness=brightness,
46
+ contrast=contrast,
47
+ saturation=saturation,
48
+ hue=hue,
49
+ )
50
+ self.key = key
51
+
52
+ def aug_image(self, image: Image.Image) -> Image.Image:
53
+ return transforms.ColorJitter.forward(self, image)
54
+
55
+ def forward(
56
+ self, feed_dict: dict or np.ndarray or Image.Image
57
+ ) -> dict or np.ndarray or Image.Image:
58
+ return ImageAug.__call__(self, feed_dict)
59
+
60
+
61
+ class RandAug(ImageAug):
62
+ def __init__(
63
+ self, config: dict[str, any], mean: tuple[float, float, float], key="data"
64
+ ):
65
+ n = config.get("n", 2)
66
+ m = config.get("m", 9)
67
+ mstd = config.get("mstd", 1.0)
68
+ inc = config.get("inc", 1)
69
+ tpct = config.get("tpct", 0.45)
70
+ config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
71
+
72
+ aa_params = dict(
73
+ translate_pct=tpct,
74
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
75
+ interpolation=Image.BICUBIC,
76
+ )
77
+ self.aug_op = rand_augment_transform(config_str, aa_params)
78
+ self.key = key
79
+
80
+ def aug_image(self, image: Image.Image) -> Image.Image:
81
+ return self.aug_op(image)
82
+
83
+ def __repr__(self):
84
+ return self.aug_op.__repr__()
src/efficientvit/apps/data_provider/base.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import warnings
7
+
8
+ import torch.utils.data
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from src.efficientvit.apps.data_provider.random_resolution import RRSController
12
+ from src.efficientvit.models.utils import val2tuple
13
+
14
+ __all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
15
+
16
+
17
+ def parse_image_size(size: int or str) -> tuple[int, int]:
18
+ if isinstance(size, str):
19
+ size = [int(val) for val in size.split("-")]
20
+ return size[0], size[1]
21
+ else:
22
+ return val2tuple(size, 2)
23
+
24
+
25
+ def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
26
+ g = torch.Generator()
27
+ g.manual_seed(seed) # set random seed before sampling validation set
28
+ rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
29
+
30
+ dropped_indexes = rand_indexes[:drop_size]
31
+ remaining_indexes = rand_indexes[drop_size:]
32
+
33
+ dropped_dataset = copy.deepcopy(dataset)
34
+ for key in keys:
35
+ setattr(
36
+ dropped_dataset,
37
+ key,
38
+ [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
39
+ )
40
+ setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
41
+ return dataset, dropped_dataset
42
+
43
+
44
+ class DataProvider:
45
+ data_keys = ("samples",)
46
+ mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
47
+ SUB_SEED = 937162211 # random seed for sampling subset
48
+ VALID_SEED = 2147483647 # random seed for the validation set
49
+
50
+ name: str
51
+
52
+ def __init__(
53
+ self,
54
+ train_batch_size: int,
55
+ test_batch_size: int or None,
56
+ valid_size: int or float or None,
57
+ n_worker: int,
58
+ image_size: int or list[int] or str or list[str],
59
+ num_replicas: int or None = None,
60
+ rank: int or None = None,
61
+ train_ratio: float or None = None,
62
+ drop_last: bool = False,
63
+ ):
64
+ warnings.filterwarnings("ignore")
65
+ super().__init__()
66
+
67
+ # batch_size & valid_size
68
+ self.train_batch_size = train_batch_size
69
+ self.test_batch_size = test_batch_size or self.train_batch_size
70
+ self.valid_size = valid_size
71
+
72
+ # image size
73
+ if isinstance(image_size, list):
74
+ self.image_size = [parse_image_size(size) for size in image_size]
75
+ self.image_size.sort() # e.g., 160 -> 224
76
+ RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
77
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
78
+ else:
79
+ self.image_size = parse_image_size(image_size)
80
+ RRSController.IMAGE_SIZE_LIST = [self.image_size]
81
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
82
+
83
+ # distributed configs
84
+ self.num_replicas = num_replicas
85
+ self.rank = rank
86
+
87
+ # build datasets
88
+ train_dataset, val_dataset, test_dataset = self.build_datasets()
89
+
90
+ if train_ratio is not None and train_ratio < 1.0:
91
+ assert 0 < train_ratio < 1
92
+ _, train_dataset = random_drop_data(
93
+ train_dataset,
94
+ int(train_ratio * len(train_dataset)),
95
+ self.SUB_SEED,
96
+ self.data_keys,
97
+ )
98
+
99
+ # build data loader
100
+ self.train = self.build_dataloader(
101
+ train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
102
+ )
103
+ self.valid = self.build_dataloader(
104
+ val_dataset, test_batch_size, n_worker, drop_last=False, train=False
105
+ )
106
+ self.test = self.build_dataloader(
107
+ test_dataset, test_batch_size, n_worker, drop_last=False, train=False
108
+ )
109
+ if self.valid is None:
110
+ self.valid = self.test
111
+ self.sub_train = None
112
+
113
+ @property
114
+ def data_shape(self) -> tuple[int, ...]:
115
+ return 3, self.active_image_size[0], self.active_image_size[1]
116
+
117
+ def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
118
+ raise NotImplementedError
119
+
120
+ def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
121
+ raise NotImplementedError
122
+
123
+ def build_datasets(self) -> tuple[any, any, any]:
124
+ raise NotImplementedError
125
+
126
+ def build_dataloader(
127
+ self,
128
+ dataset: any or None,
129
+ batch_size: int,
130
+ n_worker: int,
131
+ drop_last: bool,
132
+ train: bool,
133
+ ):
134
+ if dataset is None:
135
+ return None
136
+ if isinstance(self.image_size, list) and train:
137
+ from efficientvit.apps.data_provider.random_resolution._data_loader import \
138
+ RRSDataLoader
139
+
140
+ dataloader_class = RRSDataLoader
141
+ else:
142
+ dataloader_class = torch.utils.data.DataLoader
143
+ if self.num_replicas is None:
144
+ return dataloader_class(
145
+ dataset=dataset,
146
+ batch_size=batch_size,
147
+ shuffle=True,
148
+ num_workers=n_worker,
149
+ pin_memory=True,
150
+ drop_last=drop_last,
151
+ )
152
+ else:
153
+ sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
154
+ return dataloader_class(
155
+ dataset=dataset,
156
+ batch_size=batch_size,
157
+ sampler=sampler,
158
+ num_workers=n_worker,
159
+ pin_memory=True,
160
+ drop_last=drop_last,
161
+ )
162
+
163
+ def set_epoch(self, epoch: int) -> None:
164
+ RRSController.set_epoch(epoch, len(self.train))
165
+ if isinstance(self.train.sampler, DistributedSampler):
166
+ self.train.sampler.set_epoch(epoch)
167
+
168
+ def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
169
+ self.active_image_size = val2tuple(new_size, 2)
170
+ new_transform = self.build_valid_transform(self.active_image_size)
171
+ # change the transform of the valid and test set
172
+ self.valid.dataset.transform = self.test.dataset.transform = new_transform
173
+
174
+ def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
175
+ if self.valid_size is not None:
176
+ if 0 < self.valid_size < 1:
177
+ valid_size = int(self.valid_size * len(train_dataset))
178
+ else:
179
+ assert self.valid_size >= 1
180
+ valid_size = int(self.valid_size)
181
+ train_dataset, val_dataset = random_drop_data(
182
+ train_dataset,
183
+ valid_size,
184
+ self.VALID_SEED,
185
+ self.data_keys,
186
+ )
187
+ val_dataset.transform = valid_transform
188
+ else:
189
+ val_dataset = None
190
+ return train_dataset, val_dataset
191
+
192
+ def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
193
+ # used for resetting BN running statistics
194
+ if self.sub_train is None:
195
+ self.sub_train = {}
196
+ if self.active_image_size in self.sub_train:
197
+ return self.sub_train[self.active_image_size]
198
+
199
+ # construct dataset and dataloader
200
+ train_dataset = copy.deepcopy(self.train.dataset)
201
+ if n_samples < len(train_dataset):
202
+ _, train_dataset = random_drop_data(
203
+ train_dataset,
204
+ n_samples,
205
+ self.SUB_SEED,
206
+ self.data_keys,
207
+ )
208
+ RRSController.ACTIVE_SIZE = self.active_image_size
209
+ train_dataset.transform = self.build_train_transform(
210
+ image_size=self.active_image_size
211
+ )
212
+ data_loader = self.build_dataloader(
213
+ train_dataset, batch_size, self.train.num_workers, True, False
214
+ )
215
+
216
+ # pre-fetch data
217
+ self.sub_train[self.active_image_size] = [
218
+ data
219
+ for data in data_loader
220
+ for _ in range(max(1, n_samples // len(train_dataset)))
221
+ ]
222
+
223
+ return self.sub_train[self.active_image_size]
src/efficientvit/apps/data_provider/random_resolution/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Random resolution data loader compatible with multi-processing and distributed training.
2
+
3
+ Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
4
+ at the training time, resolution sampling is controlled by RRSController
5
+ """
6
+
7
+ from .controller import *
src/efficientvit/apps/data_provider/random_resolution/_data_loader.py ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""This file is based on torch/utils/data/data_loader.py
2
+
3
+ Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
4
+
5
+ To support these two classes, in `./_utils` we define many utility methods and
6
+ functions to be run in multiprocessing. E.g., the data loading worker loop is
7
+ in `./_utils/worker.py`.
8
+ """
9
+
10
+ import functools
11
+ import itertools
12
+ import logging
13
+ import multiprocessing as python_multiprocessing
14
+ import os
15
+ import queue
16
+ import threading
17
+ import warnings
18
+ from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence,
19
+ TypeVar, Union)
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ import torch.multiprocessing as multiprocessing
24
+ import torch.utils.data.graph_settings
25
+ from torch._utils import ExceptionWrapper
26
+ from torch.utils.data import (BatchSampler, Dataset, IterableDataset,
27
+ IterDataPipe, MapDataPipe, RandomSampler,
28
+ Sampler, SequentialSampler, _utils)
29
+ from torch.utils.data.datapipes.datapipe import (
30
+ _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper)
31
+
32
+ from ._data_worker import _worker_loop
33
+
34
+ __all__ = ["RRSDataLoader"]
35
+
36
+ T_co = TypeVar("T_co", covariant=True)
37
+ T = TypeVar("T")
38
+ _worker_init_fn_t = Callable[[int], None]
39
+
40
+ # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
41
+ # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
42
+ # See https://github.com/python/mypy/issues/3737.
43
+ _collate_fn_t = Callable[[List[T]], Any]
44
+
45
+
46
+ # These functions used to be defined in this file. However, it was moved to
47
+ # _utils/collate.py. Although it is rather hard to access this from user land
48
+ # (one has to explicitly directly `import torch.utils.data.dataloader`), there
49
+ # probably is user code out there using it. This aliasing maintains BC in this
50
+ # aspect.
51
+ default_collate: _collate_fn_t = _utils.collate.default_collate
52
+ default_convert = _utils.collate.default_convert
53
+
54
+ get_worker_info = _utils.worker.get_worker_info
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ class _DatasetKind:
60
+ Map = 0
61
+ Iterable = 1
62
+
63
+ @staticmethod
64
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
65
+ if kind == _DatasetKind.Map:
66
+ return _utils.fetch._MapDatasetFetcher(
67
+ dataset, auto_collation, collate_fn, drop_last
68
+ )
69
+ else:
70
+ return _utils.fetch._IterableDatasetFetcher(
71
+ dataset, auto_collation, collate_fn, drop_last
72
+ )
73
+
74
+
75
+ class _InfiniteConstantSampler(Sampler):
76
+ r"""Analogous to ``itertools.repeat(None, None)``.
77
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
78
+
79
+ Args:
80
+ data_source (Dataset): dataset to sample from
81
+ """
82
+
83
+ def __init__(self):
84
+ super().__init__(None)
85
+
86
+ def __iter__(self):
87
+ while True:
88
+ yield None
89
+
90
+
91
+ def _get_distributed_settings():
92
+ if dist.is_available() and dist.is_initialized():
93
+ return dist.get_world_size(), dist.get_rank()
94
+ else:
95
+ return 1, 0
96
+
97
+
98
+ def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
99
+ global_worker_id = worker_id
100
+ info = torch.utils.data.get_worker_info()
101
+ assert info is not None
102
+ total_workers = info.num_workers
103
+ datapipe = info.dataset
104
+ assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
105
+ # To distribute elements across distributed process evenly, we should shard data on distributed
106
+ # processes first then shard on worker processes
107
+ total_workers *= world_size
108
+ global_worker_id = global_worker_id * world_size + rank_id
109
+ # For BC, use default SHARDING_PRIORITIES
110
+ torch.utils.data.graph_settings.apply_sharding(
111
+ datapipe, total_workers, global_worker_id
112
+ )
113
+ if worker_init_fn is not None:
114
+ worker_init_fn(worker_id)
115
+
116
+
117
+ def _share_dist_seed(generator, pg):
118
+ _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
119
+ if isinstance(pg, dist.ProcessGroup):
120
+ dist.broadcast(_shared_seed, src=0, group=pg)
121
+ return _shared_seed.item()
122
+
123
+
124
+ class RRSDataLoader(Generic[T_co]):
125
+ r"""
126
+ Data loader. Combines a dataset and a sampler, and provides an iterable over
127
+ the given dataset.
128
+
129
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
130
+ iterable-style datasets with single- or multi-process loading, customizing
131
+ loading order and optional automatic batching (collation) and memory pinning.
132
+
133
+ See :py:mod:`torch.utils.data` documentation page for more details.
134
+
135
+ Args:
136
+ dataset (Dataset): dataset from which to load the data.
137
+ batch_size (int, optional): how many samples per batch to load
138
+ (default: ``1``).
139
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
140
+ at every epoch (default: ``False``).
141
+ sampler (Sampler or Iterable, optional): defines the strategy to draw
142
+ samples from the dataset. Can be any ``Iterable`` with ``__len__``
143
+ implemented. If specified, :attr:`shuffle` must not be specified.
144
+ batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
145
+ returns a batch of indices at a time. Mutually exclusive with
146
+ :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
147
+ and :attr:`drop_last`.
148
+ num_workers (int, optional): how many subprocesses to use for data
149
+ loading. ``0`` means that the data will be loaded in the main process.
150
+ (default: ``0``)
151
+ collate_fn (Callable, optional): merges a list of samples to form a
152
+ mini-batch of Tensor(s). Used when using batched loading from a
153
+ map-style dataset.
154
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
155
+ into device/CUDA pinned memory before returning them. If your data elements
156
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
157
+ see the example below.
158
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
159
+ if the dataset size is not divisible by the batch size. If ``False`` and
160
+ the size of dataset is not divisible by the batch size, then the last batch
161
+ will be smaller. (default: ``False``)
162
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
163
+ from workers. Should always be non-negative. (default: ``0``)
164
+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
165
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
166
+ input, after seeding and before data loading. (default: ``None``)
167
+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
168
+ by RandomSampler to generate random indexes and multiprocessing to generate
169
+ `base_seed` for workers. (default: ``None``)
170
+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
171
+ in advance by each worker. ``2`` means there will be a total of
172
+ 2 * num_workers batches prefetched across all workers. (default value depends
173
+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
174
+ Otherwise if value of num_workers>0 default is ``2``).
175
+ persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
176
+ the worker processes after a dataset has been consumed once. This allows to
177
+ maintain the workers `Dataset` instances alive. (default: ``False``)
178
+ pin_memory_device (str, optional): the data loader will copy Tensors
179
+ into device pinned memory before returning them if pin_memory is set to true.
180
+
181
+
182
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
183
+ cannot be an unpicklable object, e.g., a lambda function. See
184
+ :ref:`multiprocessing-best-practices` on more details related
185
+ to multiprocessing in PyTorch.
186
+
187
+ .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
188
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
189
+ it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
190
+ rounding depending on :attr:`drop_last`, regardless of multi-process loading
191
+ configurations. This represents the best guess PyTorch can make because PyTorch
192
+ trusts user :attr:`dataset` code in correctly handling multi-process
193
+ loading to avoid duplicate data.
194
+
195
+ However, if sharding results in multiple workers having incomplete last batches,
196
+ this estimate can still be inaccurate, because (1) an otherwise complete batch can
197
+ be broken into multiple ones and (2) more than one batch worth of samples can be
198
+ dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
199
+ cases in general.
200
+
201
+ See `Dataset Types`_ for more details on these two types of datasets and how
202
+ :class:`~torch.utils.data.IterableDataset` interacts with
203
+ `Multi-process data loading`_.
204
+
205
+ .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
206
+ :ref:`data-loading-randomness` notes for random seed related questions.
207
+ """
208
+
209
+ dataset: Dataset[T_co]
210
+ batch_size: Optional[int]
211
+ num_workers: int
212
+ pin_memory: bool
213
+ drop_last: bool
214
+ timeout: float
215
+ sampler: Union[Sampler, Iterable]
216
+ pin_memory_device: str
217
+ prefetch_factor: Optional[int]
218
+ _iterator: Optional["_BaseDataLoaderIter"]
219
+ __initialized = False
220
+
221
+ def __init__(
222
+ self,
223
+ dataset: Dataset[T_co],
224
+ batch_size: Optional[int] = 1,
225
+ shuffle: Optional[bool] = None,
226
+ sampler: Union[Sampler, Iterable, None] = None,
227
+ batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
228
+ num_workers: int = 0,
229
+ collate_fn: Optional[_collate_fn_t] = None,
230
+ pin_memory: bool = False,
231
+ drop_last: bool = False,
232
+ timeout: float = 0,
233
+ worker_init_fn: Optional[_worker_init_fn_t] = None,
234
+ multiprocessing_context=None,
235
+ generator=None,
236
+ *,
237
+ prefetch_factor: Optional[int] = None,
238
+ persistent_workers: bool = False,
239
+ pin_memory_device: str = ""
240
+ ):
241
+ torch._C._log_api_usage_once("python.data_loader")
242
+
243
+ if num_workers < 0:
244
+ raise ValueError(
245
+ "num_workers option should be non-negative; "
246
+ "use num_workers=0 to disable multiprocessing."
247
+ )
248
+
249
+ if timeout < 0:
250
+ raise ValueError("timeout option should be non-negative")
251
+
252
+ if num_workers == 0 and prefetch_factor is not None:
253
+ raise ValueError(
254
+ "prefetch_factor option could only be specified in multiprocessing."
255
+ "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
256
+ )
257
+ elif num_workers > 0 and prefetch_factor is None:
258
+ prefetch_factor = 2
259
+ elif prefetch_factor is not None and prefetch_factor < 0:
260
+ raise ValueError("prefetch_factor option should be non-negative")
261
+
262
+ if persistent_workers and num_workers == 0:
263
+ raise ValueError("persistent_workers option needs num_workers > 0")
264
+
265
+ self.dataset = dataset
266
+ self.num_workers = num_workers
267
+ self.prefetch_factor = prefetch_factor
268
+ self.pin_memory = pin_memory
269
+ self.pin_memory_device = pin_memory_device
270
+ self.timeout = timeout
271
+ self.worker_init_fn = worker_init_fn
272
+ self.multiprocessing_context = multiprocessing_context
273
+
274
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
275
+ # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
276
+ if isinstance(self.dataset, IterDataPipe):
277
+ self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
278
+ elif isinstance(self.dataset, MapDataPipe):
279
+ self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
280
+
281
+ # Arg-check dataset related before checking samplers because we want to
282
+ # tell users that iterable-style datasets are incompatible with custom
283
+ # samplers first, so that they don't learn that this combo doesn't work
284
+ # after spending time fixing the custom sampler errors.
285
+ if isinstance(dataset, IterableDataset):
286
+ self._dataset_kind = _DatasetKind.Iterable
287
+ # NOTE [ Custom Samplers and IterableDataset ]
288
+ #
289
+ # `IterableDataset` does not support custom `batch_sampler` or
290
+ # `sampler` since the key is irrelevant (unless we support
291
+ # generator-style dataset one day...).
292
+ #
293
+ # For `sampler`, we always create a dummy sampler. This is an
294
+ # infinite sampler even when the dataset may have an implemented
295
+ # finite `__len__` because in multi-process data loading, naive
296
+ # settings will return duplicated data (which may be desired), and
297
+ # thus using a sampler with length matching that of dataset will
298
+ # cause data lost (you may have duplicates of the first couple
299
+ # batches, but never see anything afterwards). Therefore,
300
+ # `Iterabledataset` always uses an infinite sampler, an instance of
301
+ # `_InfiniteConstantSampler` defined above.
302
+ #
303
+ # A custom `batch_sampler` essentially only controls the batch size.
304
+ # However, it is unclear how useful it would be since an iterable-style
305
+ # dataset can handle that within itself. Moreover, it is pointless
306
+ # in multi-process data loading as the assignment order of batches
307
+ # to workers is an implementation detail so users can not control
308
+ # how to batchify each worker's iterable. Thus, we disable this
309
+ # option. If this turns out to be useful in future, we can re-enable
310
+ # this, and support custom samplers that specify the assignments to
311
+ # specific workers.
312
+ if isinstance(dataset, IterDataPipe):
313
+ if shuffle is not None:
314
+ dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
315
+ dataset, shuffle=shuffle
316
+ )
317
+ # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
318
+ elif shuffle not in {False, None}:
319
+ raise ValueError(
320
+ "DataLoader with IterableDataset: expected unspecified "
321
+ "shuffle option, but got shuffle={}".format(shuffle)
322
+ )
323
+
324
+ if sampler is not None:
325
+ # See NOTE [ Custom Samplers and IterableDataset ]
326
+ raise ValueError(
327
+ "DataLoader with IterableDataset: expected unspecified "
328
+ "sampler option, but got sampler={}".format(sampler)
329
+ )
330
+ elif batch_sampler is not None:
331
+ # See NOTE [ Custom Samplers and IterableDataset ]
332
+ raise ValueError(
333
+ "DataLoader with IterableDataset: expected unspecified "
334
+ "batch_sampler option, but got batch_sampler={}".format(
335
+ batch_sampler
336
+ )
337
+ )
338
+ else:
339
+ shuffle = bool(shuffle)
340
+ self._dataset_kind = _DatasetKind.Map
341
+
342
+ if sampler is not None and shuffle:
343
+ raise ValueError("sampler option is mutually exclusive with " "shuffle")
344
+
345
+ if batch_sampler is not None:
346
+ # auto_collation with custom batch_sampler
347
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
348
+ raise ValueError(
349
+ "batch_sampler option is mutually exclusive "
350
+ "with batch_size, shuffle, sampler, and "
351
+ "drop_last"
352
+ )
353
+ batch_size = None
354
+ drop_last = False
355
+ elif batch_size is None:
356
+ # no auto_collation
357
+ if drop_last:
358
+ raise ValueError(
359
+ "batch_size=None option disables auto-batching "
360
+ "and is mutually exclusive with drop_last"
361
+ )
362
+
363
+ if sampler is None: # give default samplers
364
+ if self._dataset_kind == _DatasetKind.Iterable:
365
+ # See NOTE [ Custom Samplers and IterableDataset ]
366
+ sampler = _InfiniteConstantSampler()
367
+ else: # map-style
368
+ if shuffle:
369
+ sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
370
+ else:
371
+ sampler = SequentialSampler(dataset) # type: ignore[arg-type]
372
+
373
+ if batch_size is not None and batch_sampler is None:
374
+ # auto_collation without custom batch_sampler
375
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
376
+
377
+ self.batch_size = batch_size
378
+ self.drop_last = drop_last
379
+ self.sampler = sampler
380
+ self.batch_sampler = batch_sampler
381
+ self.generator = generator
382
+
383
+ if collate_fn is None:
384
+ if self._auto_collation:
385
+ collate_fn = _utils.collate.default_collate
386
+ else:
387
+ collate_fn = _utils.collate.default_convert
388
+
389
+ self.collate_fn = collate_fn
390
+ self.persistent_workers = persistent_workers
391
+
392
+ self.__initialized = True
393
+ self._IterableDataset_len_called = (
394
+ None # See NOTE [ IterableDataset and __len__ ]
395
+ )
396
+
397
+ self._iterator = None
398
+
399
+ self.check_worker_number_rationality()
400
+
401
+ torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
402
+
403
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
404
+ if self.num_workers == 0:
405
+ return _SingleProcessDataLoaderIter(self)
406
+ else:
407
+ self.check_worker_number_rationality()
408
+ return _MultiProcessingDataLoaderIter(self)
409
+
410
+ @property
411
+ def multiprocessing_context(self):
412
+ return self.__multiprocessing_context
413
+
414
+ @multiprocessing_context.setter
415
+ def multiprocessing_context(self, multiprocessing_context):
416
+ if multiprocessing_context is not None:
417
+ if self.num_workers > 0:
418
+ if isinstance(multiprocessing_context, str):
419
+ valid_start_methods = multiprocessing.get_all_start_methods()
420
+ if multiprocessing_context not in valid_start_methods:
421
+ raise ValueError(
422
+ (
423
+ "multiprocessing_context option "
424
+ "should specify a valid start method in {!r}, but got "
425
+ "multiprocessing_context={!r}"
426
+ ).format(valid_start_methods, multiprocessing_context)
427
+ )
428
+ multiprocessing_context = multiprocessing.get_context(
429
+ multiprocessing_context
430
+ )
431
+
432
+ if not isinstance(
433
+ multiprocessing_context, python_multiprocessing.context.BaseContext
434
+ ):
435
+ raise TypeError(
436
+ (
437
+ "multiprocessing_context option should be a valid context "
438
+ "object or a string specifying the start method, but got "
439
+ "multiprocessing_context={}"
440
+ ).format(multiprocessing_context)
441
+ )
442
+ else:
443
+ raise ValueError(
444
+ (
445
+ "multiprocessing_context can only be used with "
446
+ "multi-process loading (num_workers > 0), but got "
447
+ "num_workers={}"
448
+ ).format(self.num_workers)
449
+ )
450
+
451
+ self.__multiprocessing_context = multiprocessing_context
452
+
453
+ def __setattr__(self, attr, val):
454
+ if self.__initialized and attr in (
455
+ "batch_size",
456
+ "batch_sampler",
457
+ "sampler",
458
+ "drop_last",
459
+ "dataset",
460
+ "persistent_workers",
461
+ ):
462
+ raise ValueError(
463
+ "{} attribute should not be set after {} is "
464
+ "initialized".format(attr, self.__class__.__name__)
465
+ )
466
+
467
+ super().__setattr__(attr, val)
468
+
469
+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
470
+ # since '_BaseDataLoaderIter' references 'DataLoader'.
471
+ def __iter__(self) -> "_BaseDataLoaderIter":
472
+ # When using a single worker the returned iterator should be
473
+ # created everytime to avoid reseting its state
474
+ # However, in the case of a multiple workers iterator
475
+ # the iterator is only created once in the lifetime of the
476
+ # DataLoader object so that workers can be reused
477
+ if self.persistent_workers and self.num_workers > 0:
478
+ if self._iterator is None:
479
+ self._iterator = self._get_iterator()
480
+ else:
481
+ self._iterator._reset(self)
482
+ return self._iterator
483
+ else:
484
+ return self._get_iterator()
485
+
486
+ @property
487
+ def _auto_collation(self):
488
+ return self.batch_sampler is not None
489
+
490
+ @property
491
+ def _index_sampler(self):
492
+ # The actual sampler used for generating indices for `_DatasetFetcher`
493
+ # (see _utils/fetch.py) to read data at each time. This would be
494
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
495
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
496
+ # reasons.
497
+ if self._auto_collation:
498
+ return self.batch_sampler
499
+ else:
500
+ return self.sampler
501
+
502
+ def __len__(self) -> int:
503
+ if self._dataset_kind == _DatasetKind.Iterable:
504
+ # NOTE [ IterableDataset and __len__ ]
505
+ #
506
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
507
+ # does multi-processing data loading, since the samples will be duplicated.
508
+ # However, no real use case should be actually using that behavior, so
509
+ # it should count as a user error. We should generally trust user
510
+ # code to do the proper thing (e.g., configure each replica differently
511
+ # in `__iter__`), and give us the correct `__len__` if they choose to
512
+ # implement it (this will still throw if the dataset does not implement
513
+ # a `__len__`).
514
+ #
515
+ # To provide a further warning, we track if `__len__` was called on the
516
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
517
+ # if the iterator ends up yielding more than this number of samples.
518
+
519
+ # Cannot statically verify that dataset is Sized
520
+ length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
521
+ if (
522
+ self.batch_size is not None
523
+ ): # IterableDataset doesn't allow custom sampler or batch_sampler
524
+ from math import ceil
525
+
526
+ if self.drop_last:
527
+ length = length // self.batch_size
528
+ else:
529
+ length = ceil(length / self.batch_size)
530
+ return length
531
+ else:
532
+ return len(self._index_sampler)
533
+
534
+ def check_worker_number_rationality(self):
535
+ # This function check whether the dataloader's worker number is rational based on
536
+ # current system's resource. Current rule is that if the number of workers this
537
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
538
+ # use, than we will pop up a warning to let user pay attention.
539
+ #
540
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
541
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
542
+ # DataLoader process can use half of them which is 32, then the rational max number of
543
+ # worker that initiated from this process is 32.
544
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
545
+ # So the warning message is triggered to notify the user to lower the worker number if
546
+ # necessary.
547
+ #
548
+ #
549
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
550
+ # available (available in most of Linux system, but not OSX and Windows).
551
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
552
+ # it doesn't repect cpuset.
553
+ # We don't take threading into account since each worker process is single threaded
554
+ # at this time.
555
+ #
556
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
557
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
558
+ # in functions use 3rd party modules that rely on those threading flags to determine
559
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
560
+ # set those flags correctly.
561
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
562
+
563
+ suggested_max_worker_msg = (
564
+ (
565
+ (
566
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
567
+ "than what this DataLoader is going to create."
568
+ ).format(
569
+ num_worker_suggest,
570
+ (
571
+ ""
572
+ if cpuset_checked
573
+ else " (`cpuset` is not taken into account)"
574
+ ),
575
+ )
576
+ )
577
+ if num_worker_suggest is not None
578
+ else (
579
+ "DataLoader is not able to compute a suggested max number of worker in current system."
580
+ )
581
+ )
582
+
583
+ warn_msg = (
584
+ "This DataLoader will create {} worker processes in total. {} "
585
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
586
+ "lower the worker number to avoid potential slowness/freeze if necessary."
587
+ ).format(num_worker_created, suggested_max_worker_msg)
588
+ return warn_msg
589
+
590
+ if not self.num_workers or self.num_workers == 0:
591
+ return
592
+
593
+ # try to compute a suggested max number of worker based on system's resource
594
+ max_num_worker_suggest = None
595
+ cpuset_checked = False
596
+ if hasattr(os, "sched_getaffinity"):
597
+ try:
598
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
599
+ cpuset_checked = True
600
+ except Exception:
601
+ pass
602
+ if max_num_worker_suggest is None:
603
+ # os.cpu_count() could return Optional[int]
604
+ # get cpu count first and check None in order to satify mypy check
605
+ cpu_count = os.cpu_count()
606
+ if cpu_count is not None:
607
+ max_num_worker_suggest = cpu_count
608
+
609
+ if max_num_worker_suggest is None:
610
+ warnings.warn(
611
+ _create_warning_msg(
612
+ max_num_worker_suggest, self.num_workers, cpuset_checked
613
+ )
614
+ )
615
+ return
616
+
617
+ if self.num_workers > max_num_worker_suggest:
618
+ warnings.warn(
619
+ _create_warning_msg(
620
+ max_num_worker_suggest, self.num_workers, cpuset_checked
621
+ )
622
+ )
623
+
624
+
625
+ class _BaseDataLoaderIter:
626
+ def __init__(self, loader: RRSDataLoader) -> None:
627
+ self._dataset = loader.dataset
628
+ self._shared_seed = None
629
+ self._pg = None
630
+ if isinstance(self._dataset, IterDataPipe):
631
+ if dist.is_available() and dist.is_initialized():
632
+ self._pg = dist.new_group(backend="gloo")
633
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
634
+ shared_rng = torch.Generator()
635
+ shared_rng.manual_seed(self._shared_seed)
636
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
637
+ self._dataset, shared_rng
638
+ )
639
+ self._dataset_kind = loader._dataset_kind
640
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
641
+ self._auto_collation = loader._auto_collation
642
+ self._drop_last = loader.drop_last
643
+ self._index_sampler = loader._index_sampler
644
+ self._num_workers = loader.num_workers
645
+ ws, rank = _get_distributed_settings()
646
+ self._world_size = ws
647
+ self._rank = rank
648
+ # for other backends, pin_memory_device need to set. if not set
649
+ # default behaviour is CUDA device. if pin_memory_device is selected
650
+ # and pin_memory is not set, the default behaviour false.
651
+ if len(loader.pin_memory_device) == 0:
652
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
653
+ self._pin_memory_device = None
654
+ else:
655
+ if not loader.pin_memory:
656
+ warn_msg = (
657
+ "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
658
+ "please set pin_memory to true, if you need to use the device pin memory"
659
+ )
660
+ warnings.warn(warn_msg)
661
+
662
+ self._pin_memory = loader.pin_memory
663
+ self._pin_memory_device = loader.pin_memory_device
664
+ self._timeout = loader.timeout
665
+ self._collate_fn = loader.collate_fn
666
+ self._sampler_iter = iter(self._index_sampler)
667
+ self._base_seed = (
668
+ torch.empty((), dtype=torch.int64)
669
+ .random_(generator=loader.generator)
670
+ .item()
671
+ )
672
+ self._persistent_workers = loader.persistent_workers
673
+ self._num_yielded = 0
674
+ self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
675
+ self.__class__.__name__
676
+ )
677
+
678
+ def __iter__(self) -> "_BaseDataLoaderIter":
679
+ return self
680
+
681
+ def _reset(self, loader, first_iter=False):
682
+ self._sampler_iter = iter(self._index_sampler)
683
+ self._num_yielded = 0
684
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
685
+ if isinstance(self._dataset, IterDataPipe):
686
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
687
+ shared_rng = torch.Generator()
688
+ shared_rng.manual_seed(self._shared_seed)
689
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
690
+ self._dataset, shared_rng
691
+ )
692
+
693
+ def _next_index(self):
694
+ return next(self._sampler_iter) # may raise StopIteration
695
+
696
+ def _next_data(self):
697
+ raise NotImplementedError
698
+
699
+ def __next__(self) -> Any:
700
+ with torch.autograd.profiler.record_function(self._profile_name):
701
+ if self._sampler_iter is None:
702
+ self._reset() # type: ignore[call-arg]
703
+ data = self._next_data()
704
+ self._num_yielded += 1
705
+ if (
706
+ self._dataset_kind == _DatasetKind.Iterable
707
+ and self._IterableDataset_len_called is not None
708
+ and self._num_yielded > self._IterableDataset_len_called
709
+ ):
710
+ warn_msg = (
711
+ "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
712
+ "samples have been fetched. "
713
+ ).format(
714
+ self._dataset, self._IterableDataset_len_called, self._num_yielded
715
+ )
716
+ if self._num_workers > 0:
717
+ warn_msg += (
718
+ "For multiprocessing data-loading, this could be caused by not properly configuring the "
719
+ "IterableDataset replica at each worker. Please see "
720
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
721
+ )
722
+ warnings.warn(warn_msg)
723
+ return data
724
+
725
+ def __len__(self) -> int:
726
+ return len(self._index_sampler)
727
+
728
+ def __getstate__(self):
729
+ # across multiple threads for HOGWILD.
730
+ # Probably the best way to do this is by moving the sample pushing
731
+ # to a separate thread and then just sharing the data queue
732
+ # but signalling the end is tricky without a non-blocking API
733
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
734
+
735
+
736
+ class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
737
+ def __init__(self, loader):
738
+ super().__init__(loader)
739
+ assert self._timeout == 0
740
+ assert self._num_workers == 0
741
+
742
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
743
+ # Taking care of distributed sharding
744
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
745
+ # For BC, use default SHARDING_PRIORITIES
746
+ torch.utils.data.graph_settings.apply_sharding(
747
+ self._dataset, self._world_size, self._rank
748
+ )
749
+
750
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
751
+ self._dataset_kind,
752
+ self._dataset,
753
+ self._auto_collation,
754
+ self._collate_fn,
755
+ self._drop_last,
756
+ )
757
+
758
+ def _next_data(self):
759
+ index = self._next_index() # may raise StopIteration
760
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
761
+ if self._pin_memory:
762
+ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
763
+ return data
764
+
765
+
766
+ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
767
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
768
+
769
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
770
+ #
771
+ # Preliminary:
772
+ #
773
+ # Our data model looks like this (queues are indicated with curly brackets):
774
+ #
775
+ # main process ||
776
+ # | ||
777
+ # {index_queue} ||
778
+ # | ||
779
+ # worker processes || DATA
780
+ # | ||
781
+ # {worker_result_queue} || FLOW
782
+ # | ||
783
+ # pin_memory_thread of main process || DIRECTION
784
+ # | ||
785
+ # {data_queue} ||
786
+ # | ||
787
+ # data output \/
788
+ #
789
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
790
+ # `pin_memory=False`.
791
+ #
792
+ #
793
+ # Terminating multiprocessing logic requires very careful design. In
794
+ # particular, we need to make sure that
795
+ #
796
+ # 1. The iterator gracefully exits the workers when its last reference is
797
+ # gone or it is depleted.
798
+ #
799
+ # In this case, the workers should be gracefully exited because the
800
+ # main process may still need to continue to run, and we want cleaning
801
+ # up code in the workers to be executed (e.g., releasing GPU memory).
802
+ # Naturally, we implement the shutdown logic in `__del__` of
803
+ # DataLoaderIterator.
804
+ #
805
+ # We delay the discussion on the logic in this case until later.
806
+ #
807
+ # 2. The iterator exits the workers when the loader process and/or worker
808
+ # processes exits normally or with error.
809
+ #
810
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
811
+ #
812
+ # You may ask, why can't we make the workers non-daemonic, and
813
+ # gracefully exit using the same logic as we have in `__del__` when the
814
+ # iterator gets deleted (see 1 above)?
815
+ #
816
+ # First of all, `__del__` is **not** guaranteed to be called when
817
+ # interpreter exits. Even if it is called, by the time it executes,
818
+ # many Python core library resources may alreay be freed, and even
819
+ # simple things like acquiring an internal lock of a queue may hang.
820
+ # Therefore, in this case, we actually need to prevent `__del__` from
821
+ # being executed, and rely on the automatic termination of daemonic
822
+ # children.
823
+ #
824
+ # Thus, we register an `atexit` hook that sets a global flag
825
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
826
+ # reverse order of registration, we are guaranteed that this flag is
827
+ # set before library resources we use are freed (which, at least in
828
+ # CPython, is done via an `atexit` handler defined in
829
+ # `multiprocessing/util.py`
830
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
831
+ # registered when an object requiring this mechanism is first
832
+ # created, e.g., `mp.Queue`
833
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
834
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
835
+ # )
836
+ #
837
+ # So in `__del__`, we check if `_utils.python_exit_status` is set or
838
+ # `None` (freed), and perform no-op if so.
839
+ #
840
+ # However, simply letting library clean-up codes run can also be bad,
841
+ # because such codes (i.e., `multiprocessing.util._exit_function()`)
842
+ # include join putting threads for `mp.Queue`, which can be blocking.
843
+ # Hence, the main process putting threads are called with
844
+ # `cancel_join_thread` at creation. See later section
845
+ # [ 3b. A process won't hang when putting into a queue; ]
846
+ # for more details.
847
+ #
848
+ # Here are two example cases where library clean-up codes can run
849
+ # before `__del__` is called:
850
+ #
851
+ # 1. If we hold onto a reference to the iterator, it more often
852
+ # than not tries to do `multiprocessing` library cleaning before
853
+ # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
854
+ # and thus prevents our cleaning-up code to run first.
855
+ #
856
+ # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
857
+ # When a process ends, it shuts the all its daemonic children
858
+ # down with a SIGTERM (instead of joining them without a timeout).
859
+ # Simiarly for threads, but by a different mechanism. This fact,
860
+ # together with a few implementation details of multiprocessing, forces
861
+ # us to make workers daemonic. All of our problems arise when a
862
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
863
+ # code which looks more or less like this:
864
+ #
865
+ # try:
866
+ # your_function_using_a_dataloader()
867
+ # finally:
868
+ # multiprocessing.util._exit_function()
869
+ #
870
+ # The joining/termination mentioned above happens inside
871
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
872
+ # throws, the stack trace stored in the exception will prevent the
873
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
874
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
875
+ # its `__del__`, which starts the shutdown procedure, will not be
876
+ # called. That, in turn, means that workers aren't notified. Attempting
877
+ # to join in `_exit_function` will then result in a hang.
878
+ #
879
+ # For context, `_exit_function` is also registered as an `atexit` call.
880
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
881
+ # The code dates back to 2008 and there is no comment on the original
882
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
883
+ # the finally block and the `atexit` registration) that explains this.
884
+ #
885
+ #
886
+ # Finally, another choice is to just shutdown workers with logic in 1
887
+ # above whenever we see an error in `next`. This isn't ideal because
888
+ # a. It prevents users from using try-catch to resume data loading.
889
+ # b. It doesn't prevent hanging if users have references to the
890
+ # iterator.
891
+ #
892
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
893
+ #
894
+ # As shown above, the workers are set as daemonic children of the main
895
+ # process. However, automatic cleaning-up of such child processes only
896
+ # happens if the parent process exits gracefully (e.g., not via fatal
897
+ # signals like SIGKILL). So we must ensure that each process will exit
898
+ # even the process that should send/receive data to/from it were
899
+ # killed, i.e.,
900
+ #
901
+ # a. A process won't hang when getting from a queue.
902
+ #
903
+ # Even with carefully designed data dependencies (i.e., a `put()`
904
+ # always corresponding to a `get()`), hanging on `get()` can still
905
+ # happen when data in queue is corrupted (e.g., due to
906
+ # `cancel_join_thread` or unexpected exit).
907
+ #
908
+ # For child exit, we set a timeout whenever we try to get data
909
+ # from `data_queue`, and check the workers' status on each timeout
910
+ # and error.
911
+ # See `_DataLoaderiter._get_batch()` and
912
+ # `_DataLoaderiter._try_get_data()` for details.
913
+ #
914
+ # Additionally, for child exit on non-Windows platforms, we also
915
+ # register a SIGCHLD handler (which is supported on Windows) on
916
+ # the main process, which checks if any of the workers fail in the
917
+ # (Python) handler. This is more efficient and faster in detecting
918
+ # worker failures, compared to only using the above mechanism.
919
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
920
+ #
921
+ # For `.get()` calls where the sender(s) is not the workers, we
922
+ # guard them with timeouts, and check the status of the sender
923
+ # when timeout happens:
924
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
925
+ # checks the status of the main process.
926
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
927
+ # check `pin_memory_thread` status periodically until `.get()`
928
+ # returns or see that `pin_memory_thread` died.
929
+ #
930
+ # b. A process won't hang when putting into a queue;
931
+ #
932
+ # We use `mp.Queue` which has a separate background thread to put
933
+ # objects from an unbounded buffer array. The background thread is
934
+ # daemonic and usually automatically joined when the process
935
+ # *exits*.
936
+ #
937
+ # In case that the receiver has ended abruptly while
938
+ # reading from the pipe, the join will hang forever. The usual
939
+ # solution for this in Python is calling `q.cancel_join_thread`,
940
+ # which prevents automatically joining it when finalizing
941
+ # (exiting).
942
+ #
943
+ # Nonetheless, `cancel_join_thread` must only be called when the
944
+ # queue is **not** going to be read from or write into by another
945
+ # process, because it may hold onto a lock or leave corrupted data
946
+ # in the queue, leading other readers/writers to hang.
947
+ #
948
+ # Hence,
949
+ # + For worker processes, we only do so (for their output
950
+ # queues, i.e., `worker_result_queue`) before exiting.
951
+ # + For `pin_memory_thread`, its output queue `data_queue` is a
952
+ # `queue.Queue` that does blocking `put` if the queue is full.
953
+ # So there is no above problem, but as a result, in
954
+ # `_pin_memory_loop`, we do need to wrap the `put` in a loop
955
+ # that breaks not only upon success, but also when the main
956
+ # process stops reading, i.e., is shutting down.
957
+ # + For loader process, we `cancel_join_thread()` for all
958
+ # `_index_queues` because the whole purpose of workers and
959
+ # `pin_memory_thread` is to serve the loader process. If
960
+ # loader process is already exiting, we don't really care if
961
+ # the queues are corrupted.
962
+ #
963
+ #
964
+ # Now let's get back to 1:
965
+ # how we gracefully exit the workers when the last reference to the
966
+ # iterator is gone.
967
+ #
968
+ # To achieve this, we implement the following logic along with the design
969
+ # choices mentioned above:
970
+ #
971
+ # `workers_done_event`:
972
+ # A `multiprocessing.Event` shared among the main process and all worker
973
+ # processes. This is used to signal the workers that the iterator is
974
+ # shutting down. After it is set, they will not send processed data to
975
+ # queues anymore, and only wait for the final `None` before exiting.
976
+ # `done_event` isn't strictly needed. I.e., we can just check for `None`
977
+ # from the input queue, but it allows us to skip wasting resources
978
+ # processing data if we are already shutting down.
979
+ #
980
+ # `pin_memory_thread_done_event`:
981
+ # A `threading.Event` for a similar purpose to that of
982
+ # `workers_done_event`, but is for the `pin_memory_thread`. The reason
983
+ # that separate events are needed is that `pin_memory_thread` reads from
984
+ # the output queue of the workers. But the workers, upon seeing that
985
+ # `workers_done_event` is set, only wants to see the final `None`, and is
986
+ # not required to flush all data in the output queue (e.g., it may call
987
+ # `cancel_join_thread` on that queue if its `IterableDataset` iterator
988
+ # happens to exhaust coincidentally, which is out of the control of the
989
+ # main process). Thus, since we will exit `pin_memory_thread` before the
990
+ # workers (see below), two separete events are used.
991
+ #
992
+ # NOTE: In short, the protocol is that the main process will set these
993
+ # `done_event`s and then the corresponding processes/threads a `None`,
994
+ # and that they may exit at any time after receiving the `None`.
995
+ #
996
+ # NOTE: Using `None` as the final signal is valid, since normal data will
997
+ # always be a 2-tuple with the 1st element being the index of the data
998
+ # transferred (different from dataset index/key), and the 2nd being
999
+ # either the dataset key or the data sample (depending on which part
1000
+ # of the data model the queue is at).
1001
+ #
1002
+ # [ worker processes ]
1003
+ # While loader process is alive:
1004
+ # Get from `index_queue`.
1005
+ # If get anything else,
1006
+ # Check `workers_done_event`.
1007
+ # If set, continue to next iteration
1008
+ # i.e., keep getting until see the `None`, then exit.
1009
+ # Otherwise, process data:
1010
+ # If is fetching from an `IterableDataset` and the iterator
1011
+ # is exhausted, send an `_IterableDatasetStopIteration`
1012
+ # object to signal iteration end. The main process, upon
1013
+ # receiving such an object, will send `None` to this
1014
+ # worker and not use the corresponding `index_queue`
1015
+ # anymore.
1016
+ # If timed out,
1017
+ # No matter `workers_done_event` is set (still need to see `None`)
1018
+ # or not, must continue to next iteration.
1019
+ # (outside loop)
1020
+ # If `workers_done_event` is set, (this can be False with `IterableDataset`)
1021
+ # `data_queue.cancel_join_thread()`. (Everything is ending here:
1022
+ # main process won't read from it;
1023
+ # other workers will also call
1024
+ # `cancel_join_thread`.)
1025
+ #
1026
+ # [ pin_memory_thread ]
1027
+ # # No need to check main thread. If this thread is alive, the main loader
1028
+ # # thread must be alive, because this thread is set as daemonic.
1029
+ # While `pin_memory_thread_done_event` is not set:
1030
+ # Get from `index_queue`.
1031
+ # If timed out, continue to get in the next iteration.
1032
+ # Otherwise, process data.
1033
+ # While `pin_memory_thread_done_event` is not set:
1034
+ # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
1035
+ # If timed out, continue to put in the next iteration.
1036
+ # Otherwise, break, i.e., continuing to the out loop.
1037
+ #
1038
+ # NOTE: we don't check the status of the main thread because
1039
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
1040
+ # ends.
1041
+ # 2. in other cases, either the cleaning-up in __del__ or the
1042
+ # automatic exit of daemonic thread will take care of it.
1043
+ # This won't busy-wait either because `.get(timeout)` does not
1044
+ # busy-wait.
1045
+ #
1046
+ # [ main process ]
1047
+ # In the DataLoader Iter's `__del__`
1048
+ # b. Exit `pin_memory_thread`
1049
+ # i. Set `pin_memory_thread_done_event`.
1050
+ # ii Put `None` in `worker_result_queue`.
1051
+ # iii. Join the `pin_memory_thread`.
1052
+ # iv. `worker_result_queue.cancel_join_thread()`.
1053
+ #
1054
+ # c. Exit the workers.
1055
+ # i. Set `workers_done_event`.
1056
+ # ii. Put `None` in each worker's `index_queue`.
1057
+ # iii. Join the workers.
1058
+ # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
1059
+ #
1060
+ # NOTE: (c) is better placed after (b) because it may leave corrupted
1061
+ # data in `worker_result_queue`, which `pin_memory_thread`
1062
+ # reads from, in which case the `pin_memory_thread` can only
1063
+ # happen at timeing out, which is slow. Nonetheless, same thing
1064
+ # happens if a worker is killed by signal at unfortunate times,
1065
+ # but in other cases, we are better off having a non-corrupted
1066
+ # `worker_result_queue` for `pin_memory_thread`.
1067
+ #
1068
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
1069
+ # can be omitted
1070
+ #
1071
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
1072
+ # `None` from `index_queue`, but it allows us to skip wasting resources
1073
+ # processing indices already in `index_queue` if we are already shutting
1074
+ # down.
1075
+
1076
+ def __init__(self, loader):
1077
+ super().__init__(loader)
1078
+
1079
+ self._prefetch_factor = loader.prefetch_factor
1080
+
1081
+ assert self._num_workers > 0
1082
+ assert self._prefetch_factor > 0
1083
+
1084
+ if loader.multiprocessing_context is None:
1085
+ multiprocessing_context = multiprocessing
1086
+ else:
1087
+ multiprocessing_context = loader.multiprocessing_context
1088
+
1089
+ self._worker_init_fn = loader.worker_init_fn
1090
+
1091
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
1092
+ # Additional worker init function will take care of sharding in MP and Distributed
1093
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
1094
+ self._worker_init_fn = functools.partial(
1095
+ _sharding_worker_init_fn,
1096
+ self._worker_init_fn,
1097
+ self._world_size,
1098
+ self._rank,
1099
+ )
1100
+
1101
+ # No certainty which module multiprocessing_context is
1102
+ self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1103
+ self._worker_pids_set = False
1104
+ self._shutdown = False
1105
+ self._workers_done_event = multiprocessing_context.Event()
1106
+
1107
+ self._index_queues = []
1108
+ self._workers = []
1109
+ for i in range(self._num_workers):
1110
+ # No certainty which module multiprocessing_context is
1111
+ index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1112
+ # Need to `cancel_join_thread` here!
1113
+ # See sections (2) and (3b) above.
1114
+ index_queue.cancel_join_thread()
1115
+ w = multiprocessing_context.Process(
1116
+ target=_worker_loop,
1117
+ args=(
1118
+ self._dataset_kind,
1119
+ self._dataset,
1120
+ index_queue,
1121
+ self._worker_result_queue,
1122
+ self._workers_done_event,
1123
+ self._auto_collation,
1124
+ self._collate_fn,
1125
+ self._drop_last,
1126
+ self._base_seed,
1127
+ self._worker_init_fn,
1128
+ i,
1129
+ self._num_workers,
1130
+ self._persistent_workers,
1131
+ self._shared_seed,
1132
+ ),
1133
+ )
1134
+ w.daemon = True
1135
+ # NB: Process.start() actually take some time as it needs to
1136
+ # start a process and pass the arguments over via a pipe.
1137
+ # Therefore, we only add a worker to self._workers list after
1138
+ # it started, so that we do not call .join() if program dies
1139
+ # before it starts, and __del__ tries to join but will get:
1140
+ # AssertionError: can only join a started process.
1141
+ w.start()
1142
+ self._index_queues.append(index_queue)
1143
+ self._workers.append(w)
1144
+
1145
+ if self._pin_memory:
1146
+ self._pin_memory_thread_done_event = threading.Event()
1147
+
1148
+ # Queue is not type-annotated
1149
+ self._data_queue = queue.Queue() # type: ignore[var-annotated]
1150
+ if self._pin_memory_device == "xpu":
1151
+ current_device = torch.xpu.current_device() # type: ignore[attr-defined]
1152
+ else:
1153
+ current_device = torch.cuda.current_device() # choose cuda for default
1154
+ pin_memory_thread = threading.Thread(
1155
+ target=_utils.pin_memory._pin_memory_loop,
1156
+ args=(
1157
+ self._worker_result_queue,
1158
+ self._data_queue,
1159
+ current_device,
1160
+ self._pin_memory_thread_done_event,
1161
+ self._pin_memory_device,
1162
+ ),
1163
+ )
1164
+ pin_memory_thread.daemon = True
1165
+ pin_memory_thread.start()
1166
+ # Similar to workers (see comment above), we only register
1167
+ # pin_memory_thread once it is started.
1168
+ self._pin_memory_thread = pin_memory_thread
1169
+ else:
1170
+ self._data_queue = self._worker_result_queue
1171
+
1172
+ # In some rare cases, persistent workers (daemonic processes)
1173
+ # would be terminated before `__del__` of iterator is invoked
1174
+ # when main process exits
1175
+ # It would cause failure when pin_memory_thread tries to read
1176
+ # corrupted data from worker_result_queue
1177
+ # atexit is used to shutdown thread and child processes in the
1178
+ # right sequence before main process exits
1179
+ if self._persistent_workers and self._pin_memory:
1180
+ import atexit
1181
+
1182
+ for w in self._workers:
1183
+ atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
1184
+
1185
+ # .pid can be None only before process is spawned (not the case, so ignore)
1186
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
1187
+ _utils.signal_handling._set_SIGCHLD_handler()
1188
+ self._worker_pids_set = True
1189
+ self._reset(loader, first_iter=True)
1190
+
1191
+ def _reset(self, loader, first_iter=False):
1192
+ super()._reset(loader, first_iter)
1193
+ self._send_idx = 0 # idx of the next task to be sent to workers
1194
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
1195
+ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
1196
+ # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
1197
+ # \ (worker_id, data) if data is already fetched (out-of-order)
1198
+ self._task_info = {}
1199
+ self._tasks_outstanding = (
1200
+ 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
1201
+ )
1202
+ # A list of booleans representing whether each worker still has work to
1203
+ # do, i.e., not having exhausted its iterable dataset object. It always
1204
+ # contains all `True`s if not using an iterable-style dataset
1205
+ # (i.e., if kind != Iterable).
1206
+ # Not that this indicates that a worker still has work to do *for this epoch*.
1207
+ # It does not mean that a worker is dead. In case of `_persistent_workers`,
1208
+ # the worker will be reset to available in the next epoch.
1209
+ self._workers_status = [True for i in range(self._num_workers)]
1210
+ # Reset the worker queue cycle so it resumes next epoch at worker 0
1211
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
1212
+ # We resume the prefetching in case it was enabled
1213
+ if not first_iter:
1214
+ for idx in range(self._num_workers):
1215
+ self._index_queues[idx].put(
1216
+ _utils.worker._ResumeIteration(self._shared_seed)
1217
+ )
1218
+ resume_iteration_cnt = self._num_workers
1219
+ while resume_iteration_cnt > 0:
1220
+ return_idx, return_data = self._get_data()
1221
+ if isinstance(return_idx, _utils.worker._ResumeIteration):
1222
+ assert return_data is None
1223
+ resume_iteration_cnt -= 1
1224
+ # prime the prefetch loop
1225
+ for _ in range(self._prefetch_factor * self._num_workers):
1226
+ self._try_put_index()
1227
+
1228
+ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1229
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
1230
+ # This can also be used as inner loop of fetching without timeout, with
1231
+ # the sender status as the loop condition.
1232
+ #
1233
+ # This raises a `RuntimeError` if any worker died expectedly. This error
1234
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
1235
+ # (only for non-Windows platforms), or the manual check below on errors
1236
+ # and timeouts.
1237
+ #
1238
+ # Returns a 2-tuple:
1239
+ # (bool: whether successfully get data, any: data if successful else None)
1240
+ try:
1241
+ data = self._data_queue.get(timeout=timeout)
1242
+ return (True, data)
1243
+ except Exception as e:
1244
+ # At timeout and error, we manually check whether any worker has
1245
+ # failed. Note that this is the only mechanism for Windows to detect
1246
+ # worker failures.
1247
+ failed_workers = []
1248
+ for worker_id, w in enumerate(self._workers):
1249
+ if self._workers_status[worker_id] and not w.is_alive():
1250
+ failed_workers.append(w)
1251
+ self._mark_worker_as_unavailable(worker_id)
1252
+ if len(failed_workers) > 0:
1253
+ pids_str = ", ".join(str(w.pid) for w in failed_workers)
1254
+ raise RuntimeError(
1255
+ "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
1256
+ ) from e
1257
+ if isinstance(e, queue.Empty):
1258
+ return (False, None)
1259
+ import errno
1260
+ import tempfile
1261
+
1262
+ try:
1263
+ # Raise an exception if we are this close to the FDs limit.
1264
+ # Apparently, trying to open only one file is not a sufficient
1265
+ # test.
1266
+ # See NOTE [ DataLoader on Linux and open files limit ]
1267
+ fds_limit_margin = 10
1268
+ fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
1269
+ except OSError as e:
1270
+ if e.errno == errno.EMFILE:
1271
+ raise RuntimeError(
1272
+ "Too many open files. Communication with the"
1273
+ " workers is no longer possible. Please increase the"
1274
+ " limit using `ulimit -n` in the shell or change the"
1275
+ " sharing strategy by calling"
1276
+ " `torch.multiprocessing.set_sharing_strategy('file_system')`"
1277
+ " at the beginning of your code"
1278
+ ) from None
1279
+ raise
1280
+
1281
+ # NOTE [ DataLoader on Linux and open files limit ]
1282
+ #
1283
+ # On Linux when DataLoader is used with multiprocessing we pass the data between
1284
+ # the root process and the workers through SHM files. We remove those files from
1285
+ # the filesystem as soon as they are created and keep them alive by
1286
+ # passing around their file descriptors through AF_UNIX sockets. (See
1287
+ # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
1288
+ # the wiki (https://github.com/pytorch/pytorch/wiki).)
1289
+ #
1290
+ # This sometimes leads us to exceeding the open files limit. When that happens,
1291
+ # and the offending file descriptor is coming over a socket, the `socket` Python
1292
+ # package silently strips the file descriptor from the message, setting only the
1293
+ # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
1294
+ # it _indicates that some control data were discarded due to lack of space in
1295
+ # the buffer for ancillary data_). This might reflect the C implementation of
1296
+ # AF_UNIX sockets.
1297
+ #
1298
+ # This behaviour can be reproduced with the script and instructions at the
1299
+ # bottom of this note.
1300
+ #
1301
+ # When that happens, the standard Python `multiprocessing` (and not
1302
+ # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
1303
+ #
1304
+ # Sometimes, instead of the FD being stripped, you may get an `OSError:
1305
+ # Too many open files`, both in the script below and in DataLoader. However,
1306
+ # this is rare and seems to be nondeterministic.
1307
+ #
1308
+ #
1309
+ # #!/usr/bin/env python3
1310
+ # import sys
1311
+ # import socket
1312
+ # import os
1313
+ # import array
1314
+ # import shutil
1315
+ # import socket
1316
+ #
1317
+ #
1318
+ # if len(sys.argv) != 4:
1319
+ # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
1320
+ # sys.exit(1)
1321
+ #
1322
+ # if __name__ == '__main__':
1323
+ # dirname = sys.argv[1]
1324
+ # sock_path = dirname + "/sock"
1325
+ # iterations = int(sys.argv[2])
1326
+ # def dummy_path(i):
1327
+ # return dirname + "/" + str(i) + ".dummy"
1328
+ #
1329
+ #
1330
+ # if sys.argv[3] == 'send':
1331
+ # while not os.path.exists(sock_path):
1332
+ # pass
1333
+ # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1334
+ # client.connect(sock_path)
1335
+ # for i in range(iterations):
1336
+ # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
1337
+ # ancdata = array.array('i', [fd])
1338
+ # msg = bytes([i % 256])
1339
+ # print("Sending fd ", fd, " (iteration #", i, ")")
1340
+ # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
1341
+ #
1342
+ #
1343
+ # else:
1344
+ # assert sys.argv[3] == 'recv'
1345
+ #
1346
+ # if os.path.exists(dirname):
1347
+ # raise Exception("Directory exists")
1348
+ #
1349
+ # os.mkdir(dirname)
1350
+ #
1351
+ # print("Opening socket...")
1352
+ # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1353
+ # server.bind(sock_path)
1354
+ #
1355
+ # print("Listening...")
1356
+ # for i in range(iterations):
1357
+ # a = array.array('i')
1358
+ # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
1359
+ # assert(len(ancdata) == 1)
1360
+ # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
1361
+ # a.frombytes(cmsg_data)
1362
+ # print("Received fd ", a[0], " (iteration #", i, ")")
1363
+ #
1364
+ # shutil.rmtree(dirname)
1365
+ #
1366
+ # Steps to reproduce:
1367
+ #
1368
+ # 1. Run two shells and set lower file descriptor limit in the receiving one:
1369
+ # (shell1) ulimit -n 1020
1370
+ # (shell2) ulimit -n 1022
1371
+ #
1372
+ # 2. Run the script above with the `recv` option in the first shell
1373
+ # (shell1) ./test_socket.py sock_tmp 1017 recv
1374
+ #
1375
+ # 3. Run the script with the `send` option in the second shell:
1376
+ # (shell2) ./test_socket.py sock_tmp 1017 send
1377
+
1378
+ def _get_data(self):
1379
+ # Fetches data from `self._data_queue`.
1380
+ #
1381
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
1382
+ # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
1383
+ # in a loop. This is the only mechanism to detect worker failures for
1384
+ # Windows. For other platforms, a SIGCHLD handler is also used for
1385
+ # worker failure detection.
1386
+ #
1387
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
1388
+ # died at timeouts.
1389
+ if self._timeout > 0:
1390
+ success, data = self._try_get_data(self._timeout)
1391
+ if success:
1392
+ return data
1393
+ else:
1394
+ raise RuntimeError(
1395
+ "DataLoader timed out after {} seconds".format(self._timeout)
1396
+ )
1397
+ elif self._pin_memory:
1398
+ while self._pin_memory_thread.is_alive():
1399
+ success, data = self._try_get_data()
1400
+ if success:
1401
+ return data
1402
+ else:
1403
+ # while condition is false, i.e., pin_memory_thread died.
1404
+ raise RuntimeError("Pin memory thread exited unexpectedly")
1405
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1406
+ # need to call `.task_done()` because we don't use `.join()`.
1407
+ else:
1408
+ while True:
1409
+ success, data = self._try_get_data()
1410
+ if success:
1411
+ return data
1412
+
1413
+ def _next_data(self):
1414
+ while True:
1415
+ # If the worker responsible for `self._rcvd_idx` has already ended
1416
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
1417
+ # we try to advance `self._rcvd_idx` to find the next valid index.
1418
+ #
1419
+ # This part needs to run in the loop because both the `self._get_data()`
1420
+ # call and `_IterableDatasetStopIteration` check below can mark
1421
+ # extra worker(s) as dead.
1422
+ while self._rcvd_idx < self._send_idx:
1423
+ info = self._task_info[self._rcvd_idx]
1424
+ worker_id = info[0]
1425
+ if (
1426
+ len(info) == 2 or self._workers_status[worker_id]
1427
+ ): # has data or is still active
1428
+ break
1429
+ del self._task_info[self._rcvd_idx]
1430
+ self._rcvd_idx += 1
1431
+ else:
1432
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
1433
+ if not self._persistent_workers:
1434
+ self._shutdown_workers()
1435
+ raise StopIteration
1436
+
1437
+ # Now `self._rcvd_idx` is the batch index we want to fetch
1438
+
1439
+ # Check if the next sample has already been generated
1440
+ if len(self._task_info[self._rcvd_idx]) == 2:
1441
+ data = self._task_info.pop(self._rcvd_idx)[1]
1442
+ return self._process_data(data)
1443
+
1444
+ assert not self._shutdown and self._tasks_outstanding > 0
1445
+ idx, data = self._get_data()
1446
+ self._tasks_outstanding -= 1
1447
+ if self._dataset_kind == _DatasetKind.Iterable:
1448
+ # Check for _IterableDatasetStopIteration
1449
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
1450
+ if self._persistent_workers:
1451
+ self._workers_status[data.worker_id] = False
1452
+ else:
1453
+ self._mark_worker_as_unavailable(data.worker_id)
1454
+ self._try_put_index()
1455
+ continue
1456
+
1457
+ if idx != self._rcvd_idx:
1458
+ # store out-of-order samples
1459
+ self._task_info[idx] += (data,)
1460
+ else:
1461
+ del self._task_info[idx]
1462
+ return self._process_data(data)
1463
+
1464
+ def _try_put_index(self):
1465
+ assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
1466
+
1467
+ try:
1468
+ index = self._next_index()
1469
+ except StopIteration:
1470
+ return
1471
+ for _ in range(self._num_workers): # find the next active worker, if any
1472
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
1473
+ if self._workers_status[worker_queue_idx]:
1474
+ break
1475
+ else:
1476
+ # not found (i.e., didn't break)
1477
+ return
1478
+
1479
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
1480
+ self._task_info[self._send_idx] = (worker_queue_idx,)
1481
+ self._tasks_outstanding += 1
1482
+ self._send_idx += 1
1483
+
1484
+ def _process_data(self, data):
1485
+ self._rcvd_idx += 1
1486
+ self._try_put_index()
1487
+ if isinstance(data, ExceptionWrapper):
1488
+ data.reraise()
1489
+ return data
1490
+
1491
+ def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
1492
+ # Mark a worker as having finished its work e.g., due to
1493
+ # exhausting an `IterableDataset`. This should be used only when this
1494
+ # `_MultiProcessingDataLoaderIter` is going to continue running.
1495
+
1496
+ assert self._workers_status[worker_id] or (
1497
+ self._persistent_workers and shutdown
1498
+ )
1499
+
1500
+ # Signal termination to that specific worker.
1501
+ q = self._index_queues[worker_id]
1502
+ # Indicate that no more data will be put on this queue by the current
1503
+ # process.
1504
+ q.put(None)
1505
+
1506
+ # Note that we don't actually join the worker here, nor do we remove the
1507
+ # worker's pid from C side struct because (1) joining may be slow, and
1508
+ # (2) since we don't join, the worker may still raise error, and we
1509
+ # prefer capturing those, rather than ignoring them, even though they
1510
+ # are raised after the worker has finished its job.
1511
+ # Joinning is deferred to `_shutdown_workers`, which it is called when
1512
+ # all workers finish their jobs (e.g., `IterableDataset` replicas) or
1513
+ # when this iterator is garbage collected.
1514
+
1515
+ self._workers_status[worker_id] = False
1516
+
1517
+ assert self._workers_done_event.is_set() == shutdown
1518
+
1519
+ def _shutdown_workers(self):
1520
+ # Called when shutting down this `_MultiProcessingDataLoaderIter`.
1521
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
1522
+ # the logic of this function.
1523
+ if (
1524
+ _utils is None
1525
+ or _utils.python_exit_status is True
1526
+ or _utils.python_exit_status is None
1527
+ ):
1528
+ # See (2) of the note. If Python is shutting down, do no-op.
1529
+ return
1530
+ # Normal exit when last reference is gone / iterator is depleted.
1531
+ # See (1) and the second half of the note.
1532
+ if not self._shutdown:
1533
+ self._shutdown = True
1534
+ try:
1535
+ # Normal exit when last reference is gone / iterator is depleted.
1536
+ # See (1) and the second half of the note.
1537
+
1538
+ # Exit `pin_memory_thread` first because exiting workers may leave
1539
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
1540
+ # reads from.
1541
+ if hasattr(self, "_pin_memory_thread"):
1542
+ # Use hasattr in case error happens before we set the attribute.
1543
+ self._pin_memory_thread_done_event.set()
1544
+ # Send something to pin_memory_thread in case it is waiting
1545
+ # so that it can wake up and check `pin_memory_thread_done_event`
1546
+ self._worker_result_queue.put((None, None))
1547
+ self._pin_memory_thread.join()
1548
+ self._worker_result_queue.cancel_join_thread()
1549
+ self._worker_result_queue.close()
1550
+
1551
+ # Exit workers now.
1552
+ self._workers_done_event.set()
1553
+ for worker_id in range(len(self._workers)):
1554
+ # Get number of workers from `len(self._workers)` instead of
1555
+ # `self._num_workers` in case we error before starting all
1556
+ # workers.
1557
+ # If we are using workers_status with persistent_workers
1558
+ # we have to shut it down because the worker is paused
1559
+ if self._persistent_workers or self._workers_status[worker_id]:
1560
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
1561
+ for w in self._workers:
1562
+ # We should be able to join here, but in case anything went
1563
+ # wrong, we set a timeout and if the workers fail to join,
1564
+ # they are killed in the `finally` block.
1565
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1566
+ for q in self._index_queues:
1567
+ q.cancel_join_thread()
1568
+ q.close()
1569
+ finally:
1570
+ # Even though all this function does is putting into queues that
1571
+ # we have called `cancel_join_thread` on, weird things can
1572
+ # happen when a worker is killed by a signal, e.g., hanging in
1573
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
1574
+ # and remove pids from the C side data structure only at the
1575
+ # end.
1576
+ #
1577
+ if self._worker_pids_set:
1578
+ _utils.signal_handling._remove_worker_pids(id(self))
1579
+ self._worker_pids_set = False
1580
+ for w in self._workers:
1581
+ if w.is_alive():
1582
+ # Existing mechanisms try to make the workers exit
1583
+ # peacefully, but in case that we unfortunately reach
1584
+ # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
1585
+ # we kill the worker.
1586
+ w.terminate()
1587
+
1588
+ # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
1589
+ @staticmethod
1590
+ def _clean_up_worker(w):
1591
+ try:
1592
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1593
+ finally:
1594
+ if w.is_alive():
1595
+ w.terminate()
1596
+
1597
+ def __del__(self):
1598
+ self._shutdown_workers()
src/efficientvit/apps/data_provider/random_resolution/_data_worker.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""""This file is based on torch/utils/data/_utils/worker.py
2
+
3
+ Contains definitions of the methods used by the _BaseDataLoaderIter workers.
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import os
9
+ import queue
10
+ import random
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING, Optional, Union
13
+
14
+ import torch
15
+ from torch._utils import ExceptionWrapper
16
+ from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS,
17
+ MP_STATUS_CHECK_INTERVAL, signal_handling)
18
+
19
+ if TYPE_CHECKING:
20
+ from torch.utils.data import Dataset
21
+
22
+ from .controller import RRSController
23
+
24
+ if IS_WINDOWS:
25
+ import ctypes
26
+ from ctypes.wintypes import BOOL, DWORD, HANDLE
27
+
28
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
29
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
30
+ # of the manager and ask if the process status has changed.
31
+ class ManagerWatchdog:
32
+ def __init__(self):
33
+ self.manager_pid = os.getppid()
34
+
35
+ # mypy cannot detect this code is windows only
36
+ self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
37
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
38
+ self.kernel32.OpenProcess.restype = HANDLE
39
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
40
+ self.kernel32.WaitForSingleObject.restype = DWORD
41
+
42
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
43
+ SYNCHRONIZE = 0x00100000
44
+ self.manager_handle = self.kernel32.OpenProcess(
45
+ SYNCHRONIZE, 0, self.manager_pid
46
+ )
47
+
48
+ if not self.manager_handle:
49
+ raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
50
+
51
+ self.manager_dead = False
52
+
53
+ def is_alive(self):
54
+ if not self.manager_dead:
55
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
56
+ self.manager_dead = (
57
+ self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
58
+ )
59
+ return not self.manager_dead
60
+
61
+ else:
62
+
63
+ class ManagerWatchdog: # type: ignore[no-redef]
64
+ def __init__(self):
65
+ self.manager_pid = os.getppid()
66
+ self.manager_dead = False
67
+
68
+ def is_alive(self):
69
+ if not self.manager_dead:
70
+ self.manager_dead = os.getppid() != self.manager_pid
71
+ return not self.manager_dead
72
+
73
+
74
+ _worker_info = None
75
+
76
+
77
+ class WorkerInfo:
78
+ id: int
79
+ num_workers: int
80
+ seed: int
81
+ dataset: "Dataset"
82
+ __initialized = False
83
+
84
+ def __init__(self, **kwargs):
85
+ for k, v in kwargs.items():
86
+ setattr(self, k, v)
87
+ self.__keys = tuple(kwargs.keys())
88
+ self.__initialized = True
89
+
90
+ def __setattr__(self, key, val):
91
+ if self.__initialized:
92
+ raise RuntimeError(
93
+ "Cannot assign attributes to {} objects".format(self.__class__.__name__)
94
+ )
95
+ return super().__setattr__(key, val)
96
+
97
+ def __repr__(self):
98
+ items = []
99
+ for k in self.__keys:
100
+ items.append("{}={}".format(k, getattr(self, k)))
101
+ return "{}({})".format(self.__class__.__name__, ", ".join(items))
102
+
103
+
104
+ def get_worker_info() -> Optional[WorkerInfo]:
105
+ r"""Returns the information about the current
106
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
107
+
108
+ When called in a worker, this returns an object guaranteed to have the
109
+ following attributes:
110
+
111
+ * :attr:`id`: the current worker id.
112
+ * :attr:`num_workers`: the total number of workers.
113
+ * :attr:`seed`: the random seed set for the current worker. This value is
114
+ determined by main process RNG and the worker id. See
115
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
116
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
117
+ that this will be a different object in a different process than the one
118
+ in the main process.
119
+
120
+ When called in the main process, this returns ``None``.
121
+
122
+ .. note::
123
+ When used in a :attr:`worker_init_fn` passed over to
124
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
125
+ set up each worker process differently, for instance, using ``worker_id``
126
+ to configure the ``dataset`` object to only read a specific fraction of a
127
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
128
+ code.
129
+ """
130
+ return _worker_info
131
+
132
+
133
+ r"""Dummy class used to signal the end of an IterableDataset"""
134
+
135
+
136
+ @dataclass(frozen=True)
137
+ class _IterableDatasetStopIteration:
138
+ worker_id: int
139
+
140
+
141
+ r"""Dummy class used to resume the fetching when worker reuse is enabled"""
142
+
143
+
144
+ @dataclass(frozen=True)
145
+ class _ResumeIteration:
146
+ seed: Optional[int] = None
147
+
148
+
149
+ # The function `_generate_state` is adapted from `numpy.random.SeedSequence`
150
+ # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
151
+ # It's MIT licensed, here is the copyright:
152
+
153
+ # Copyright (c) 2015 Melissa E. O'Neill
154
+ # Copyright (c) 2019 NumPy Developers
155
+ #
156
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
157
+ # of this software and associated documentation files (the "Software"), to deal
158
+ # in the Software without restriction, including without limitation the rights
159
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
160
+ # copies of the Software, and to permit persons to whom the Software is
161
+ # furnished to do so, subject to the following conditions:
162
+ #
163
+ # The above copyright notice and this permission notice shall be included in
164
+ # all copies or substantial portions of the Software.
165
+ #
166
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
167
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
168
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
169
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
170
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
171
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
172
+ # SOFTWARE.
173
+
174
+
175
+ # This function generates an array of int32 as the seed for
176
+ # `numpy.random`, in order to prevent state collision due to same
177
+ # seed and algorithm for `numpy.random` and `random` modules.
178
+ def _generate_state(base_seed, worker_id):
179
+ INIT_A = 0x43B0D7E5
180
+ MULT_A = 0x931E8875
181
+ INIT_B = 0x8B51F9DD
182
+ MULT_B = 0x58F38DED
183
+ MIX_MULT_L = 0xCA01F9DD
184
+ MIX_MULT_R = 0x4973F715
185
+ XSHIFT = 4 * 8 // 2
186
+ MASK32 = 0xFFFFFFFF
187
+
188
+ entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
189
+ pool = [0] * 4
190
+
191
+ hash_const_A = INIT_A
192
+
193
+ def hash(value):
194
+ nonlocal hash_const_A
195
+ value = (value ^ hash_const_A) & MASK32
196
+ hash_const_A = (hash_const_A * MULT_A) & MASK32
197
+ value = (value * hash_const_A) & MASK32
198
+ value = (value ^ (value >> XSHIFT)) & MASK32
199
+ return value
200
+
201
+ def mix(x, y):
202
+ result_x = (MIX_MULT_L * x) & MASK32
203
+ result_y = (MIX_MULT_R * y) & MASK32
204
+ result = (result_x - result_y) & MASK32
205
+ result = (result ^ (result >> XSHIFT)) & MASK32
206
+ return result
207
+
208
+ # Add in the entropy to the pool.
209
+ for i in range(len(pool)):
210
+ pool[i] = hash(entropy[i])
211
+
212
+ # Mix all bits together so late bits can affect earlier bits.
213
+ for i_src in range(len(pool)):
214
+ for i_dst in range(len(pool)):
215
+ if i_src != i_dst:
216
+ pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
217
+
218
+ hash_const_B = INIT_B
219
+ state = []
220
+ for i_dst in range(4):
221
+ data_val = pool[i_dst]
222
+ data_val = (data_val ^ hash_const_B) & MASK32
223
+ hash_const_B = (hash_const_B * MULT_B) & MASK32
224
+ data_val = (data_val * hash_const_B) & MASK32
225
+ data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
226
+ state.append(data_val)
227
+ return state
228
+
229
+
230
+ def _worker_loop(
231
+ dataset_kind,
232
+ dataset,
233
+ index_queue,
234
+ data_queue,
235
+ done_event,
236
+ auto_collation,
237
+ collate_fn,
238
+ drop_last,
239
+ base_seed,
240
+ init_fn,
241
+ worker_id,
242
+ num_workers,
243
+ persistent_workers,
244
+ shared_seed,
245
+ ):
246
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
247
+ # logic of this function.
248
+
249
+ try:
250
+ # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
251
+ # module's handlers are executed after Python returns from C low-level
252
+ # handlers, likely when the same fatal signal had already happened
253
+ # again.
254
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
255
+ signal_handling._set_worker_signal_handlers()
256
+
257
+ torch.set_num_threads(1)
258
+ seed = base_seed + worker_id
259
+ random.seed(seed)
260
+ torch.manual_seed(seed)
261
+ if HAS_NUMPY:
262
+ np_seed = _generate_state(base_seed, worker_id)
263
+ import numpy as np
264
+
265
+ np.random.seed(np_seed)
266
+
267
+ from torch.utils.data import IterDataPipe
268
+ from torch.utils.data.graph_settings import apply_random_seed
269
+
270
+ shared_rng = torch.Generator()
271
+ if isinstance(dataset, IterDataPipe):
272
+ assert shared_seed is not None
273
+ shared_rng.manual_seed(shared_seed)
274
+ dataset = apply_random_seed(dataset, shared_rng)
275
+
276
+ global _worker_info
277
+ _worker_info = WorkerInfo(
278
+ id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
279
+ )
280
+
281
+ from torch.utils.data import _DatasetKind
282
+
283
+ init_exception = None
284
+
285
+ try:
286
+ if init_fn is not None:
287
+ init_fn(worker_id)
288
+
289
+ fetcher = _DatasetKind.create_fetcher(
290
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
291
+ )
292
+ except Exception:
293
+ init_exception = ExceptionWrapper(
294
+ where="in DataLoader worker process {}".format(worker_id)
295
+ )
296
+
297
+ # When using Iterable mode, some worker can exit earlier than others due
298
+ # to the IterableDataset behaving differently for different workers.
299
+ # When such things happen, an `_IterableDatasetStopIteration` object is
300
+ # sent over to the main process with the ID of this worker, so that the
301
+ # main process won't send more tasks to this worker, and will send
302
+ # `None` to this worker to properly exit it.
303
+ #
304
+ # Note that we cannot set `done_event` from a worker as it is shared
305
+ # among all processes. Instead, we set the `iteration_end` flag to
306
+ # signify that the iterator is exhausted. When either `done_event` or
307
+ # `iteration_end` is set, we skip all processing step and just wait for
308
+ # `None`.
309
+ iteration_end = False
310
+
311
+ watchdog = ManagerWatchdog()
312
+
313
+ while watchdog.is_alive():
314
+ try:
315
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
316
+ except queue.Empty:
317
+ continue
318
+ if isinstance(r, _ResumeIteration):
319
+ # Acknowledge the main process
320
+ data_queue.put((r, None))
321
+ iteration_end = False
322
+
323
+ if isinstance(dataset, IterDataPipe):
324
+ assert r.seed is not None
325
+ shared_rng.manual_seed(r.seed)
326
+ dataset = apply_random_seed(dataset, shared_rng)
327
+
328
+ # Recreate the fetcher for worker-reuse policy
329
+ fetcher = _DatasetKind.create_fetcher(
330
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
331
+ )
332
+ continue
333
+ elif r is None:
334
+ # Received the final signal
335
+ assert done_event.is_set() or iteration_end
336
+ break
337
+ elif done_event.is_set() or iteration_end:
338
+ # `done_event` is set. But I haven't received the final signal
339
+ # (None) yet. I will keep continuing until get it, and skip the
340
+ # processing steps.
341
+ continue
342
+ idx, index = r
343
+ """ Added """
344
+ RRSController.sample_resolution(batch_id=idx)
345
+ """ Added """
346
+ data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
347
+ if init_exception is not None:
348
+ data = init_exception
349
+ init_exception = None
350
+ else:
351
+ try:
352
+ data = fetcher.fetch(index)
353
+ except Exception as e:
354
+ if (
355
+ isinstance(e, StopIteration)
356
+ and dataset_kind == _DatasetKind.Iterable
357
+ ):
358
+ data = _IterableDatasetStopIteration(worker_id)
359
+ # Set `iteration_end`
360
+ # (1) to save future `next(...)` calls, and
361
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
362
+ iteration_end = True
363
+ else:
364
+ # It is important that we don't store exc_info in a variable.
365
+ # `ExceptionWrapper` does the correct thing.
366
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
367
+ data = ExceptionWrapper(
368
+ where="in DataLoader worker process {}".format(worker_id)
369
+ )
370
+ data_queue.put((idx, data))
371
+ del data, idx, index, r # save memory
372
+ except KeyboardInterrupt:
373
+ # Main process will raise KeyboardInterrupt anyways.
374
+ pass
375
+ if done_event.is_set():
376
+ data_queue.cancel_join_thread()
377
+ data_queue.close()
src/efficientvit/apps/data_provider/random_resolution/controller.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ import torchvision.transforms.functional as F
10
+
11
+ from src.efficientvit.models.utils import torch_random_choices
12
+
13
+ __all__ = [
14
+ "RRSController",
15
+ "get_interpolate",
16
+ "MyRandomResizedCrop",
17
+ ]
18
+
19
+
20
+ class RRSController:
21
+ ACTIVE_SIZE = (224, 224)
22
+ IMAGE_SIZE_LIST = [(224, 224)]
23
+
24
+ CHOICE_LIST = None
25
+
26
+ @staticmethod
27
+ def get_candidates() -> list[tuple[int, int]]:
28
+ return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
29
+
30
+ @staticmethod
31
+ def sample_resolution(batch_id: int) -> None:
32
+ RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
33
+
34
+ @staticmethod
35
+ def set_epoch(epoch: int, batch_per_epoch: int) -> None:
36
+ g = torch.Generator()
37
+ g.manual_seed(epoch)
38
+ RRSController.CHOICE_LIST = torch_random_choices(
39
+ RRSController.get_candidates(),
40
+ g,
41
+ batch_per_epoch,
42
+ )
43
+
44
+
45
+ def get_interpolate(name: str) -> F.InterpolationMode:
46
+ mapping = {
47
+ "nearest": F.InterpolationMode.NEAREST,
48
+ "bilinear": F.InterpolationMode.BILINEAR,
49
+ "bicubic": F.InterpolationMode.BICUBIC,
50
+ "box": F.InterpolationMode.BOX,
51
+ "hamming": F.InterpolationMode.HAMMING,
52
+ "lanczos": F.InterpolationMode.LANCZOS,
53
+ }
54
+ if name in mapping:
55
+ return mapping[name]
56
+ elif name == "random":
57
+ return torch_random_choices(
58
+ [
59
+ F.InterpolationMode.NEAREST,
60
+ F.InterpolationMode.BILINEAR,
61
+ F.InterpolationMode.BICUBIC,
62
+ F.InterpolationMode.BOX,
63
+ F.InterpolationMode.HAMMING,
64
+ F.InterpolationMode.LANCZOS,
65
+ ],
66
+ )
67
+ else:
68
+ raise NotImplementedError
69
+
70
+
71
+ class MyRandomResizedCrop(transforms.RandomResizedCrop):
72
+ def __init__(
73
+ self,
74
+ scale=(0.08, 1.0),
75
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
76
+ interpolation: str = "random",
77
+ ):
78
+ super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
79
+ self.interpolation = interpolation
80
+
81
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
82
+ i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
83
+ target_size = RRSController.ACTIVE_SIZE
84
+ return F.resized_crop(
85
+ img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
86
+ )
87
+
88
+ def __repr__(self) -> str:
89
+ format_string = self.__class__.__name__
90
+ format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
91
+ format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
92
+ format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
93
+ format_string += f"\tinterpolation={self.interpolation})"
94
+ return format_string
src/efficientvit/apps/setup.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+ import time
7
+ from copy import deepcopy
8
+
9
+ import torch.backends.cudnn
10
+ import torch.distributed
11
+ import torch.nn as nn
12
+
13
+ from src.efficientvit.apps.data_provider import DataProvider
14
+ from src.efficientvit.apps.trainer.run_config import RunConfig
15
+ from src.efficientvit.apps.utils import (dist_init, dump_config,
16
+ get_dist_local_rank, get_dist_rank,
17
+ get_dist_size, init_modules, is_master,
18
+ load_config, partial_update_config,
19
+ zero_last_gamma)
20
+ from src.efficientvit.models.utils import (build_kwargs_from_config,
21
+ load_state_dict_from_file)
22
+
23
+ __all__ = [
24
+ "save_exp_config",
25
+ "setup_dist_env",
26
+ "setup_seed",
27
+ "setup_exp_config",
28
+ "setup_data_provider",
29
+ "setup_run_config",
30
+ "init_model",
31
+ ]
32
+
33
+
34
+ def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
35
+ if not is_master():
36
+ return
37
+ dump_config(exp_config, os.path.join(path, name))
38
+
39
+
40
+ def setup_dist_env(gpu: str or None = None) -> None:
41
+ if gpu is not None:
42
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu
43
+ if not torch.distributed.is_initialized():
44
+ dist_init()
45
+ torch.backends.cudnn.benchmark = True
46
+ torch.cuda.set_device(get_dist_local_rank())
47
+
48
+
49
+ def setup_seed(manual_seed: int, resume: bool) -> None:
50
+ if resume:
51
+ manual_seed = int(time.time())
52
+ manual_seed = get_dist_rank() + manual_seed
53
+ torch.manual_seed(manual_seed)
54
+ torch.cuda.manual_seed_all(manual_seed)
55
+
56
+
57
+ def setup_exp_config(
58
+ config_path: str, recursive=True, opt_args: dict or None = None
59
+ ) -> dict:
60
+ # load config
61
+ if not os.path.isfile(config_path):
62
+ raise ValueError(config_path)
63
+
64
+ fpaths = [config_path]
65
+ if recursive:
66
+ extension = os.path.splitext(config_path)[1]
67
+ while os.path.dirname(config_path) != config_path:
68
+ config_path = os.path.dirname(config_path)
69
+ fpath = os.path.join(config_path, "default" + extension)
70
+ if os.path.isfile(fpath):
71
+ fpaths.append(fpath)
72
+ fpaths = fpaths[::-1]
73
+
74
+ default_config = load_config(fpaths[0])
75
+ exp_config = deepcopy(default_config)
76
+ for fpath in fpaths[1:]:
77
+ partial_update_config(exp_config, load_config(fpath))
78
+ # update config via args
79
+ if opt_args is not None:
80
+ partial_update_config(exp_config, opt_args)
81
+
82
+ return exp_config
83
+
84
+
85
+ def setup_data_provider(
86
+ exp_config: dict,
87
+ data_provider_classes: list[type[DataProvider]],
88
+ is_distributed: bool = True,
89
+ ) -> DataProvider:
90
+ dp_config = exp_config["data_provider"]
91
+ dp_config["num_replicas"] = get_dist_size() if is_distributed else None
92
+ dp_config["rank"] = get_dist_rank() if is_distributed else None
93
+ dp_config["test_batch_size"] = (
94
+ dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
95
+ )
96
+ dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
97
+ "base_batch_size"
98
+ ]
99
+
100
+ data_provider_lookup = {
101
+ provider.name: provider for provider in data_provider_classes
102
+ }
103
+ data_provider_class = data_provider_lookup[dp_config["dataset"]]
104
+
105
+ data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
106
+ data_provider = data_provider_class(**data_provider_kwargs)
107
+ return data_provider
108
+
109
+
110
+ def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
111
+ exp_config["run_config"]["init_lr"] = (
112
+ exp_config["run_config"]["base_lr"] * get_dist_size()
113
+ )
114
+
115
+ run_config = run_config_cls(**exp_config["run_config"])
116
+
117
+ return run_config
118
+
119
+
120
+ def init_model(
121
+ network: nn.Module,
122
+ init_from: str or None = None,
123
+ backbone_init_from: str or None = None,
124
+ rand_init="trunc_normal",
125
+ last_gamma=None,
126
+ ) -> None:
127
+ # initialization
128
+ init_modules(network, init_type=rand_init)
129
+ # zero gamma of last bn in each block
130
+ if last_gamma is not None:
131
+ zero_last_gamma(network, last_gamma)
132
+
133
+ # load weight
134
+ if init_from is not None and os.path.isfile(init_from):
135
+ network.load_state_dict(load_state_dict_from_file(init_from))
136
+ print(f"Loaded init from {init_from}")
137
+ elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
138
+ network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
139
+ print(f"Loaded backbone init from {backbone_init_from}")
140
+ else:
141
+ print(f"Random init ({rand_init}) with last gamma {last_gamma}")
src/efficientvit/apps/trainer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .base import *
6
+ from .run_config import *
src/efficientvit/apps/trainer/base.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from src.efficientvit.apps.data_provider import DataProvider, parse_image_size
11
+ from src.efficientvit.apps.trainer.run_config import RunConfig
12
+ from src.efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank,
13
+ is_master)
14
+ from src.efficientvit.models.nn.norm import reset_bn
15
+ from src.efficientvit.models.utils import is_parallel, load_state_dict_from_file
16
+
17
+ __all__ = ["Trainer"]
18
+
19
+
20
+ class Trainer:
21
+ def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
22
+ self.path = os.path.realpath(os.path.expanduser(path))
23
+ self.model = model.cuda()
24
+ self.data_provider = data_provider
25
+
26
+ self.ema = None
27
+
28
+ self.checkpoint_path = os.path.join(self.path, "checkpoint")
29
+ self.logs_path = os.path.join(self.path, "logs")
30
+ for path in [self.path, self.checkpoint_path, self.logs_path]:
31
+ os.makedirs(path, exist_ok=True)
32
+
33
+ self.best_val = 0.0
34
+ self.start_epoch = 0
35
+
36
+ @property
37
+ def network(self) -> nn.Module:
38
+ return self.model.module if is_parallel(self.model) else self.model
39
+
40
+ @property
41
+ def eval_network(self) -> nn.Module:
42
+ if self.ema is None:
43
+ model = self.model
44
+ else:
45
+ model = self.ema.shadows
46
+ model = model.module if is_parallel(model) else model
47
+ return model
48
+
49
+ def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
50
+ if is_master():
51
+ fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
52
+ fout.write(log_str + "\n")
53
+ fout.flush()
54
+ fout.close()
55
+ if print_log:
56
+ print(log_str)
57
+
58
+ def save_model(
59
+ self,
60
+ checkpoint=None,
61
+ only_state_dict=True,
62
+ epoch=0,
63
+ model_name=None,
64
+ ) -> None:
65
+ if is_master():
66
+ if checkpoint is None:
67
+ if only_state_dict:
68
+ checkpoint = {"state_dict": self.network.state_dict()}
69
+ else:
70
+ checkpoint = {
71
+ "state_dict": self.network.state_dict(),
72
+ "epoch": epoch,
73
+ "best_val": self.best_val,
74
+ "optimizer": self.optimizer.state_dict(),
75
+ "lr_scheduler": self.lr_scheduler.state_dict(),
76
+ "ema": self.ema.state_dict() if self.ema is not None else None,
77
+ "scaler": self.scaler.state_dict() if self.fp16 else None,
78
+ }
79
+
80
+ model_name = model_name or "checkpoint.pt"
81
+
82
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
83
+ model_path = os.path.join(self.checkpoint_path, model_name)
84
+ with open(latest_fname, "w") as _fout:
85
+ _fout.write(model_path + "\n")
86
+ torch.save(checkpoint, model_path)
87
+
88
+ def load_model(self, model_fname=None) -> None:
89
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
90
+ if model_fname is None and os.path.exists(latest_fname):
91
+ with open(latest_fname, "r") as fin:
92
+ model_fname = fin.readline()
93
+ if len(model_fname) > 0 and model_fname[-1] == "\n":
94
+ model_fname = model_fname[:-1]
95
+ try:
96
+ if model_fname is None:
97
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
98
+ elif not os.path.exists(model_fname):
99
+ model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
100
+ if not os.path.exists(model_fname):
101
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
102
+ print(f"=> loading checkpoint {model_fname}")
103
+ checkpoint = load_state_dict_from_file(model_fname, False)
104
+ except Exception:
105
+ self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
106
+ return
107
+
108
+ # load checkpoint
109
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
110
+ log = []
111
+ if "epoch" in checkpoint:
112
+ self.start_epoch = checkpoint["epoch"] + 1
113
+ self.run_config.update_global_step(self.start_epoch)
114
+ log.append(f"epoch={self.start_epoch - 1}")
115
+ if "best_val" in checkpoint:
116
+ self.best_val = checkpoint["best_val"]
117
+ log.append(f"best_val={self.best_val:.2f}")
118
+ if "optimizer" in checkpoint:
119
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
120
+ log.append("optimizer")
121
+ if "lr_scheduler" in checkpoint:
122
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
123
+ log.append("lr_scheduler")
124
+ if "ema" in checkpoint and self.ema is not None:
125
+ self.ema.load_state_dict(checkpoint["ema"])
126
+ log.append("ema")
127
+ if "scaler" in checkpoint and self.fp16:
128
+ self.scaler.load_state_dict(checkpoint["scaler"])
129
+ log.append("scaler")
130
+ self.write_log("Loaded: " + ", ".join(log))
131
+
132
+ """ validate """
133
+
134
+ def reset_bn(
135
+ self,
136
+ network: nn.Module or None = None,
137
+ subset_size: int = 16000,
138
+ subset_batch_size: int = 100,
139
+ data_loader=None,
140
+ progress_bar=False,
141
+ ) -> None:
142
+ network = network or self.network
143
+ if data_loader is None:
144
+ data_loader = []
145
+ for data in self.data_provider.build_sub_train_loader(
146
+ subset_size, subset_batch_size
147
+ ):
148
+ if isinstance(data, list):
149
+ data_loader.append(data[0])
150
+ elif isinstance(data, dict):
151
+ data_loader.append(data["data"])
152
+ elif isinstance(data, torch.Tensor):
153
+ data_loader.append(data)
154
+ else:
155
+ raise NotImplementedError
156
+
157
+ network.eval()
158
+ reset_bn(
159
+ network,
160
+ data_loader,
161
+ sync=True,
162
+ progress_bar=progress_bar,
163
+ )
164
+
165
+ def _validate(self, model, data_loader, epoch) -> dict[str, any]:
166
+ raise NotImplementedError
167
+
168
+ def validate(
169
+ self, model=None, data_loader=None, is_test=True, epoch=0
170
+ ) -> dict[str, any]:
171
+ model = model or self.eval_network
172
+ if data_loader is None:
173
+ if is_test:
174
+ data_loader = self.data_provider.test
175
+ else:
176
+ data_loader = self.data_provider.valid
177
+
178
+ model.eval()
179
+ return self._validate(model, data_loader, epoch)
180
+
181
+ def multires_validate(
182
+ self,
183
+ model=None,
184
+ data_loader=None,
185
+ is_test=True,
186
+ epoch=0,
187
+ eval_image_size=None,
188
+ ) -> dict[str, dict[str, any]]:
189
+ eval_image_size = eval_image_size or self.run_config.eval_image_size
190
+ eval_image_size = eval_image_size or self.data_provider.image_size
191
+ model = model or self.eval_network
192
+
193
+ if not isinstance(eval_image_size, list):
194
+ eval_image_size = [eval_image_size]
195
+
196
+ output_dict = {}
197
+ for r in eval_image_size:
198
+ self.data_provider.assign_active_image_size(parse_image_size(r))
199
+ if self.run_config.reset_bn:
200
+ self.reset_bn(
201
+ network=model,
202
+ subset_size=self.run_config.reset_bn_size,
203
+ subset_batch_size=self.run_config.reset_bn_batch_size,
204
+ progress_bar=True,
205
+ )
206
+ output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
207
+ return output_dict
208
+
209
+ """ training """
210
+
211
+ def prep_for_training(
212
+ self, run_config: RunConfig, ema_decay: float or None = None, fp16=False
213
+ ) -> None:
214
+ self.run_config = run_config
215
+ self.model = nn.parallel.DistributedDataParallel(
216
+ self.model.cuda(),
217
+ device_ids=[get_dist_local_rank()],
218
+ static_graph=True,
219
+ )
220
+
221
+ self.run_config.global_step = 0
222
+ self.run_config.batch_per_epoch = len(self.data_provider.train)
223
+ assert self.run_config.batch_per_epoch > 0, "Training set is empty"
224
+
225
+ # build optimizer
226
+ self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
227
+
228
+ if ema_decay is not None:
229
+ self.ema = EMA(self.network, ema_decay)
230
+
231
+ # fp16
232
+ self.fp16 = fp16
233
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
234
+
235
+ def sync_model(self):
236
+ print("Sync model")
237
+ self.save_model(model_name="sync.pt")
238
+ dist_barrier()
239
+ checkpoint = torch.load(
240
+ os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu"
241
+ )
242
+ dist_barrier()
243
+ if is_master():
244
+ os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
245
+ dist_barrier()
246
+
247
+ # load checkpoint
248
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
249
+ if "optimizer" in checkpoint:
250
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
251
+ if "lr_scheduler" in checkpoint:
252
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
253
+ if "ema" in checkpoint and self.ema is not None:
254
+ self.ema.load_state_dict(checkpoint["ema"])
255
+ if "scaler" in checkpoint and self.fp16:
256
+ self.scaler.load_state_dict(checkpoint["scaler"])
257
+
258
+ def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
259
+ for key in feed_dict:
260
+ if isinstance(feed_dict[key], torch.Tensor):
261
+ feed_dict[key] = feed_dict[key].cuda()
262
+ return feed_dict
263
+
264
+ def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
265
+ raise NotImplementedError
266
+
267
+ def after_step(self) -> None:
268
+ self.scaler.unscale_(self.optimizer)
269
+ # gradient clip
270
+ if self.run_config.grad_clip is not None:
271
+ torch.nn.utils.clip_grad_value_(
272
+ self.model.parameters(), self.run_config.grad_clip
273
+ )
274
+ # update
275
+ self.scaler.step(self.optimizer)
276
+ self.scaler.update()
277
+
278
+ self.lr_scheduler.step()
279
+ self.run_config.step()
280
+ # update ema
281
+ if self.ema is not None:
282
+ self.ema.step(self.network, self.run_config.global_step)
283
+
284
+ def _train_one_epoch(self, epoch: int) -> dict[str, any]:
285
+ raise NotImplementedError
286
+
287
+ def train_one_epoch(self, epoch: int) -> dict[str, any]:
288
+ self.model.train()
289
+
290
+ self.data_provider.set_epoch(epoch)
291
+
292
+ train_info_dict = self._train_one_epoch(epoch)
293
+
294
+ return train_info_dict
295
+
296
+ def train(self) -> None:
297
+ raise NotImplementedError
src/efficientvit/apps/trainer/run_config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import json
6
+
7
+ import numpy as np
8
+ import torch.nn as nn
9
+
10
+ from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
11
+
12
+ __all__ = ["Scheduler", "RunConfig"]
13
+
14
+
15
+ class Scheduler:
16
+ PROGRESS = 0
17
+
18
+
19
+ class RunConfig:
20
+ n_epochs: int
21
+ init_lr: float
22
+ warmup_epochs: int
23
+ warmup_lr: float
24
+ lr_schedule_name: str
25
+ lr_schedule_param: dict
26
+ optimizer_name: str
27
+ optimizer_params: dict
28
+ weight_decay: float
29
+ no_wd_keys: list
30
+ grad_clip: float # allow none to turn off grad clipping
31
+ reset_bn: bool
32
+ reset_bn_size: int
33
+ reset_bn_batch_size: int
34
+ eval_image_size: list # allow none to use image_size in data_provider
35
+
36
+ @property
37
+ def none_allowed(self):
38
+ return ["grad_clip", "eval_image_size"]
39
+
40
+ def __init__(self, **kwargs): # arguments must be passed as kwargs
41
+ for k, val in kwargs.items():
42
+ setattr(self, k, val)
43
+
44
+ # check that all relevant configs are there
45
+ annotations = {}
46
+ for clas in type(self).mro():
47
+ if hasattr(clas, "__annotations__"):
48
+ annotations.update(clas.__annotations__)
49
+ for k, k_type in annotations.items():
50
+ assert hasattr(
51
+ self, k
52
+ ), f"Key {k} with type {k_type} required for initialization."
53
+ attr = getattr(self, k)
54
+ if k in self.none_allowed:
55
+ k_type = (k_type, type(None))
56
+ assert isinstance(
57
+ attr, k_type
58
+ ), f"Key {k} must be type {k_type}, provided={attr}."
59
+
60
+ self.global_step = 0
61
+ self.batch_per_epoch = 1
62
+
63
+ def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
64
+ r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
65
+ param_dict = {}
66
+ for name, param in network.named_parameters():
67
+ if param.requires_grad:
68
+ opt_config = [self.weight_decay, self.init_lr]
69
+ if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
70
+ if np.any([key in name for key in self.no_wd_keys]):
71
+ opt_config[0] = 0
72
+ opt_key = json.dumps(opt_config)
73
+ param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
74
+
75
+ net_params = []
76
+ for opt_key, param_list in param_dict.items():
77
+ wd, lr = json.loads(opt_key)
78
+ net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
79
+
80
+ optimizer = build_optimizer(
81
+ net_params, self.optimizer_name, self.optimizer_params, self.init_lr
82
+ )
83
+ # build lr scheduler
84
+ if self.lr_schedule_name == "cosine":
85
+ decay_steps = []
86
+ for epoch in self.lr_schedule_param.get("step", []):
87
+ decay_steps.append(epoch * self.batch_per_epoch)
88
+ decay_steps.append(self.n_epochs * self.batch_per_epoch)
89
+ decay_steps.sort()
90
+ lr_scheduler = CosineLRwithWarmup(
91
+ optimizer,
92
+ self.warmup_epochs * self.batch_per_epoch,
93
+ self.warmup_lr,
94
+ decay_steps,
95
+ )
96
+ else:
97
+ raise NotImplementedError
98
+ return optimizer, lr_scheduler
99
+
100
+ def update_global_step(self, epoch, batch_id=0) -> None:
101
+ self.global_step = epoch * self.batch_per_epoch + batch_id
102
+ Scheduler.PROGRESS = self.progress
103
+
104
+ @property
105
+ def progress(self) -> float:
106
+ warmup_steps = self.warmup_epochs * self.batch_per_epoch
107
+ steps = max(0, self.global_step - warmup_steps)
108
+ return steps / (self.n_epochs * self.batch_per_epoch)
109
+
110
+ def step(self) -> None:
111
+ self.global_step += 1
112
+ Scheduler.PROGRESS = self.progress
113
+
114
+ def get_remaining_epoch(self, epoch, post=True) -> int:
115
+ return self.n_epochs + self.warmup_epochs - epoch - int(post)
116
+
117
+ def epoch_format(self, epoch: int) -> str:
118
+ epoch_format = f"%.{len(str(self.n_epochs))}d"
119
+ epoch_format = f"[{epoch_format}/{epoch_format}]"
120
+ epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
121
+ return epoch_format
src/efficientvit/apps/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .dist import *
6
+ from .ema import *
7
+ from .export import *
8
+ from .init import *
9
+ from .lr import *
10
+ from .metric import *
11
+ from .misc import *
12
+ from .opt import *
src/efficientvit/apps/utils/dist.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.distributed
9
+
10
+ from src.efficientvit.models.utils.list import list_mean, list_sum
11
+
12
+ __all__ = [
13
+ "dist_init",
14
+ "get_dist_rank",
15
+ "get_dist_size",
16
+ "is_master",
17
+ "dist_barrier",
18
+ "get_dist_local_rank",
19
+ "sync_tensor",
20
+ ]
21
+
22
+
23
+ def dist_init() -> None:
24
+ try:
25
+ torch.distributed.init_process_group(backend="nccl")
26
+ assert torch.distributed.is_initialized()
27
+ except Exception:
28
+ # use torchpack
29
+ from torchpack import distributed as dist
30
+
31
+ dist.init()
32
+ os.environ["RANK"] = f"{dist.rank()}"
33
+ os.environ["WORLD_SIZE"] = f"{dist.size()}"
34
+ os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
35
+
36
+
37
+ def get_dist_rank() -> int:
38
+ return int(os.environ["RANK"])
39
+
40
+
41
+ def get_dist_size() -> int:
42
+ return int(os.environ["WORLD_SIZE"])
43
+
44
+
45
+ def is_master() -> bool:
46
+ return get_dist_rank() == 0
47
+
48
+
49
+ def dist_barrier() -> None:
50
+ torch.distributed.barrier()
51
+
52
+
53
+ def get_dist_local_rank() -> int:
54
+ return int(os.environ["LOCAL_RANK"])
55
+
56
+
57
+ def sync_tensor(
58
+ tensor: torch.Tensor or float, reduce="mean"
59
+ ) -> torch.Tensor or list[torch.Tensor]:
60
+ if not isinstance(tensor, torch.Tensor):
61
+ tensor = torch.Tensor(1).fill_(tensor).cuda()
62
+ tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
63
+ torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
64
+ if reduce == "mean":
65
+ return list_mean(tensor_list)
66
+ elif reduce == "sum":
67
+ return list_sum(tensor_list)
68
+ elif reduce == "cat":
69
+ return torch.cat(tensor_list, dim=0)
70
+ elif reduce == "root":
71
+ return tensor_list[0]
72
+ else:
73
+ return tensor_list
src/efficientvit/apps/utils/ema.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from src.efficientvit.models.utils import is_parallel
12
+
13
+ __all__ = ["EMA"]
14
+
15
+
16
+ def update_ema(
17
+ ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
18
+ ) -> None:
19
+ for k, v in ema.state_dict().items():
20
+ if v.dtype.is_floating_point:
21
+ v -= (1.0 - decay) * (v - new_state_dict[k].detach())
22
+
23
+
24
+ class EMA:
25
+ def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
26
+ self.shadows = copy.deepcopy(
27
+ model.module if is_parallel(model) else model
28
+ ).eval()
29
+ self.decay = decay
30
+ self.warmup_steps = warmup_steps
31
+
32
+ for p in self.shadows.parameters():
33
+ p.requires_grad = False
34
+
35
+ def step(self, model: nn.Module, global_step: int) -> None:
36
+ with torch.no_grad():
37
+ msd = (model.module if is_parallel(model) else model).state_dict()
38
+ update_ema(
39
+ self.shadows,
40
+ msd,
41
+ self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
42
+ )
43
+
44
+ def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
45
+ return {self.decay: self.shadows.state_dict()}
46
+
47
+ def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
48
+ for decay in state_dict:
49
+ if decay == self.decay:
50
+ self.shadows.load_state_dict(state_dict[decay])
src/efficientvit/apps/utils/export.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import io
6
+ import os
7
+
8
+ import onnx
9
+ import torch
10
+ import torch.nn as nn
11
+ from onnxsim import simplify as simplify_func
12
+
13
+ __all__ = ["export_onnx"]
14
+
15
+
16
+ def export_onnx(
17
+ model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11
18
+ ) -> None:
19
+ """Export a model to a platform-specific onnx format.
20
+
21
+ Args:
22
+ model: a torch.nn.Module object.
23
+ export_path: export location.
24
+ sample_inputs: Any.
25
+ simplify: a flag to turn on onnx-simplifier
26
+ opset: int
27
+ """
28
+ model.eval()
29
+
30
+ buffer = io.BytesIO()
31
+ with torch.no_grad():
32
+ torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
33
+ buffer.seek(0, 0)
34
+ if simplify:
35
+ onnx_model = onnx.load_model(buffer)
36
+ onnx_model, success = simplify_func(onnx_model)
37
+ assert success
38
+ new_buffer = io.BytesIO()
39
+ onnx.save(onnx_model, new_buffer)
40
+ buffer = new_buffer
41
+ buffer.seek(0, 0)
42
+
43
+ if buffer.getbuffer().nbytes > 0:
44
+ save_dir = os.path.dirname(export_path)
45
+ os.makedirs(save_dir, exist_ok=True)
46
+ with open(export_path, "wb") as f:
47
+ f.write(buffer.read())
src/efficientvit/apps/utils/init.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.modules.batchnorm import _BatchNorm
8
+
9
+ __all__ = ["init_modules", "zero_last_gamma"]
10
+
11
+
12
+ def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
13
+ _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
14
+
15
+ if isinstance(model, list):
16
+ for sub_module in model:
17
+ init_modules(sub_module, init_type)
18
+ else:
19
+ init_params = init_type.split("@")
20
+ init_params = float(init_params[1]) if len(init_params) > 1 else None
21
+
22
+ if init_type.startswith("trunc_normal"):
23
+ init_func = lambda param: nn.init.trunc_normal_(
24
+ param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
25
+ )
26
+ else:
27
+ raise NotImplementedError
28
+
29
+ for m in model.modules():
30
+ if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
31
+ init_func(m.weight)
32
+ if m.bias is not None:
33
+ m.bias.data.zero_()
34
+ elif isinstance(m, nn.Embedding):
35
+ init_func(m.weight)
36
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
37
+ m.weight.data.fill_(1)
38
+ m.bias.data.zero_()
39
+ else:
40
+ weight = getattr(m, "weight", None)
41
+ bias = getattr(m, "bias", None)
42
+ if isinstance(weight, torch.nn.Parameter):
43
+ init_func(weight)
44
+ if isinstance(bias, torch.nn.Parameter):
45
+ bias.data.zero_()
46
+
47
+
48
+ def zero_last_gamma(model: nn.Module, init_val=0) -> None:
49
+ import efficientvit.models.nn.ops as ops
50
+
51
+ for m in model.modules():
52
+ if isinstance(m, ops.ResidualBlock) and isinstance(
53
+ m.shortcut, ops.IdentityLayer
54
+ ):
55
+ if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
56
+ parent_module = m.main.point_conv
57
+ elif isinstance(m.main, ops.ResBlock):
58
+ parent_module = m.main.conv2
59
+ elif isinstance(m.main, ops.ConvLayer):
60
+ parent_module = m.main
61
+ elif isinstance(m.main, (ops.LiteMLA)):
62
+ parent_module = m.main.proj
63
+ else:
64
+ parent_module = None
65
+ if parent_module is not None:
66
+ norm = getattr(parent_module, "norm", None)
67
+ if norm is not None:
68
+ nn.init.constant_(norm.weight, init_val)
src/efficientvit/apps/utils/lr.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+ from src.efficientvit.models.utils.list import val2list
10
+
11
+ __all__ = ["CosineLRwithWarmup"]
12
+
13
+
14
+ class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
15
+ def __init__(
16
+ self,
17
+ optimizer: torch.optim.Optimizer,
18
+ warmup_steps: int,
19
+ warmup_lr: float,
20
+ decay_steps: int or list[int],
21
+ last_epoch: int = -1,
22
+ ) -> None:
23
+ self.warmup_steps = warmup_steps
24
+ self.warmup_lr = warmup_lr
25
+ self.decay_steps = val2list(decay_steps)
26
+ super().__init__(optimizer, last_epoch)
27
+
28
+ def get_lr(self) -> list[float]:
29
+ if self.last_epoch < self.warmup_steps:
30
+ return [
31
+ (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps
32
+ + self.warmup_lr
33
+ for base_lr in self.base_lrs
34
+ ]
35
+ else:
36
+ current_steps = self.last_epoch - self.warmup_steps
37
+ decay_steps = [0] + self.decay_steps
38
+ idx = len(decay_steps) - 2
39
+ for i, decay_step in enumerate(decay_steps[:-1]):
40
+ if decay_step <= current_steps < decay_steps[i + 1]:
41
+ idx = i
42
+ break
43
+ current_steps -= decay_steps[idx]
44
+ decay_step = decay_steps[idx + 1] - decay_steps[idx]
45
+ return [
46
+ 0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step))
47
+ for base_lr in self.base_lrs
48
+ ]
src/efficientvit/apps/utils/metric.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+
7
+ from src.efficientvit.apps.utils.dist import sync_tensor
8
+
9
+ __all__ = ["AverageMeter"]
10
+
11
+
12
+ class AverageMeter:
13
+ """Computes and stores the average and current value."""
14
+
15
+ def __init__(self, is_distributed=True):
16
+ self.is_distributed = is_distributed
17
+ self.sum = 0
18
+ self.count = 0
19
+
20
+ def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
21
+ return sync_tensor(val, reduce="sum") if self.is_distributed else val
22
+
23
+ def update(self, val: torch.Tensor or int or float, delta_n=1):
24
+ self.count += self._sync(delta_n)
25
+ self.sum += self._sync(val * delta_n)
26
+
27
+ def get_count(self) -> torch.Tensor or int or float:
28
+ return (
29
+ self.count.item()
30
+ if isinstance(self.count, torch.Tensor) and self.count.numel() == 1
31
+ else self.count
32
+ )
33
+
34
+ @property
35
+ def avg(self):
36
+ avg = -1 if self.count == 0 else self.sum / self.count
37
+ return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
src/efficientvit/apps/utils/misc.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import yaml
8
+
9
+ __all__ = [
10
+ "parse_with_yaml",
11
+ "parse_unknown_args",
12
+ "partial_update_config",
13
+ "resolve_and_load_config",
14
+ "load_config",
15
+ "dump_config",
16
+ ]
17
+
18
+
19
+ def parse_with_yaml(config_str: str) -> str or dict:
20
+ try:
21
+ # add space manually for dict
22
+ if "{" in config_str and "}" in config_str and ":" in config_str:
23
+ out_str = config_str.replace(":", ": ")
24
+ else:
25
+ out_str = config_str
26
+ return yaml.safe_load(out_str)
27
+ except ValueError:
28
+ # return raw string if parsing fails
29
+ return config_str
30
+
31
+
32
+ def parse_unknown_args(unknown: list) -> dict:
33
+ """Parse unknown args."""
34
+ index = 0
35
+ parsed_dict = {}
36
+ while index < len(unknown):
37
+ key, val = unknown[index], unknown[index + 1]
38
+ index += 2
39
+ if not key.startswith("--"):
40
+ continue
41
+ key = key[2:]
42
+
43
+ # try parsing with either dot notation or full yaml notation
44
+ # Note that the vanilla case "--key value" will be parsed the same
45
+ if "." in key:
46
+ # key == a.b.c, val == val --> parsed_dict[a][b][c] = val
47
+ keys = key.split(".")
48
+ dict_to_update = parsed_dict
49
+ for key in keys[:-1]:
50
+ if not (
51
+ key in dict_to_update and isinstance(dict_to_update[key], dict)
52
+ ):
53
+ dict_to_update[key] = {}
54
+ dict_to_update = dict_to_update[key]
55
+ dict_to_update[keys[-1]] = parse_with_yaml(
56
+ val
57
+ ) # so we can parse lists, bools, etc...
58
+ else:
59
+ parsed_dict[key] = parse_with_yaml(val)
60
+ return parsed_dict
61
+
62
+
63
+ def partial_update_config(config: dict, partial_config: dict) -> dict:
64
+ for key in partial_config:
65
+ if (
66
+ key in config
67
+ and isinstance(partial_config[key], dict)
68
+ and isinstance(config[key], dict)
69
+ ):
70
+ partial_update_config(config[key], partial_config[key])
71
+ else:
72
+ config[key] = partial_config[key]
73
+ return config
74
+
75
+
76
+ def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
77
+ path = os.path.realpath(os.path.expanduser(path))
78
+ if os.path.isdir(path):
79
+ config_path = os.path.join(path, config_name)
80
+ else:
81
+ config_path = path
82
+ if os.path.isfile(config_path):
83
+ pass
84
+ else:
85
+ raise Exception(f"Cannot find a valid config at {path}")
86
+ config = load_config(config_path)
87
+ return config
88
+
89
+
90
+ class SafeLoaderWithTuple(yaml.SafeLoader):
91
+ """A yaml safe loader with python tuple loading capabilities."""
92
+
93
+ def construct_python_tuple(self, node):
94
+ return tuple(self.construct_sequence(node))
95
+
96
+
97
+ SafeLoaderWithTuple.add_constructor(
98
+ "tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple
99
+ )
100
+
101
+
102
+ def load_config(filename: str) -> dict:
103
+ """Load a yaml file."""
104
+ filename = os.path.realpath(os.path.expanduser(filename))
105
+ return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
106
+
107
+
108
+ def dump_config(config: dict, filename: str) -> None:
109
+ """Dump a config file"""
110
+ filename = os.path.realpath(os.path.expanduser(filename))
111
+ yaml.dump(config, open(filename, "w"), sort_keys=False)
src/efficientvit/apps/utils/opt.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+
7
+ __all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
8
+
9
+ # register optimizer here
10
+ # name: optimizer, kwargs with default values
11
+ REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = {
12
+ "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
13
+ "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
14
+ "adamw": (
15
+ torch.optim.AdamW,
16
+ {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False},
17
+ ),
18
+ }
19
+
20
+
21
+ def build_optimizer(
22
+ net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float
23
+ ) -> torch.optim.Optimizer:
24
+ optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
25
+ optimizer_params = optimizer_params or {}
26
+
27
+ for key in default_params:
28
+ if key in optimizer_params:
29
+ default_params[key] = optimizer_params[key]
30
+ optimizer = optimizer_class(net_params, init_lr, **default_params)
31
+ return optimizer
src/efficientvit/models/__init__.py ADDED
File without changes
src/efficientvit/models/efficientvit/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .backbone import *
6
+ from .cls import *
7
+ from .sam import *
8
+ from .seg import *
src/efficientvit/models/efficientvit/backbone.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from src.efficientvit.models.nn import (ConvLayer, DSConv, EfficientViTBlock,
9
+ FusedMBConv, IdentityLayer, MBConv,
10
+ OpSequential, ResBlock, ResidualBlock)
11
+ from src.efficientvit.models.utils import build_kwargs_from_config
12
+
13
+ __all__ = [
14
+ "EfficientViTBackbone",
15
+ "efficientvit_backbone_b0",
16
+ "efficientvit_backbone_b1",
17
+ "efficientvit_backbone_b2",
18
+ "efficientvit_backbone_b3",
19
+ "EfficientViTLargeBackbone",
20
+ "efficientvit_backbone_l0",
21
+ "efficientvit_backbone_l1",
22
+ "efficientvit_backbone_l2",
23
+ "efficientvit_backbone_l3",
24
+ ]
25
+
26
+
27
+ class EfficientViTBackbone(nn.Module):
28
+ def __init__(
29
+ self,
30
+ width_list: list[int],
31
+ depth_list: list[int],
32
+ in_channels=3,
33
+ dim=32,
34
+ expand_ratio=4,
35
+ norm="bn2d",
36
+ act_func="hswish",
37
+ ) -> None:
38
+ super().__init__()
39
+
40
+ self.width_list = []
41
+ # input stem
42
+ self.input_stem = [
43
+ ConvLayer(
44
+ in_channels=3,
45
+ out_channels=width_list[0],
46
+ stride=2,
47
+ norm=norm,
48
+ act_func=act_func,
49
+ )
50
+ ]
51
+ for _ in range(depth_list[0]):
52
+ block = self.build_local_block(
53
+ in_channels=width_list[0],
54
+ out_channels=width_list[0],
55
+ stride=1,
56
+ expand_ratio=1,
57
+ norm=norm,
58
+ act_func=act_func,
59
+ )
60
+ self.input_stem.append(ResidualBlock(block, IdentityLayer()))
61
+ in_channels = width_list[0]
62
+ self.input_stem = OpSequential(self.input_stem)
63
+ self.width_list.append(in_channels)
64
+
65
+ # stages
66
+ self.stages = []
67
+ for w, d in zip(width_list[1:3], depth_list[1:3]):
68
+ stage = []
69
+ for i in range(d):
70
+ stride = 2 if i == 0 else 1
71
+ block = self.build_local_block(
72
+ in_channels=in_channels,
73
+ out_channels=w,
74
+ stride=stride,
75
+ expand_ratio=expand_ratio,
76
+ norm=norm,
77
+ act_func=act_func,
78
+ )
79
+ block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)
80
+ stage.append(block)
81
+ in_channels = w
82
+ self.stages.append(OpSequential(stage))
83
+ self.width_list.append(in_channels)
84
+
85
+ for w, d in zip(width_list[3:], depth_list[3:]):
86
+ stage = []
87
+ block = self.build_local_block(
88
+ in_channels=in_channels,
89
+ out_channels=w,
90
+ stride=2,
91
+ expand_ratio=expand_ratio,
92
+ norm=norm,
93
+ act_func=act_func,
94
+ fewer_norm=True,
95
+ )
96
+ stage.append(ResidualBlock(block, None))
97
+ in_channels = w
98
+
99
+ for _ in range(d):
100
+ stage.append(
101
+ EfficientViTBlock(
102
+ in_channels=in_channels,
103
+ dim=dim,
104
+ expand_ratio=expand_ratio,
105
+ norm=norm,
106
+ act_func=act_func,
107
+ )
108
+ )
109
+ self.stages.append(OpSequential(stage))
110
+ self.width_list.append(in_channels)
111
+ self.stages = nn.ModuleList(self.stages)
112
+
113
+ @staticmethod
114
+ def build_local_block(
115
+ in_channels: int,
116
+ out_channels: int,
117
+ stride: int,
118
+ expand_ratio: float,
119
+ norm: str,
120
+ act_func: str,
121
+ fewer_norm: bool = False,
122
+ ) -> nn.Module:
123
+ if expand_ratio == 1:
124
+ block = DSConv(
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ stride=stride,
128
+ use_bias=(True, False) if fewer_norm else False,
129
+ norm=(None, norm) if fewer_norm else norm,
130
+ act_func=(act_func, None),
131
+ )
132
+ else:
133
+ block = MBConv(
134
+ in_channels=in_channels,
135
+ out_channels=out_channels,
136
+ stride=stride,
137
+ expand_ratio=expand_ratio,
138
+ use_bias=(True, True, False) if fewer_norm else False,
139
+ norm=(None, None, norm) if fewer_norm else norm,
140
+ act_func=(act_func, act_func, None),
141
+ )
142
+ return block
143
+
144
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
145
+ output_dict = {"input": x}
146
+ output_dict["stage0"] = x = self.input_stem(x)
147
+ for stage_id, stage in enumerate(self.stages, 1):
148
+ output_dict["stage%d" % stage_id] = x = stage(x)
149
+ output_dict["stage_final"] = x
150
+ return output_dict
151
+
152
+
153
+ def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone:
154
+ backbone = EfficientViTBackbone(
155
+ width_list=[8, 16, 32, 64, 128],
156
+ depth_list=[1, 2, 2, 2, 2],
157
+ dim=16,
158
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
159
+ )
160
+ return backbone
161
+
162
+
163
+ def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone:
164
+ backbone = EfficientViTBackbone(
165
+ width_list=[16, 32, 64, 128, 256],
166
+ depth_list=[1, 2, 3, 3, 4],
167
+ dim=16,
168
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
169
+ )
170
+ return backbone
171
+
172
+
173
+ def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone:
174
+ backbone = EfficientViTBackbone(
175
+ width_list=[24, 48, 96, 192, 384],
176
+ depth_list=[1, 3, 4, 4, 6],
177
+ dim=32,
178
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
179
+ )
180
+ return backbone
181
+
182
+
183
+ def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone:
184
+ backbone = EfficientViTBackbone(
185
+ width_list=[32, 64, 128, 256, 512],
186
+ depth_list=[1, 4, 6, 6, 9],
187
+ dim=32,
188
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
189
+ )
190
+ return backbone
191
+
192
+
193
+ class EfficientViTLargeBackbone(nn.Module):
194
+ def __init__(
195
+ self,
196
+ width_list: list[int],
197
+ depth_list: list[int],
198
+ block_list: list[str] or None = None,
199
+ expand_list: list[float] or None = None,
200
+ fewer_norm_list: list[bool] or None = None,
201
+ in_channels=3,
202
+ qkv_dim=32,
203
+ norm="bn2d",
204
+ act_func="gelu",
205
+ ) -> None:
206
+ super().__init__()
207
+ block_list = block_list or ["res", "fmb", "fmb", "mb", "att"]
208
+ expand_list = expand_list or [1, 4, 4, 4, 6]
209
+ fewer_norm_list = fewer_norm_list or [False, False, False, True, True]
210
+
211
+ self.width_list = []
212
+ self.stages = []
213
+ # stage 0
214
+ stage0 = [
215
+ ConvLayer(
216
+ in_channels=3,
217
+ out_channels=width_list[0],
218
+ stride=2,
219
+ norm=norm,
220
+ act_func=act_func,
221
+ )
222
+ ]
223
+ for _ in range(depth_list[0]):
224
+ block = self.build_local_block(
225
+ block=block_list[0],
226
+ in_channels=width_list[0],
227
+ out_channels=width_list[0],
228
+ stride=1,
229
+ expand_ratio=expand_list[0],
230
+ norm=norm,
231
+ act_func=act_func,
232
+ fewer_norm=fewer_norm_list[0],
233
+ )
234
+ stage0.append(ResidualBlock(block, IdentityLayer()))
235
+ in_channels = width_list[0]
236
+ self.stages.append(OpSequential(stage0))
237
+ self.width_list.append(in_channels)
238
+
239
+ for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1):
240
+ stage = []
241
+ block = self.build_local_block(
242
+ block=(
243
+ "mb"
244
+ if block_list[stage_id] not in ["mb", "fmb"]
245
+ else block_list[stage_id]
246
+ ),
247
+ in_channels=in_channels,
248
+ out_channels=w,
249
+ stride=2,
250
+ expand_ratio=expand_list[stage_id] * 4,
251
+ norm=norm,
252
+ act_func=act_func,
253
+ fewer_norm=fewer_norm_list[stage_id],
254
+ )
255
+ stage.append(ResidualBlock(block, None))
256
+ in_channels = w
257
+
258
+ for _ in range(d):
259
+ if block_list[stage_id].startswith("att"):
260
+ stage.append(
261
+ EfficientViTBlock(
262
+ in_channels=in_channels,
263
+ dim=qkv_dim,
264
+ expand_ratio=expand_list[stage_id],
265
+ scales=(3,) if block_list[stage_id] == "att@3" else (5,),
266
+ norm=norm,
267
+ act_func=act_func,
268
+ )
269
+ )
270
+ else:
271
+ block = self.build_local_block(
272
+ block=block_list[stage_id],
273
+ in_channels=in_channels,
274
+ out_channels=in_channels,
275
+ stride=1,
276
+ expand_ratio=expand_list[stage_id],
277
+ norm=norm,
278
+ act_func=act_func,
279
+ fewer_norm=fewer_norm_list[stage_id],
280
+ )
281
+ block = ResidualBlock(block, IdentityLayer())
282
+ stage.append(block)
283
+ self.stages.append(OpSequential(stage))
284
+ self.width_list.append(in_channels)
285
+ self.stages = nn.ModuleList(self.stages)
286
+
287
+ @staticmethod
288
+ def build_local_block(
289
+ block: str,
290
+ in_channels: int,
291
+ out_channels: int,
292
+ stride: int,
293
+ expand_ratio: float,
294
+ norm: str,
295
+ act_func: str,
296
+ fewer_norm: bool = False,
297
+ ) -> nn.Module:
298
+ if block == "res":
299
+ block = ResBlock(
300
+ in_channels=in_channels,
301
+ out_channels=out_channels,
302
+ stride=stride,
303
+ use_bias=(True, False) if fewer_norm else False,
304
+ norm=(None, norm) if fewer_norm else norm,
305
+ act_func=(act_func, None),
306
+ )
307
+ elif block == "fmb":
308
+ block = FusedMBConv(
309
+ in_channels=in_channels,
310
+ out_channels=out_channels,
311
+ stride=stride,
312
+ expand_ratio=expand_ratio,
313
+ use_bias=(True, False) if fewer_norm else False,
314
+ norm=(None, norm) if fewer_norm else norm,
315
+ act_func=(act_func, None),
316
+ )
317
+ elif block == "mb":
318
+ block = MBConv(
319
+ in_channels=in_channels,
320
+ out_channels=out_channels,
321
+ stride=stride,
322
+ expand_ratio=expand_ratio,
323
+ use_bias=(True, True, False) if fewer_norm else False,
324
+ norm=(None, None, norm) if fewer_norm else norm,
325
+ act_func=(act_func, act_func, None),
326
+ )
327
+ else:
328
+ raise ValueError(block)
329
+ return block
330
+
331
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
332
+ output_dict = {"input": x}
333
+ for stage_id, stage in enumerate(self.stages):
334
+ output_dict["stage%d" % stage_id] = x = stage(x)
335
+ output_dict["stage_final"] = x
336
+ return output_dict
337
+
338
+
339
+ def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone:
340
+ backbone = EfficientViTLargeBackbone(
341
+ width_list=[32, 64, 128, 256, 512],
342
+ depth_list=[1, 1, 1, 4, 4],
343
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
344
+ )
345
+ return backbone
346
+
347
+
348
+ def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone:
349
+ backbone = EfficientViTLargeBackbone(
350
+ width_list=[32, 64, 128, 256, 512],
351
+ depth_list=[1, 1, 1, 6, 6],
352
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
353
+ )
354
+ return backbone
355
+
356
+
357
+ def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone:
358
+ backbone = EfficientViTLargeBackbone(
359
+ width_list=[32, 64, 128, 256, 512],
360
+ depth_list=[1, 2, 2, 8, 8],
361
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
362
+ )
363
+ return backbone
364
+
365
+
366
+ def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone:
367
+ backbone = EfficientViTLargeBackbone(
368
+ width_list=[64, 128, 256, 512, 1024],
369
+ depth_list=[1, 2, 2, 8, 8],
370
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
371
+ )
372
+ return backbone
src/efficientvit/models/efficientvit/cls.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from src.efficientvit.models.efficientvit.backbone import (
9
+ EfficientViTBackbone, EfficientViTLargeBackbone)
10
+ from src.efficientvit.models.nn import ConvLayer, LinearLayer, OpSequential
11
+ from src.efficientvit.models.utils import build_kwargs_from_config
12
+
13
+ __all__ = [
14
+ "EfficientViTCls",
15
+ ######################
16
+ "efficientvit_cls_b0",
17
+ "efficientvit_cls_b1",
18
+ "efficientvit_cls_b2",
19
+ "efficientvit_cls_b3",
20
+ ######################
21
+ "efficientvit_cls_l1",
22
+ "efficientvit_cls_l2",
23
+ "efficientvit_cls_l3",
24
+ ]
25
+
26
+
27
+ class ClsHead(OpSequential):
28
+ def __init__(
29
+ self,
30
+ in_channels: int,
31
+ width_list: list[int],
32
+ n_classes=1000,
33
+ dropout=0.0,
34
+ norm="bn2d",
35
+ act_func="hswish",
36
+ fid="stage_final",
37
+ ):
38
+ ops = [
39
+ ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func),
40
+ nn.AdaptiveAvgPool2d(output_size=1),
41
+ LinearLayer(
42
+ width_list[0], width_list[1], False, norm="ln", act_func=act_func
43
+ ),
44
+ LinearLayer(width_list[1], n_classes, True, dropout, None, None),
45
+ ]
46
+ super().__init__(ops)
47
+
48
+ self.fid = fid
49
+
50
+ def forward(self, feed_dict: dict[str, torch.Tensor]) -> torch.Tensor:
51
+ x = feed_dict[self.fid]
52
+ return OpSequential.forward(self, x)
53
+
54
+
55
+ class EfficientViTCls(nn.Module):
56
+ def __init__(
57
+ self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: ClsHead
58
+ ) -> None:
59
+ super().__init__()
60
+ self.backbone = backbone
61
+ self.head = head
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ feed_dict = self.backbone(x)
65
+ output = self.head(feed_dict)
66
+ return output
67
+
68
+
69
+ def efficientvit_cls_b0(**kwargs) -> EfficientViTCls:
70
+ from efficientvit.models.efficientvit.backbone import \
71
+ efficientvit_backbone_b0
72
+
73
+ backbone = efficientvit_backbone_b0(**kwargs)
74
+
75
+ head = ClsHead(
76
+ in_channels=128,
77
+ width_list=[1024, 1280],
78
+ **build_kwargs_from_config(kwargs, ClsHead),
79
+ )
80
+ model = EfficientViTCls(backbone, head)
81
+ return model
82
+
83
+
84
+ def efficientvit_cls_b1(**kwargs) -> EfficientViTCls:
85
+ from efficientvit.models.efficientvit.backbone import \
86
+ efficientvit_backbone_b1
87
+
88
+ backbone = efficientvit_backbone_b1(**kwargs)
89
+
90
+ head = ClsHead(
91
+ in_channels=256,
92
+ width_list=[1536, 1600],
93
+ **build_kwargs_from_config(kwargs, ClsHead),
94
+ )
95
+ model = EfficientViTCls(backbone, head)
96
+ return model
97
+
98
+
99
+ def efficientvit_cls_b2(**kwargs) -> EfficientViTCls:
100
+ from efficientvit.models.efficientvit.backbone import \
101
+ efficientvit_backbone_b2
102
+
103
+ backbone = efficientvit_backbone_b2(**kwargs)
104
+
105
+ head = ClsHead(
106
+ in_channels=384,
107
+ width_list=[2304, 2560],
108
+ **build_kwargs_from_config(kwargs, ClsHead),
109
+ )
110
+ model = EfficientViTCls(backbone, head)
111
+ return model
112
+
113
+
114
+ def efficientvit_cls_b3(**kwargs) -> EfficientViTCls:
115
+ from efficientvit.models.efficientvit.backbone import \
116
+ efficientvit_backbone_b3
117
+
118
+ backbone = efficientvit_backbone_b3(**kwargs)
119
+
120
+ head = ClsHead(
121
+ in_channels=512,
122
+ width_list=[2304, 2560],
123
+ **build_kwargs_from_config(kwargs, ClsHead),
124
+ )
125
+ model = EfficientViTCls(backbone, head)
126
+ return model
127
+
128
+
129
+ def efficientvit_cls_l1(**kwargs) -> EfficientViTCls:
130
+ from efficientvit.models.efficientvit.backbone import \
131
+ efficientvit_backbone_l1
132
+
133
+ backbone = efficientvit_backbone_l1(**kwargs)
134
+
135
+ head = ClsHead(
136
+ in_channels=512,
137
+ width_list=[3072, 3200],
138
+ act_func="gelu",
139
+ **build_kwargs_from_config(kwargs, ClsHead),
140
+ )
141
+ model = EfficientViTCls(backbone, head)
142
+ return model
143
+
144
+
145
+ def efficientvit_cls_l2(**kwargs) -> EfficientViTCls:
146
+ from efficientvit.models.efficientvit.backbone import \
147
+ efficientvit_backbone_l2
148
+
149
+ backbone = efficientvit_backbone_l2(**kwargs)
150
+
151
+ head = ClsHead(
152
+ in_channels=512,
153
+ width_list=[3072, 3200],
154
+ act_func="gelu",
155
+ **build_kwargs_from_config(kwargs, ClsHead),
156
+ )
157
+ model = EfficientViTCls(backbone, head)
158
+ return model
159
+
160
+
161
+ def efficientvit_cls_l3(**kwargs) -> EfficientViTCls:
162
+ from efficientvit.models.efficientvit.backbone import \
163
+ efficientvit_backbone_l3
164
+
165
+ backbone = efficientvit_backbone_l3(**kwargs)
166
+
167
+ head = ClsHead(
168
+ in_channels=1024,
169
+ width_list=[6144, 6400],
170
+ act_func="gelu",
171
+ **build_kwargs_from_config(kwargs, ClsHead),
172
+ )
173
+ model = EfficientViTCls(backbone, head)
174
+ return model
src/efficientvit/models/efficientvit/sam.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as transforms
12
+ from segment_anything import SamAutomaticMaskGenerator
13
+ from segment_anything.modeling import (MaskDecoder, PromptEncoder,
14
+ TwoWayTransformer)
15
+ from segment_anything.modeling.mask_decoder import MaskDecoder
16
+ from segment_anything.modeling.prompt_encoder import PromptEncoder
17
+ from segment_anything.utils.amg import build_all_layer_point_grids
18
+ from segment_anything.utils.transforms import ResizeLongestSide
19
+ from torchvision.transforms.functional import resize, to_pil_image
20
+
21
+ from src.efficientvit.models.efficientvit.backbone import (
22
+ EfficientViTBackbone, EfficientViTLargeBackbone)
23
+ from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
24
+ IdentityLayer, MBConv, OpSequential,
25
+ ResBlock, ResidualBlock, UpSampleLayer,
26
+ build_norm)
27
+ from src.efficientvit.models.utils import build_kwargs_from_config, get_device
28
+
29
+ __all__ = [
30
+ "SamPad",
31
+ "SamResize",
32
+ "SamNeck",
33
+ "EfficientViTSamImageEncoder",
34
+ "EfficientViTSam",
35
+ "EfficientViTSamPredictor",
36
+ "EfficientViTSamAutomaticMaskGenerator",
37
+ "efficientvit_sam_l0",
38
+ "efficientvit_sam_l1",
39
+ "efficientvit_sam_l2",
40
+ "efficientvit_sam_xl0",
41
+ "efficientvit_sam_xl1",
42
+ ]
43
+
44
+
45
+ class SamPad:
46
+ def __init__(self, size: int, fill: float = 0, pad_mode="corner") -> None:
47
+ self.size = size
48
+ self.fill = fill
49
+ self.pad_mode = pad_mode
50
+
51
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
52
+ h, w = image.shape[-2:]
53
+ th, tw = self.size, self.size
54
+ assert th >= h and tw >= w
55
+ if self.pad_mode == "corner":
56
+ image = F.pad(image, (0, tw - w, 0, th - h), value=self.fill)
57
+ else:
58
+ raise NotImplementedError
59
+ return image
60
+
61
+ def __repr__(self) -> str:
62
+ return f"{type(self).__name__}(size={self.size},mode={self.pad_mode},fill={self.fill})"
63
+
64
+
65
+ class SamResize:
66
+ def __init__(self, size: int) -> None:
67
+ self.size = size
68
+
69
+ def __call__(self, image: np.ndarray) -> np.ndarray:
70
+ h, w, _ = image.shape
71
+ long_side = max(h, w)
72
+ if long_side != self.size:
73
+ return self.apply_image(image)
74
+ else:
75
+ return image
76
+
77
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
78
+ """
79
+ Expects a numpy array with shape HxWxC in uint8 format.
80
+ """
81
+ target_size = self.get_preprocess_shape(
82
+ image.shape[0], image.shape[1], self.size
83
+ )
84
+ return np.array(resize(to_pil_image(image), target_size))
85
+
86
+ @staticmethod
87
+ def get_preprocess_shape(
88
+ oldh: int, oldw: int, long_side_length: int
89
+ ) -> tuple[int, int]:
90
+ """
91
+ Compute the output size given input size and target long side length.
92
+ """
93
+ scale = long_side_length * 1.0 / max(oldh, oldw)
94
+ newh, neww = oldh * scale, oldw * scale
95
+ neww = int(neww + 0.5)
96
+ newh = int(newh + 0.5)
97
+ return (newh, neww)
98
+
99
+ def __repr__(self) -> str:
100
+ return f"{type(self).__name__}(size={self.size})"
101
+
102
+
103
+ class SamNeck(DAGBlock):
104
+ def __init__(
105
+ self,
106
+ fid_list: list[str],
107
+ in_channel_list: list[int],
108
+ head_width: int,
109
+ head_depth: int,
110
+ expand_ratio: float,
111
+ middle_op: str,
112
+ out_dim: int = 256,
113
+ norm="bn2d",
114
+ act_func="gelu",
115
+ ):
116
+ inputs = {}
117
+ for fid, in_channel in zip(fid_list, in_channel_list):
118
+ inputs[fid] = OpSequential(
119
+ [
120
+ ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
121
+ UpSampleLayer(size=(64, 64)),
122
+ ]
123
+ )
124
+
125
+ middle = []
126
+ for _ in range(head_depth):
127
+ if middle_op == "mb":
128
+ block = MBConv(
129
+ head_width,
130
+ head_width,
131
+ expand_ratio=expand_ratio,
132
+ norm=norm,
133
+ act_func=(act_func, act_func, None),
134
+ )
135
+ elif middle_op == "fmb":
136
+ block = FusedMBConv(
137
+ head_width,
138
+ head_width,
139
+ expand_ratio=expand_ratio,
140
+ norm=norm,
141
+ act_func=(act_func, None),
142
+ )
143
+ elif middle_op == "res":
144
+ block = ResBlock(
145
+ head_width,
146
+ head_width,
147
+ expand_ratio=expand_ratio,
148
+ norm=norm,
149
+ act_func=(act_func, None),
150
+ )
151
+ else:
152
+ raise NotImplementedError
153
+ middle.append(ResidualBlock(block, IdentityLayer()))
154
+ middle = OpSequential(middle)
155
+
156
+ outputs = {
157
+ "sam_encoder": OpSequential(
158
+ [
159
+ ConvLayer(
160
+ head_width,
161
+ out_dim,
162
+ 1,
163
+ use_bias=True,
164
+ norm=None,
165
+ act_func=None,
166
+ ),
167
+ ]
168
+ )
169
+ }
170
+
171
+ super(SamNeck, self).__init__(
172
+ inputs, "add", None, middle=middle, outputs=outputs
173
+ )
174
+
175
+
176
+ class EfficientViTSamImageEncoder(nn.Module):
177
+ def __init__(
178
+ self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, neck: SamNeck
179
+ ):
180
+ super().__init__()
181
+ self.backbone = backbone
182
+ self.neck = neck
183
+
184
+ self.norm = build_norm("ln2d", 256)
185
+
186
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
187
+ feed_dict = self.backbone(x)
188
+ feed_dict = self.neck(feed_dict)
189
+
190
+ output = feed_dict["sam_encoder"]
191
+ output = self.norm(output)
192
+ return output
193
+
194
+
195
+ class EfficientViTSam(nn.Module):
196
+ mask_threshold: float = 0.0
197
+ image_format: str = "RGB"
198
+
199
+ def __init__(
200
+ self,
201
+ image_encoder: EfficientViTSamImageEncoder,
202
+ prompt_encoder: PromptEncoder,
203
+ mask_decoder: MaskDecoder,
204
+ image_size: tuple[int, int] = (1024, 512),
205
+ ) -> None:
206
+ super().__init__()
207
+ self.image_encoder = image_encoder
208
+ self.prompt_encoder = prompt_encoder
209
+ self.mask_decoder = mask_decoder
210
+
211
+ self.image_size = image_size
212
+
213
+ self.transform = transforms.Compose(
214
+ [
215
+ SamResize(self.image_size[1]),
216
+ transforms.ToTensor(),
217
+ transforms.Normalize(
218
+ mean=[123.675 / 255, 116.28 / 255, 103.53 / 255],
219
+ std=[58.395 / 255, 57.12 / 255, 57.375 / 255],
220
+ ),
221
+ SamPad(self.image_size[1]),
222
+ ]
223
+ )
224
+
225
+ def postprocess_masks(
226
+ self,
227
+ masks: torch.Tensor,
228
+ input_size: tuple[int, ...],
229
+ original_size: tuple[int, ...],
230
+ ) -> torch.Tensor:
231
+ masks = F.interpolate(
232
+ masks,
233
+ (self.image_size[0], self.image_size[0]),
234
+ mode="bilinear",
235
+ align_corners=False,
236
+ )
237
+ masks = masks[..., : input_size[0], : input_size[1]]
238
+ masks = F.interpolate(
239
+ masks, original_size, mode="bilinear", align_corners=False
240
+ )
241
+ return masks
242
+
243
+
244
+ class EfficientViTSamPredictor:
245
+ def __init__(self, sam_model: EfficientViTSam) -> None:
246
+ self.model = sam_model
247
+ self.reset_image()
248
+
249
+ @property
250
+ def transform(self):
251
+ return self
252
+
253
+ @property
254
+ def device(self):
255
+ return get_device(self.model)
256
+
257
+ def reset_image(self) -> None:
258
+ self.is_image_set = False
259
+ self.features = None
260
+ self.original_size = None
261
+ self.input_size = None
262
+
263
+ def apply_coords(self, coords: np.ndarray, im_size=None) -> np.ndarray:
264
+ old_h, old_w = self.original_size
265
+ new_h, new_w = self.input_size
266
+ coords = copy.deepcopy(coords).astype(float)
267
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
268
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
269
+ return coords
270
+
271
+ def apply_boxes(self, boxes: np.ndarray, im_size=None) -> np.ndarray:
272
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2))
273
+ return boxes.reshape(-1, 4)
274
+
275
+ @torch.inference_mode()
276
+ def set_image(self, image: np.ndarray, image_format: str = "RGB") -> None:
277
+ assert image_format in [
278
+ "RGB",
279
+ "BGR",
280
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
281
+ if image_format != self.model.image_format:
282
+ image = image[..., ::-1]
283
+
284
+ self.reset_image()
285
+
286
+ self.original_size = image.shape[:2]
287
+ self.input_size = ResizeLongestSide.get_preprocess_shape(
288
+ *self.original_size, long_side_length=self.model.image_size[0]
289
+ )
290
+
291
+ torch_data = (
292
+ self.model.transform(image).unsqueeze(dim=0).to(get_device(self.model))
293
+ )
294
+ self.features = self.model.image_encoder(torch_data)
295
+ self.is_image_set = True
296
+
297
+ def predict(
298
+ self,
299
+ point_coords: np.ndarray or None = None,
300
+ point_labels: np.ndarray or None = None,
301
+ box: np.ndarray or None = None,
302
+ mask_input: np.ndarray or None = None,
303
+ multimask_output: bool = True,
304
+ return_logits: bool = False,
305
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306
+ """
307
+ Predict masks for the given input prompts, using the currently set image.
308
+
309
+ Arguments:
310
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
311
+ model. Each point is in (X,Y) in pixels.
312
+ point_labels (np.ndarray or None): A length N array of labels for the
313
+ point prompts. 1 indicates a foreground point and 0 indicates a
314
+ background point.
315
+ box (np.ndarray or None): A length 4 array given a box prompt to the
316
+ model, in XYXY format.
317
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
318
+ coming from a previous prediction iteration. Has form 1xHxW, where
319
+ for SAM, H=W=256.
320
+ multimask_output (bool): If true, the model will return three masks.
321
+ For ambiguous input prompts (such as a single click), this will often
322
+ produce better masks than a single prediction. If only a single
323
+ mask is needed, the model's predicted quality score can be used
324
+ to select the best mask. For non-ambiguous prompts, such as multiple
325
+ input prompts, multimask_output=False can give better results.
326
+ return_logits (bool): If true, returns un-thresholded masks logits
327
+ instead of a binary mask.
328
+
329
+ Returns:
330
+ (np.ndarray): The output masks in CxHxW format, where C is the
331
+ number of masks, and (H, W) is the original image size.
332
+ (np.ndarray): An array of length C containing the model's
333
+ predictions for the quality of each mask.
334
+ (np.ndarray): An array of shape CxHxW, where C is the number
335
+ of masks and H=W=256. These low resolution logits can be passed to
336
+ a subsequent iteration as mask input.
337
+ """
338
+ if not self.is_image_set:
339
+ raise RuntimeError(
340
+ "An image must be set with .set_image(...) before mask prediction."
341
+ )
342
+
343
+ device = get_device(self.model)
344
+ # Transform input prompts
345
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
346
+ if point_coords is not None:
347
+ assert (
348
+ point_labels is not None
349
+ ), "point_labels must be supplied if point_coords is supplied."
350
+ point_coords = self.apply_coords(point_coords)
351
+ coords_torch = torch.as_tensor(
352
+ point_coords, dtype=torch.float, device=device
353
+ )
354
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
355
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
356
+ if box is not None:
357
+ box = self.apply_boxes(box)
358
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
359
+ box_torch = box_torch[None, :]
360
+ if mask_input is not None:
361
+ mask_input_torch = torch.as_tensor(
362
+ mask_input, dtype=torch.float, device=device
363
+ )
364
+ mask_input_torch = mask_input_torch[None, :, :, :]
365
+
366
+ masks, iou_predictions, low_res_masks = self.predict_torch(
367
+ coords_torch,
368
+ labels_torch,
369
+ box_torch,
370
+ mask_input_torch,
371
+ multimask_output,
372
+ return_logits=return_logits,
373
+ )
374
+
375
+ masks = masks[0].detach().cpu().numpy()
376
+ iou_predictions = iou_predictions[0].detach().cpu().numpy()
377
+ low_res_masks = low_res_masks[0].detach().cpu().numpy()
378
+ return masks, iou_predictions, low_res_masks
379
+
380
+ @torch.inference_mode()
381
+ def predict_torch(
382
+ self,
383
+ point_coords: torch.Tensor or None = None,
384
+ point_labels: torch.Tensor or None = None,
385
+ boxes: torch.Tensor or None = None,
386
+ mask_input: torch.Tensor or None = None,
387
+ multimask_output: bool = True,
388
+ return_logits: bool = False,
389
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
390
+ """
391
+ Predict masks for the given input prompts, using the currently set image.
392
+ Input prompts are batched torch tensors and are expected to already be
393
+ transformed to the input frame using ResizeLongestSide.
394
+
395
+ Arguments:
396
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
397
+ model. Each point is in (X,Y) in pixels.
398
+ point_labels (torch.Tensor or None): A BxN array of labels for the
399
+ point prompts. 1 indicates a foreground point and 0 indicates a
400
+ background point.
401
+ box (np.ndarray or None): A Bx4 array given a box prompt to the
402
+ model, in XYXY format.
403
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
404
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
405
+ for SAM, H=W=256. Masks returned by a previous iteration of the
406
+ predict method do not need further transformation.
407
+ multimask_output (bool): If true, the model will return three masks.
408
+ For ambiguous input prompts (such as a single click), this will often
409
+ produce better masks than a single prediction. If only a single
410
+ mask is needed, the model's predicted quality score can be used
411
+ to select the best mask. For non-ambiguous prompts, such as multiple
412
+ input prompts, multimask_output=False can give better results.
413
+ return_logits (bool): If true, returns un-thresholded masks logits
414
+ instead of a binary mask.
415
+
416
+ Returns:
417
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
418
+ number of masks, and (H, W) is the original image size.
419
+ (torch.Tensor): An array of shape BxC containing the model's
420
+ predictions for the quality of each mask.
421
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
422
+ of masks and H=W=256. These low res logits can be passed to
423
+ a subsequent iteration as mask input.
424
+ """
425
+ if not self.is_image_set:
426
+ raise RuntimeError(
427
+ "An image must be set with .set_image(...) before mask prediction."
428
+ )
429
+
430
+ if point_coords is not None:
431
+ points = (point_coords, point_labels)
432
+ else:
433
+ points = None
434
+
435
+ # Embed prompts
436
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
437
+ points=points,
438
+ boxes=boxes,
439
+ masks=mask_input,
440
+ )
441
+
442
+ # Predict masks
443
+ low_res_masks, iou_predictions = self.model.mask_decoder(
444
+ image_embeddings=self.features,
445
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
446
+ sparse_prompt_embeddings=sparse_embeddings,
447
+ dense_prompt_embeddings=dense_embeddings,
448
+ multimask_output=multimask_output,
449
+ )
450
+
451
+ # Upscale the masks to the original image resolution
452
+ masks = self.model.postprocess_masks(
453
+ low_res_masks, self.input_size, self.original_size
454
+ )
455
+
456
+ if not return_logits:
457
+ masks = masks > self.model.mask_threshold
458
+
459
+ return masks, iou_predictions, low_res_masks
460
+
461
+
462
+ class EfficientViTSamAutomaticMaskGenerator(SamAutomaticMaskGenerator):
463
+ def __init__(
464
+ self,
465
+ model: EfficientViTSam,
466
+ points_per_side: int or None = 32,
467
+ points_per_batch: int = 64,
468
+ pred_iou_thresh: float = 0.88,
469
+ stability_score_thresh: float = 0.95,
470
+ stability_score_offset: float = 1.0,
471
+ box_nms_thresh: float = 0.7,
472
+ crop_n_layers: int = 0,
473
+ crop_nms_thresh: float = 0.7,
474
+ crop_overlap_ratio: float = 512 / 1500,
475
+ crop_n_points_downscale_factor: int = 1,
476
+ point_grids: list[np.ndarray] or None = None,
477
+ min_mask_region_area: int = 0,
478
+ output_mode: str = "binary_mask",
479
+ ) -> None:
480
+ assert (points_per_side is None) != (
481
+ point_grids is None
482
+ ), "Exactly one of points_per_side or point_grid must be provided."
483
+ if points_per_side is not None:
484
+ self.point_grids = build_all_layer_point_grids(
485
+ points_per_side,
486
+ crop_n_layers,
487
+ crop_n_points_downscale_factor,
488
+ )
489
+ elif point_grids is not None:
490
+ self.point_grids = point_grids
491
+ else:
492
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
493
+
494
+ assert output_mode in [
495
+ "binary_mask",
496
+ "uncompressed_rle",
497
+ "coco_rle",
498
+ ], f"Unknown output_mode {output_mode}."
499
+ if output_mode == "coco_rle":
500
+ from pycocotools import \
501
+ mask as mask_utils # type: ignore # noqa: F401
502
+
503
+ if min_mask_region_area > 0:
504
+ import cv2 # type: ignore # noqa: F401
505
+
506
+ self.predictor = EfficientViTSamPredictor(model)
507
+ self.points_per_batch = points_per_batch
508
+ self.pred_iou_thresh = pred_iou_thresh
509
+ self.stability_score_thresh = stability_score_thresh
510
+ self.stability_score_offset = stability_score_offset
511
+ self.box_nms_thresh = box_nms_thresh
512
+ self.crop_n_layers = crop_n_layers
513
+ self.crop_nms_thresh = crop_nms_thresh
514
+ self.crop_overlap_ratio = crop_overlap_ratio
515
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
516
+ self.min_mask_region_area = min_mask_region_area
517
+ self.output_mode = output_mode
518
+
519
+
520
+ def build_efficientvit_sam(
521
+ image_encoder: EfficientViTSamImageEncoder, image_size: int
522
+ ) -> EfficientViTSam:
523
+ return EfficientViTSam(
524
+ image_encoder=image_encoder,
525
+ prompt_encoder=PromptEncoder(
526
+ embed_dim=256,
527
+ image_embedding_size=(64, 64),
528
+ input_image_size=(1024, 1024),
529
+ mask_in_chans=16,
530
+ ),
531
+ mask_decoder=MaskDecoder(
532
+ num_multimask_outputs=3,
533
+ transformer=TwoWayTransformer(
534
+ depth=2,
535
+ embedding_dim=256,
536
+ mlp_dim=2048,
537
+ num_heads=8,
538
+ ),
539
+ transformer_dim=256,
540
+ iou_head_depth=3,
541
+ iou_head_hidden_dim=256,
542
+ ),
543
+ image_size=(1024, image_size),
544
+ )
545
+
546
+
547
+ def efficientvit_sam_l0(image_size: int = 512, **kwargs) -> EfficientViTSam:
548
+ from efficientvit.models.efficientvit.backbone import \
549
+ efficientvit_backbone_l0
550
+
551
+ backbone = efficientvit_backbone_l0(**kwargs)
552
+
553
+ neck = SamNeck(
554
+ fid_list=["stage4", "stage3", "stage2"],
555
+ in_channel_list=[512, 256, 128],
556
+ head_width=256,
557
+ head_depth=4,
558
+ expand_ratio=1,
559
+ middle_op="fmb",
560
+ )
561
+
562
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
563
+ return build_efficientvit_sam(image_encoder, image_size)
564
+
565
+
566
+ def efficientvit_sam_l1(image_size: int = 512, **kwargs) -> EfficientViTSam:
567
+ from efficientvit.models.efficientvit.backbone import \
568
+ efficientvit_backbone_l1
569
+
570
+ backbone = efficientvit_backbone_l1(**kwargs)
571
+
572
+ neck = SamNeck(
573
+ fid_list=["stage4", "stage3", "stage2"],
574
+ in_channel_list=[512, 256, 128],
575
+ head_width=256,
576
+ head_depth=8,
577
+ expand_ratio=1,
578
+ middle_op="fmb",
579
+ )
580
+
581
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
582
+ return build_efficientvit_sam(image_encoder, image_size)
583
+
584
+
585
+ def efficientvit_sam_l2(image_size: int = 512, **kwargs) -> EfficientViTSam:
586
+ from efficientvit.models.efficientvit.backbone import \
587
+ efficientvit_backbone_l2
588
+
589
+ backbone = efficientvit_backbone_l2(**kwargs)
590
+
591
+ neck = SamNeck(
592
+ fid_list=["stage4", "stage3", "stage2"],
593
+ in_channel_list=[512, 256, 128],
594
+ head_width=256,
595
+ head_depth=12,
596
+ expand_ratio=1,
597
+ middle_op="fmb",
598
+ )
599
+
600
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
601
+ return build_efficientvit_sam(image_encoder, image_size)
602
+
603
+
604
+ def efficientvit_sam_xl0(image_size: int = 1024, **kwargs) -> EfficientViTSam:
605
+ from efficientvit.models.efficientvit.backbone import \
606
+ EfficientViTLargeBackbone
607
+
608
+ backbone = EfficientViTLargeBackbone(
609
+ width_list=[32, 64, 128, 256, 512, 1024],
610
+ depth_list=[0, 1, 1, 2, 3, 3],
611
+ block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
612
+ expand_list=[1, 4, 4, 4, 4, 6],
613
+ fewer_norm_list=[False, False, False, False, True, True],
614
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
615
+ )
616
+
617
+ neck = SamNeck(
618
+ fid_list=["stage5", "stage4", "stage3"],
619
+ in_channel_list=[1024, 512, 256],
620
+ head_width=256,
621
+ head_depth=6,
622
+ expand_ratio=4,
623
+ middle_op="fmb",
624
+ )
625
+
626
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
627
+ return build_efficientvit_sam(image_encoder, image_size)
628
+
629
+
630
+ def efficientvit_sam_xl1(image_size: int = 1024, **kwargs) -> EfficientViTSam:
631
+ from src.efficientvit.models.efficientvit.backbone import \
632
+ EfficientViTLargeBackbone
633
+
634
+ backbone = EfficientViTLargeBackbone(
635
+ width_list=[32, 64, 128, 256, 512, 1024],
636
+ depth_list=[1, 2, 2, 4, 6, 6],
637
+ block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
638
+ expand_list=[1, 4, 4, 4, 4, 6],
639
+ fewer_norm_list=[False, False, False, False, True, True],
640
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
641
+ )
642
+
643
+ neck = SamNeck(
644
+ fid_list=["stage5", "stage4", "stage3"],
645
+ in_channel_list=[1024, 512, 256],
646
+ head_width=256,
647
+ head_depth=12,
648
+ expand_ratio=4,
649
+ middle_op="fmb",
650
+ )
651
+
652
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
653
+ return build_efficientvit_sam(image_encoder, image_size)
src/efficientvit/models/efficientvit/seg.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from src.efficientvit.models.efficientvit.backbone import (
9
+ EfficientViTBackbone, EfficientViTLargeBackbone)
10
+ from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
11
+ IdentityLayer, MBConv, OpSequential,
12
+ ResidualBlock, UpSampleLayer)
13
+ from src.efficientvit.models.utils import build_kwargs_from_config
14
+
15
+ __all__ = [
16
+ "EfficientViTSeg",
17
+ "efficientvit_seg_b0",
18
+ "efficientvit_seg_b1",
19
+ "efficientvit_seg_b2",
20
+ "efficientvit_seg_b3",
21
+ "efficientvit_seg_l1",
22
+ "efficientvit_seg_l2",
23
+ ]
24
+
25
+
26
+ class SegHead(DAGBlock):
27
+ def __init__(
28
+ self,
29
+ fid_list: list[str],
30
+ in_channel_list: list[int],
31
+ stride_list: list[int],
32
+ head_stride: int,
33
+ head_width: int,
34
+ head_depth: int,
35
+ expand_ratio: float,
36
+ middle_op: str,
37
+ final_expand: float or None,
38
+ n_classes: int,
39
+ dropout=0,
40
+ norm="bn2d",
41
+ act_func="hswish",
42
+ ):
43
+ inputs = {}
44
+ for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list):
45
+ factor = stride // head_stride
46
+ if factor == 1:
47
+ inputs[fid] = ConvLayer(
48
+ in_channel, head_width, 1, norm=norm, act_func=None
49
+ )
50
+ else:
51
+ inputs[fid] = OpSequential(
52
+ [
53
+ ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
54
+ UpSampleLayer(factor=factor),
55
+ ]
56
+ )
57
+
58
+ middle = []
59
+ for _ in range(head_depth):
60
+ if middle_op == "mbconv":
61
+ block = MBConv(
62
+ head_width,
63
+ head_width,
64
+ expand_ratio=expand_ratio,
65
+ norm=norm,
66
+ act_func=(act_func, act_func, None),
67
+ )
68
+ elif middle_op == "fmbconv":
69
+ block = FusedMBConv(
70
+ head_width,
71
+ head_width,
72
+ expand_ratio=expand_ratio,
73
+ norm=norm,
74
+ act_func=(act_func, None),
75
+ )
76
+ else:
77
+ raise NotImplementedError
78
+ middle.append(ResidualBlock(block, IdentityLayer()))
79
+ middle = OpSequential(middle)
80
+
81
+ outputs = {
82
+ "segout": OpSequential(
83
+ [
84
+ (
85
+ None
86
+ if final_expand is None
87
+ else ConvLayer(
88
+ head_width,
89
+ head_width * final_expand,
90
+ 1,
91
+ norm=norm,
92
+ act_func=act_func,
93
+ )
94
+ ),
95
+ ConvLayer(
96
+ head_width * (final_expand or 1),
97
+ n_classes,
98
+ 1,
99
+ use_bias=True,
100
+ dropout=dropout,
101
+ norm=None,
102
+ act_func=None,
103
+ ),
104
+ ]
105
+ )
106
+ }
107
+
108
+ super(SegHead, self).__init__(
109
+ inputs, "add", None, middle=middle, outputs=outputs
110
+ )
111
+
112
+
113
+ class EfficientViTSeg(nn.Module):
114
+ def __init__(
115
+ self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: SegHead
116
+ ) -> None:
117
+ super().__init__()
118
+ self.backbone = backbone
119
+ self.head = head
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ feed_dict = self.backbone(x)
123
+ feed_dict = self.head(feed_dict)
124
+
125
+ return feed_dict["segout"]
126
+
127
+
128
+ def efficientvit_seg_b0(dataset: str, **kwargs) -> EfficientViTSeg:
129
+ from efficientvit.models.efficientvit.backbone import \
130
+ efficientvit_backbone_b0
131
+
132
+ backbone = efficientvit_backbone_b0(**kwargs)
133
+
134
+ if dataset == "cityscapes":
135
+ head = SegHead(
136
+ fid_list=["stage4", "stage3", "stage2"],
137
+ in_channel_list=[128, 64, 32],
138
+ stride_list=[32, 16, 8],
139
+ head_stride=8,
140
+ head_width=32,
141
+ head_depth=1,
142
+ expand_ratio=4,
143
+ middle_op="mbconv",
144
+ final_expand=4,
145
+ n_classes=19,
146
+ **build_kwargs_from_config(kwargs, SegHead),
147
+ )
148
+ else:
149
+ raise NotImplementedError
150
+ model = EfficientViTSeg(backbone, head)
151
+ return model
152
+
153
+
154
+ def efficientvit_seg_b1(dataset: str, **kwargs) -> EfficientViTSeg:
155
+ from efficientvit.models.efficientvit.backbone import \
156
+ efficientvit_backbone_b1
157
+
158
+ backbone = efficientvit_backbone_b1(**kwargs)
159
+
160
+ if dataset == "cityscapes":
161
+ head = SegHead(
162
+ fid_list=["stage4", "stage3", "stage2"],
163
+ in_channel_list=[256, 128, 64],
164
+ stride_list=[32, 16, 8],
165
+ head_stride=8,
166
+ head_width=64,
167
+ head_depth=3,
168
+ expand_ratio=4,
169
+ middle_op="mbconv",
170
+ final_expand=4,
171
+ n_classes=19,
172
+ **build_kwargs_from_config(kwargs, SegHead),
173
+ )
174
+ elif dataset == "ade20k":
175
+ head = SegHead(
176
+ fid_list=["stage4", "stage3", "stage2"],
177
+ in_channel_list=[256, 128, 64],
178
+ stride_list=[32, 16, 8],
179
+ head_stride=8,
180
+ head_width=64,
181
+ head_depth=3,
182
+ expand_ratio=4,
183
+ middle_op="mbconv",
184
+ final_expand=None,
185
+ n_classes=150,
186
+ **build_kwargs_from_config(kwargs, SegHead),
187
+ )
188
+ else:
189
+ raise NotImplementedError
190
+ model = EfficientViTSeg(backbone, head)
191
+ return model
192
+
193
+
194
+ def efficientvit_seg_b2(dataset: str, **kwargs) -> EfficientViTSeg:
195
+ from efficientvit.models.efficientvit.backbone import \
196
+ efficientvit_backbone_b2
197
+
198
+ backbone = efficientvit_backbone_b2(**kwargs)
199
+
200
+ if dataset == "cityscapes":
201
+ head = SegHead(
202
+ fid_list=["stage4", "stage3", "stage2"],
203
+ in_channel_list=[384, 192, 96],
204
+ stride_list=[32, 16, 8],
205
+ head_stride=8,
206
+ head_width=96,
207
+ head_depth=3,
208
+ expand_ratio=4,
209
+ middle_op="mbconv",
210
+ final_expand=4,
211
+ n_classes=19,
212
+ **build_kwargs_from_config(kwargs, SegHead),
213
+ )
214
+ elif dataset == "ade20k":
215
+ head = SegHead(
216
+ fid_list=["stage4", "stage3", "stage2"],
217
+ in_channel_list=[384, 192, 96],
218
+ stride_list=[32, 16, 8],
219
+ head_stride=8,
220
+ head_width=96,
221
+ head_depth=3,
222
+ expand_ratio=4,
223
+ middle_op="mbconv",
224
+ final_expand=None,
225
+ n_classes=150,
226
+ **build_kwargs_from_config(kwargs, SegHead),
227
+ )
228
+ else:
229
+ raise NotImplementedError
230
+ model = EfficientViTSeg(backbone, head)
231
+ return model
232
+
233
+
234
+ def efficientvit_seg_b3(dataset: str, **kwargs) -> EfficientViTSeg:
235
+ from efficientvit.models.efficientvit.backbone import \
236
+ efficientvit_backbone_b3
237
+
238
+ backbone = efficientvit_backbone_b3(**kwargs)
239
+
240
+ if dataset == "cityscapes":
241
+ head = SegHead(
242
+ fid_list=["stage4", "stage3", "stage2"],
243
+ in_channel_list=[512, 256, 128],
244
+ stride_list=[32, 16, 8],
245
+ head_stride=8,
246
+ head_width=128,
247
+ head_depth=3,
248
+ expand_ratio=4,
249
+ middle_op="mbconv",
250
+ final_expand=4,
251
+ n_classes=19,
252
+ **build_kwargs_from_config(kwargs, SegHead),
253
+ )
254
+ elif dataset == "ade20k":
255
+ head = SegHead(
256
+ fid_list=["stage4", "stage3", "stage2"],
257
+ in_channel_list=[512, 256, 128],
258
+ stride_list=[32, 16, 8],
259
+ head_stride=8,
260
+ head_width=128,
261
+ head_depth=3,
262
+ expand_ratio=4,
263
+ middle_op="mbconv",
264
+ final_expand=None,
265
+ n_classes=150,
266
+ **build_kwargs_from_config(kwargs, SegHead),
267
+ )
268
+ else:
269
+ raise NotImplementedError
270
+ model = EfficientViTSeg(backbone, head)
271
+ return model
272
+
273
+
274
+ def efficientvit_seg_l1(dataset: str, **kwargs) -> EfficientViTSeg:
275
+ from efficientvit.models.efficientvit.backbone import \
276
+ efficientvit_backbone_l1
277
+
278
+ backbone = efficientvit_backbone_l1(**kwargs)
279
+
280
+ if dataset == "cityscapes":
281
+ head = SegHead(
282
+ fid_list=["stage4", "stage3", "stage2"],
283
+ in_channel_list=[512, 256, 128],
284
+ stride_list=[32, 16, 8],
285
+ head_stride=8,
286
+ head_width=256,
287
+ head_depth=3,
288
+ expand_ratio=1,
289
+ middle_op="fmbconv",
290
+ final_expand=None,
291
+ n_classes=19,
292
+ act_func="gelu",
293
+ **build_kwargs_from_config(kwargs, SegHead),
294
+ )
295
+ elif dataset == "ade20k":
296
+ head = SegHead(
297
+ fid_list=["stage4", "stage3", "stage2"],
298
+ in_channel_list=[512, 256, 128],
299
+ stride_list=[32, 16, 8],
300
+ head_stride=8,
301
+ head_width=128,
302
+ head_depth=3,
303
+ expand_ratio=4,
304
+ middle_op="fmbconv",
305
+ final_expand=8,
306
+ n_classes=150,
307
+ act_func="gelu",
308
+ **build_kwargs_from_config(kwargs, SegHead),
309
+ )
310
+ else:
311
+ raise NotImplementedError
312
+ model = EfficientViTSeg(backbone, head)
313
+ return model
314
+
315
+
316
+ def efficientvit_seg_l2(dataset: str, **kwargs) -> EfficientViTSeg:
317
+ from efficientvit.models.efficientvit.backbone import \
318
+ efficientvit_backbone_l2
319
+
320
+ backbone = efficientvit_backbone_l2(**kwargs)
321
+
322
+ if dataset == "cityscapes":
323
+ head = SegHead(
324
+ fid_list=["stage4", "stage3", "stage2"],
325
+ in_channel_list=[512, 256, 128],
326
+ stride_list=[32, 16, 8],
327
+ head_stride=8,
328
+ head_width=256,
329
+ head_depth=5,
330
+ expand_ratio=1,
331
+ middle_op="fmbconv",
332
+ final_expand=None,
333
+ n_classes=19,
334
+ act_func="gelu",
335
+ **build_kwargs_from_config(kwargs, SegHead),
336
+ )
337
+ elif dataset == "ade20k":
338
+ head = SegHead(
339
+ fid_list=["stage4", "stage3", "stage2"],
340
+ in_channel_list=[512, 256, 128],
341
+ stride_list=[32, 16, 8],
342
+ head_stride=8,
343
+ head_width=128,
344
+ head_depth=3,
345
+ expand_ratio=4,
346
+ middle_op="fmbconv",
347
+ final_expand=8,
348
+ n_classes=150,
349
+ act_func="gelu",
350
+ **build_kwargs_from_config(kwargs, SegHead),
351
+ )
352
+ else:
353
+ raise NotImplementedError
354
+ model = EfficientViTSeg(backbone, head)
355
+ return model
src/efficientvit/models/nn/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .act import *
6
+ from .drop import *
7
+ from .norm import *
8
+ from .ops import *
src/efficientvit/models/nn/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from functools import partial
6
+
7
+ import torch.nn as nn
8
+
9
+ from src.efficientvit.models.utils import build_kwargs_from_config
10
+
11
+ __all__ = ["build_act"]
12
+
13
+
14
+ # register activation function here
15
+ REGISTERED_ACT_DICT: dict[str, type] = {
16
+ "relu": nn.ReLU,
17
+ "relu6": nn.ReLU6,
18
+ "hswish": nn.Hardswish,
19
+ "silu": nn.SiLU,
20
+ "gelu": partial(nn.GELU, approximate="tanh"),
21
+ }
22
+
23
+
24
+ def build_act(name: str, **kwargs) -> nn.Module or None:
25
+ if name in REGISTERED_ACT_DICT:
26
+ act_cls = REGISTERED_ACT_DICT[name]
27
+ args = build_kwargs_from_config(kwargs, act_cls)
28
+ return act_cls(**args)
29
+ else:
30
+ return None
src/efficientvit/models/nn/drop.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from src.efficientvit.apps.trainer.run_config import Scheduler
10
+ from src.efficientvit.models.nn.ops import IdentityLayer, ResidualBlock
11
+ from src.efficientvit.models.utils import build_kwargs_from_config
12
+
13
+ __all__ = ["apply_drop_func"]
14
+
15
+
16
+ def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None:
17
+ if drop_config is None:
18
+ return
19
+
20
+ drop_lookup_table = {
21
+ "droppath": apply_droppath,
22
+ }
23
+
24
+ drop_func = drop_lookup_table[drop_config["name"]]
25
+ drop_kwargs = build_kwargs_from_config(drop_config, drop_func)
26
+
27
+ drop_func(network, **drop_kwargs)
28
+
29
+
30
+ def apply_droppath(
31
+ network: nn.Module,
32
+ drop_prob: float,
33
+ linear_decay=True,
34
+ scheduled=True,
35
+ skip=0,
36
+ ) -> None:
37
+ all_valid_blocks = []
38
+ for m in network.modules():
39
+ for name, sub_module in m.named_children():
40
+ if isinstance(sub_module, ResidualBlock) and isinstance(
41
+ sub_module.shortcut, IdentityLayer
42
+ ):
43
+ all_valid_blocks.append((m, name, sub_module))
44
+ all_valid_blocks = all_valid_blocks[skip:]
45
+ for i, (m, name, sub_module) in enumerate(all_valid_blocks):
46
+ prob = (
47
+ drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob
48
+ )
49
+ new_module = DropPathResidualBlock(
50
+ sub_module.main,
51
+ sub_module.shortcut,
52
+ sub_module.post_act,
53
+ sub_module.pre_norm,
54
+ prob,
55
+ scheduled,
56
+ )
57
+ m._modules[name] = new_module
58
+
59
+
60
+ class DropPathResidualBlock(ResidualBlock):
61
+ def __init__(
62
+ self,
63
+ main: nn.Module,
64
+ shortcut: nn.Module or None,
65
+ post_act=None,
66
+ pre_norm: nn.Module or None = None,
67
+ ######################################
68
+ drop_prob: float = 0,
69
+ scheduled=True,
70
+ ):
71
+ super().__init__(main, shortcut, post_act, pre_norm)
72
+
73
+ self.drop_prob = drop_prob
74
+ self.scheduled = scheduled
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ if (
78
+ not self.training
79
+ or self.drop_prob == 0
80
+ or not isinstance(self.shortcut, IdentityLayer)
81
+ ):
82
+ return ResidualBlock.forward(self, x)
83
+ else:
84
+ drop_prob = self.drop_prob
85
+ if self.scheduled:
86
+ drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1)
87
+ keep_prob = 1 - drop_prob
88
+
89
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
90
+ random_tensor = keep_prob + torch.rand(
91
+ shape, dtype=x.dtype, device=x.device
92
+ )
93
+ random_tensor.floor_() # binarize
94
+
95
+ res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x)
96
+ if self.post_act:
97
+ res = self.post_act(res)
98
+ return res
src/efficientvit/models/nn/norm.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.modules.batchnorm import _BatchNorm
8
+
9
+ from src.efficientvit.models.utils import build_kwargs_from_config
10
+
11
+ __all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"]
12
+
13
+
14
+ class LayerNorm2d(nn.LayerNorm):
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ out = x - torch.mean(x, dim=1, keepdim=True)
17
+ out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
18
+ if self.elementwise_affine:
19
+ out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
20
+ return out
21
+
22
+
23
+ # register normalization function here
24
+ REGISTERED_NORM_DICT: dict[str, type] = {
25
+ "bn2d": nn.BatchNorm2d,
26
+ "ln": nn.LayerNorm,
27
+ "ln2d": LayerNorm2d,
28
+ }
29
+
30
+
31
+ def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None:
32
+ if name in ["ln", "ln2d"]:
33
+ kwargs["normalized_shape"] = num_features
34
+ else:
35
+ kwargs["num_features"] = num_features
36
+ if name in REGISTERED_NORM_DICT:
37
+ norm_cls = REGISTERED_NORM_DICT[name]
38
+ args = build_kwargs_from_config(kwargs, norm_cls)
39
+ return norm_cls(**args)
40
+ else:
41
+ return None
42
+
43
+
44
+ def reset_bn(
45
+ model: nn.Module,
46
+ data_loader: list,
47
+ sync=True,
48
+ progress_bar=False,
49
+ ) -> None:
50
+ import copy
51
+
52
+ import torch.nn.functional as F
53
+ from tqdm import tqdm
54
+
55
+ from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
56
+ from efficientvit.models.utils import get_device, list_join
57
+
58
+ bn_mean = {}
59
+ bn_var = {}
60
+
61
+ tmp_model = copy.deepcopy(model)
62
+ for name, m in tmp_model.named_modules():
63
+ if isinstance(m, _BatchNorm):
64
+ bn_mean[name] = AverageMeter(is_distributed=False)
65
+ bn_var[name] = AverageMeter(is_distributed=False)
66
+
67
+ def new_forward(bn, mean_est, var_est):
68
+ def lambda_forward(x):
69
+ x = x.contiguous()
70
+ if sync:
71
+ batch_mean = (
72
+ x.mean(0, keepdim=True)
73
+ .mean(2, keepdim=True)
74
+ .mean(3, keepdim=True)
75
+ ) # 1, C, 1, 1
76
+ batch_mean = sync_tensor(batch_mean, reduce="cat")
77
+ batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)
78
+
79
+ batch_var = (x - batch_mean) * (x - batch_mean)
80
+ batch_var = (
81
+ batch_var.mean(0, keepdim=True)
82
+ .mean(2, keepdim=True)
83
+ .mean(3, keepdim=True)
84
+ )
85
+ batch_var = sync_tensor(batch_var, reduce="cat")
86
+ batch_var = torch.mean(batch_var, dim=0, keepdim=True)
87
+ else:
88
+ batch_mean = (
89
+ x.mean(0, keepdim=True)
90
+ .mean(2, keepdim=True)
91
+ .mean(3, keepdim=True)
92
+ ) # 1, C, 1, 1
93
+ batch_var = (x - batch_mean) * (x - batch_mean)
94
+ batch_var = (
95
+ batch_var.mean(0, keepdim=True)
96
+ .mean(2, keepdim=True)
97
+ .mean(3, keepdim=True)
98
+ )
99
+
100
+ batch_mean = torch.squeeze(batch_mean)
101
+ batch_var = torch.squeeze(batch_var)
102
+
103
+ mean_est.update(batch_mean.data, x.size(0))
104
+ var_est.update(batch_var.data, x.size(0))
105
+
106
+ # bn forward using calculated mean & var
107
+ _feature_dim = batch_mean.shape[0]
108
+ return F.batch_norm(
109
+ x,
110
+ batch_mean,
111
+ batch_var,
112
+ bn.weight[:_feature_dim],
113
+ bn.bias[:_feature_dim],
114
+ False,
115
+ 0.0,
116
+ bn.eps,
117
+ )
118
+
119
+ return lambda_forward
120
+
121
+ m.forward = new_forward(m, bn_mean[name], bn_var[name])
122
+
123
+ # skip if there is no batch normalization layers in the network
124
+ if len(bn_mean) == 0:
125
+ return
126
+
127
+ tmp_model.eval()
128
+ with torch.no_grad():
129
+ with tqdm(
130
+ total=len(data_loader),
131
+ desc="reset bn",
132
+ disable=not progress_bar or not is_master(),
133
+ ) as t:
134
+ for images in data_loader:
135
+ images = images.to(get_device(tmp_model))
136
+ tmp_model(images)
137
+ t.set_postfix(
138
+ {
139
+ "bs": images.size(0),
140
+ "res": list_join(images.shape[-2:], "x"),
141
+ }
142
+ )
143
+ t.update()
144
+
145
+ for name, m in model.named_modules():
146
+ if name in bn_mean and bn_mean[name].count > 0:
147
+ feature_dim = bn_mean[name].avg.size(0)
148
+ assert isinstance(m, _BatchNorm)
149
+ m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
150
+ m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
151
+
152
+
153
+ def set_norm_eps(model: nn.Module, eps: float or None = None) -> None:
154
+ for m in model.modules():
155
+ if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
156
+ if eps is not None:
157
+ m.eps = eps
src/efficientvit/models/nn/ops.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.cuda.amp import autocast
9
+
10
+ from src.efficientvit.models.nn.act import build_act
11
+ from src.efficientvit.models.nn.norm import build_norm
12
+ from src.efficientvit.models.utils import (get_same_padding, list_sum, resize,
13
+ val2list, val2tuple)
14
+
15
+ __all__ = [
16
+ "ConvLayer",
17
+ "UpSampleLayer",
18
+ "LinearLayer",
19
+ "IdentityLayer",
20
+ "DSConv",
21
+ "MBConv",
22
+ "FusedMBConv",
23
+ "ResBlock",
24
+ "LiteMLA",
25
+ "EfficientViTBlock",
26
+ "ResidualBlock",
27
+ "DAGBlock",
28
+ "OpSequential",
29
+ ]
30
+
31
+
32
+ #################################################################################
33
+ # Basic Layers #
34
+ #################################################################################
35
+
36
+
37
+ class ConvLayer(nn.Module):
38
+ def __init__(
39
+ self,
40
+ in_channels: int,
41
+ out_channels: int,
42
+ kernel_size=3,
43
+ stride=1,
44
+ dilation=1,
45
+ groups=1,
46
+ use_bias=False,
47
+ dropout=0,
48
+ norm="bn2d",
49
+ act_func="relu",
50
+ ):
51
+ super(ConvLayer, self).__init__()
52
+
53
+ padding = get_same_padding(kernel_size)
54
+ padding *= dilation
55
+
56
+ self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
57
+ self.conv = nn.Conv2d(
58
+ in_channels,
59
+ out_channels,
60
+ kernel_size=(kernel_size, kernel_size),
61
+ stride=(stride, stride),
62
+ padding=padding,
63
+ dilation=(dilation, dilation),
64
+ groups=groups,
65
+ bias=use_bias,
66
+ )
67
+ self.norm = build_norm(norm, num_features=out_channels)
68
+ self.act = build_act(act_func)
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ if self.dropout is not None:
72
+ x = self.dropout(x)
73
+ x = self.conv(x)
74
+ if self.norm:
75
+ x = self.norm(x)
76
+ if self.act:
77
+ x = self.act(x)
78
+ return x
79
+
80
+
81
+ class UpSampleLayer(nn.Module):
82
+ def __init__(
83
+ self,
84
+ mode="bicubic",
85
+ size: int or tuple[int, int] or list[int] or None = None,
86
+ factor=2,
87
+ align_corners=False,
88
+ ):
89
+ super(UpSampleLayer, self).__init__()
90
+ self.mode = mode
91
+ self.size = val2list(size, 2) if size is not None else None
92
+ self.factor = None if self.size is not None else factor
93
+ self.align_corners = align_corners
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ if (
97
+ self.size is not None and tuple(x.shape[-2:]) == self.size
98
+ ) or self.factor == 1:
99
+ return x
100
+ return resize(x, self.size, self.factor, self.mode, self.align_corners)
101
+
102
+
103
+ class LinearLayer(nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_features: int,
107
+ out_features: int,
108
+ use_bias=True,
109
+ dropout=0,
110
+ norm=None,
111
+ act_func=None,
112
+ ):
113
+ super(LinearLayer, self).__init__()
114
+
115
+ self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None
116
+ self.linear = nn.Linear(in_features, out_features, use_bias)
117
+ self.norm = build_norm(norm, num_features=out_features)
118
+ self.act = build_act(act_func)
119
+
120
+ def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor:
121
+ if x.dim() > 2:
122
+ x = torch.flatten(x, start_dim=1)
123
+ return x
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ x = self._try_squeeze(x)
127
+ if self.dropout:
128
+ x = self.dropout(x)
129
+ x = self.linear(x)
130
+ if self.norm:
131
+ x = self.norm(x)
132
+ if self.act:
133
+ x = self.act(x)
134
+ return x
135
+
136
+
137
+ class IdentityLayer(nn.Module):
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ return x
140
+
141
+
142
+ #################################################################################
143
+ # Basic Blocks #
144
+ #################################################################################
145
+
146
+
147
+ class DSConv(nn.Module):
148
+ def __init__(
149
+ self,
150
+ in_channels: int,
151
+ out_channels: int,
152
+ kernel_size=3,
153
+ stride=1,
154
+ use_bias=False,
155
+ norm=("bn2d", "bn2d"),
156
+ act_func=("relu6", None),
157
+ ):
158
+ super(DSConv, self).__init__()
159
+
160
+ use_bias = val2tuple(use_bias, 2)
161
+ norm = val2tuple(norm, 2)
162
+ act_func = val2tuple(act_func, 2)
163
+
164
+ self.depth_conv = ConvLayer(
165
+ in_channels,
166
+ in_channels,
167
+ kernel_size,
168
+ stride,
169
+ groups=in_channels,
170
+ norm=norm[0],
171
+ act_func=act_func[0],
172
+ use_bias=use_bias[0],
173
+ )
174
+ self.point_conv = ConvLayer(
175
+ in_channels,
176
+ out_channels,
177
+ 1,
178
+ norm=norm[1],
179
+ act_func=act_func[1],
180
+ use_bias=use_bias[1],
181
+ )
182
+
183
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
184
+ x = self.depth_conv(x)
185
+ x = self.point_conv(x)
186
+ return x
187
+
188
+
189
+ class MBConv(nn.Module):
190
+ def __init__(
191
+ self,
192
+ in_channels: int,
193
+ out_channels: int,
194
+ kernel_size=3,
195
+ stride=1,
196
+ mid_channels=None,
197
+ expand_ratio=6,
198
+ use_bias=False,
199
+ norm=("bn2d", "bn2d", "bn2d"),
200
+ act_func=("relu6", "relu6", None),
201
+ ):
202
+ super(MBConv, self).__init__()
203
+
204
+ use_bias = val2tuple(use_bias, 3)
205
+ norm = val2tuple(norm, 3)
206
+ act_func = val2tuple(act_func, 3)
207
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
208
+
209
+ self.inverted_conv = ConvLayer(
210
+ in_channels,
211
+ mid_channels,
212
+ 1,
213
+ stride=1,
214
+ norm=norm[0],
215
+ act_func=act_func[0],
216
+ use_bias=use_bias[0],
217
+ )
218
+ self.depth_conv = ConvLayer(
219
+ mid_channels,
220
+ mid_channels,
221
+ kernel_size,
222
+ stride=stride,
223
+ groups=mid_channels,
224
+ norm=norm[1],
225
+ act_func=act_func[1],
226
+ use_bias=use_bias[1],
227
+ )
228
+ self.point_conv = ConvLayer(
229
+ mid_channels,
230
+ out_channels,
231
+ 1,
232
+ norm=norm[2],
233
+ act_func=act_func[2],
234
+ use_bias=use_bias[2],
235
+ )
236
+
237
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
238
+ x = self.inverted_conv(x)
239
+ x = self.depth_conv(x)
240
+ x = self.point_conv(x)
241
+ return x
242
+
243
+
244
+ class FusedMBConv(nn.Module):
245
+ def __init__(
246
+ self,
247
+ in_channels: int,
248
+ out_channels: int,
249
+ kernel_size=3,
250
+ stride=1,
251
+ mid_channels=None,
252
+ expand_ratio=6,
253
+ groups=1,
254
+ use_bias=False,
255
+ norm=("bn2d", "bn2d"),
256
+ act_func=("relu6", None),
257
+ ):
258
+ super().__init__()
259
+ use_bias = val2tuple(use_bias, 2)
260
+ norm = val2tuple(norm, 2)
261
+ act_func = val2tuple(act_func, 2)
262
+
263
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
264
+
265
+ self.spatial_conv = ConvLayer(
266
+ in_channels,
267
+ mid_channels,
268
+ kernel_size,
269
+ stride,
270
+ groups=groups,
271
+ use_bias=use_bias[0],
272
+ norm=norm[0],
273
+ act_func=act_func[0],
274
+ )
275
+ self.point_conv = ConvLayer(
276
+ mid_channels,
277
+ out_channels,
278
+ 1,
279
+ use_bias=use_bias[1],
280
+ norm=norm[1],
281
+ act_func=act_func[1],
282
+ )
283
+
284
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
285
+ x = self.spatial_conv(x)
286
+ x = self.point_conv(x)
287
+ return x
288
+
289
+
290
+ class ResBlock(nn.Module):
291
+ def __init__(
292
+ self,
293
+ in_channels: int,
294
+ out_channels: int,
295
+ kernel_size=3,
296
+ stride=1,
297
+ mid_channels=None,
298
+ expand_ratio=1,
299
+ use_bias=False,
300
+ norm=("bn2d", "bn2d"),
301
+ act_func=("relu6", None),
302
+ ):
303
+ super().__init__()
304
+ use_bias = val2tuple(use_bias, 2)
305
+ norm = val2tuple(norm, 2)
306
+ act_func = val2tuple(act_func, 2)
307
+
308
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
309
+
310
+ self.conv1 = ConvLayer(
311
+ in_channels,
312
+ mid_channels,
313
+ kernel_size,
314
+ stride,
315
+ use_bias=use_bias[0],
316
+ norm=norm[0],
317
+ act_func=act_func[0],
318
+ )
319
+ self.conv2 = ConvLayer(
320
+ mid_channels,
321
+ out_channels,
322
+ kernel_size,
323
+ 1,
324
+ use_bias=use_bias[1],
325
+ norm=norm[1],
326
+ act_func=act_func[1],
327
+ )
328
+
329
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
330
+ x = self.conv1(x)
331
+ x = self.conv2(x)
332
+ return x
333
+
334
+
335
+ class LiteMLA(nn.Module):
336
+ r"""Lightweight multi-scale linear attention"""
337
+
338
+ def __init__(
339
+ self,
340
+ in_channels: int,
341
+ out_channels: int,
342
+ heads: int or None = None,
343
+ heads_ratio: float = 1.0,
344
+ dim=8,
345
+ use_bias=False,
346
+ norm=(None, "bn2d"),
347
+ act_func=(None, None),
348
+ kernel_func="relu",
349
+ scales: tuple[int, ...] = (5,),
350
+ eps=1.0e-15,
351
+ ):
352
+ super(LiteMLA, self).__init__()
353
+ self.eps = eps
354
+ heads = heads or int(in_channels // dim * heads_ratio)
355
+
356
+ total_dim = heads * dim
357
+
358
+ use_bias = val2tuple(use_bias, 2)
359
+ norm = val2tuple(norm, 2)
360
+ act_func = val2tuple(act_func, 2)
361
+
362
+ self.dim = dim
363
+ self.qkv = ConvLayer(
364
+ in_channels,
365
+ 3 * total_dim,
366
+ 1,
367
+ use_bias=use_bias[0],
368
+ norm=norm[0],
369
+ act_func=act_func[0],
370
+ )
371
+ self.aggreg = nn.ModuleList(
372
+ [
373
+ nn.Sequential(
374
+ nn.Conv2d(
375
+ 3 * total_dim,
376
+ 3 * total_dim,
377
+ scale,
378
+ padding=get_same_padding(scale),
379
+ groups=3 * total_dim,
380
+ bias=use_bias[0],
381
+ ),
382
+ nn.Conv2d(
383
+ 3 * total_dim,
384
+ 3 * total_dim,
385
+ 1,
386
+ groups=3 * heads,
387
+ bias=use_bias[0],
388
+ ),
389
+ )
390
+ for scale in scales
391
+ ]
392
+ )
393
+ self.kernel_func = build_act(kernel_func, inplace=False)
394
+
395
+ self.proj = ConvLayer(
396
+ total_dim * (1 + len(scales)),
397
+ out_channels,
398
+ 1,
399
+ use_bias=use_bias[1],
400
+ norm=norm[1],
401
+ act_func=act_func[1],
402
+ )
403
+
404
+ @autocast(enabled=False)
405
+ def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
406
+ B, _, H, W = list(qkv.size())
407
+
408
+ if qkv.dtype == torch.float16:
409
+ qkv = qkv.float()
410
+
411
+ qkv = torch.reshape(
412
+ qkv,
413
+ (
414
+ B,
415
+ -1,
416
+ 3 * self.dim,
417
+ H * W,
418
+ ),
419
+ )
420
+ qkv = torch.transpose(qkv, -1, -2)
421
+ q, k, v = (
422
+ qkv[..., 0 : self.dim],
423
+ qkv[..., self.dim : 2 * self.dim],
424
+ qkv[..., 2 * self.dim :],
425
+ )
426
+
427
+ # lightweight linear attention
428
+ q = self.kernel_func(q)
429
+ k = self.kernel_func(k)
430
+
431
+ # linear matmul
432
+ trans_k = k.transpose(-1, -2)
433
+
434
+ v = F.pad(v, (0, 1), mode="constant", value=1)
435
+ kv = torch.matmul(trans_k, v)
436
+ out = torch.matmul(q, kv)
437
+ out = out[..., :-1] / (out[..., -1:] + self.eps)
438
+
439
+ out = torch.transpose(out, -1, -2)
440
+ out = torch.reshape(out, (B, -1, H, W))
441
+ return out
442
+
443
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
444
+ # generate multi-scale q, k, v
445
+ qkv = self.qkv(x)
446
+ multi_scale_qkv = [qkv]
447
+ for op in self.aggreg:
448
+ multi_scale_qkv.append(op(qkv))
449
+ multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
450
+
451
+ out = self.relu_linear_att(multi_scale_qkv)
452
+ out = self.proj(out)
453
+
454
+ return out
455
+
456
+
457
+ class EfficientViTBlock(nn.Module):
458
+ def __init__(
459
+ self,
460
+ in_channels: int,
461
+ heads_ratio: float = 1.0,
462
+ dim=32,
463
+ expand_ratio: float = 4,
464
+ scales=(5,),
465
+ norm="bn2d",
466
+ act_func="hswish",
467
+ ):
468
+ super(EfficientViTBlock, self).__init__()
469
+ self.context_module = ResidualBlock(
470
+ LiteMLA(
471
+ in_channels=in_channels,
472
+ out_channels=in_channels,
473
+ heads_ratio=heads_ratio,
474
+ dim=dim,
475
+ norm=(None, norm),
476
+ scales=scales,
477
+ ),
478
+ IdentityLayer(),
479
+ )
480
+ local_module = MBConv(
481
+ in_channels=in_channels,
482
+ out_channels=in_channels,
483
+ expand_ratio=expand_ratio,
484
+ use_bias=(True, True, False),
485
+ norm=(None, None, norm),
486
+ act_func=(act_func, act_func, None),
487
+ )
488
+ self.local_module = ResidualBlock(local_module, IdentityLayer())
489
+
490
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
491
+ x = self.context_module(x)
492
+ x = self.local_module(x)
493
+ return x
494
+
495
+
496
+ #################################################################################
497
+ # Functional Blocks #
498
+ #################################################################################
499
+
500
+
501
+ class ResidualBlock(nn.Module):
502
+ def __init__(
503
+ self,
504
+ main: nn.Module or None,
505
+ shortcut: nn.Module or None,
506
+ post_act=None,
507
+ pre_norm: nn.Module or None = None,
508
+ ):
509
+ super(ResidualBlock, self).__init__()
510
+
511
+ self.pre_norm = pre_norm
512
+ self.main = main
513
+ self.shortcut = shortcut
514
+ self.post_act = build_act(post_act)
515
+
516
+ def forward_main(self, x: torch.Tensor) -> torch.Tensor:
517
+ if self.pre_norm is None:
518
+ return self.main(x)
519
+ else:
520
+ return self.main(self.pre_norm(x))
521
+
522
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
523
+ if self.main is None:
524
+ res = x
525
+ elif self.shortcut is None:
526
+ res = self.forward_main(x)
527
+ else:
528
+ res = self.forward_main(x) + self.shortcut(x)
529
+ if self.post_act:
530
+ res = self.post_act(res)
531
+ return res
532
+
533
+
534
+ class DAGBlock(nn.Module):
535
+ def __init__(
536
+ self,
537
+ inputs: dict[str, nn.Module],
538
+ merge: str,
539
+ post_input: nn.Module or None,
540
+ middle: nn.Module,
541
+ outputs: dict[str, nn.Module],
542
+ ):
543
+ super(DAGBlock, self).__init__()
544
+
545
+ self.input_keys = list(inputs.keys())
546
+ self.input_ops = nn.ModuleList(list(inputs.values()))
547
+ self.merge = merge
548
+ self.post_input = post_input
549
+
550
+ self.middle = middle
551
+
552
+ self.output_keys = list(outputs.keys())
553
+ self.output_ops = nn.ModuleList(list(outputs.values()))
554
+
555
+ def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
556
+ feat = [
557
+ op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)
558
+ ]
559
+ if self.merge == "add":
560
+ feat = list_sum(feat)
561
+ elif self.merge == "cat":
562
+ feat = torch.concat(feat, dim=1)
563
+ else:
564
+ raise NotImplementedError
565
+ if self.post_input is not None:
566
+ feat = self.post_input(feat)
567
+ feat = self.middle(feat)
568
+ for key, op in zip(self.output_keys, self.output_ops):
569
+ feature_dict[key] = op(feat)
570
+ return feature_dict
571
+
572
+
573
+ class OpSequential(nn.Module):
574
+ def __init__(self, op_list: list[nn.Module or None]):
575
+ super(OpSequential, self).__init__()
576
+ valid_op_list = []
577
+ for op in op_list:
578
+ if op is not None:
579
+ valid_op_list.append(op)
580
+ self.op_list = nn.ModuleList(valid_op_list)
581
+
582
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
583
+ for op in self.op_list:
584
+ x = op(x)
585
+ return x
src/efficientvit/models/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .list import *
6
+ from .network import *
7
+ from .random import *
src/efficientvit/models/utils/list.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ __all__ = [
6
+ "list_sum",
7
+ "list_mean",
8
+ "weighted_list_sum",
9
+ "list_join",
10
+ "val2list",
11
+ "val2tuple",
12
+ "squeeze_list",
13
+ ]
14
+
15
+
16
+ def list_sum(x: list) -> any:
17
+ return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
18
+
19
+
20
+ def list_mean(x: list) -> any:
21
+ return list_sum(x) / len(x)
22
+
23
+
24
+ def weighted_list_sum(x: list, weights: list) -> any:
25
+ assert len(x) == len(weights)
26
+ return (
27
+ x[0] * weights[0]
28
+ if len(x) == 1
29
+ else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:])
30
+ )
31
+
32
+
33
+ def list_join(x: list, sep="\t", format_str="%s") -> str:
34
+ return sep.join([format_str % val for val in x])
35
+
36
+
37
+ def val2list(x: list or tuple or any, repeat_time=1) -> list:
38
+ if isinstance(x, (list, tuple)):
39
+ return list(x)
40
+ return [x for _ in range(repeat_time)]
41
+
42
+
43
+ def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:
44
+ x = val2list(x)
45
+
46
+ # repeat elements if necessary
47
+ if len(x) > 0:
48
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
49
+
50
+ return tuple(x)
51
+
52
+
53
+ def squeeze_list(x: list or None) -> list or any:
54
+ if x is not None and len(x) == 1:
55
+ return x[0]
56
+ else:
57
+ return x
src/efficientvit/models/utils/network.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+ from inspect import signature
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ __all__ = [
13
+ "is_parallel",
14
+ "get_device",
15
+ "get_same_padding",
16
+ "resize",
17
+ "build_kwargs_from_config",
18
+ "load_state_dict_from_file",
19
+ ]
20
+
21
+
22
+ def is_parallel(model: nn.Module) -> bool:
23
+ return isinstance(
24
+ model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
25
+ )
26
+
27
+
28
+ def get_device(model: nn.Module) -> torch.device:
29
+ return model.parameters().__next__().device
30
+
31
+
32
+ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
33
+ if isinstance(kernel_size, tuple):
34
+ return tuple([get_same_padding(ks) for ks in kernel_size])
35
+ else:
36
+ assert kernel_size % 2 > 0, "kernel size should be odd number"
37
+ return kernel_size // 2
38
+
39
+
40
+ def resize(
41
+ x: torch.Tensor,
42
+ size: any or None = None,
43
+ scale_factor: list[float] or None = None,
44
+ mode: str = "bicubic",
45
+ align_corners: bool or None = False,
46
+ ) -> torch.Tensor:
47
+ if mode in {"bilinear", "bicubic"}:
48
+ return F.interpolate(
49
+ x,
50
+ size=size,
51
+ scale_factor=scale_factor,
52
+ mode=mode,
53
+ align_corners=align_corners,
54
+ )
55
+ elif mode in {"nearest", "area"}:
56
+ return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
57
+ else:
58
+ raise NotImplementedError(f"resize(mode={mode}) not implemented.")
59
+
60
+
61
+ def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]:
62
+ valid_keys = list(signature(target_func).parameters)
63
+ kwargs = {}
64
+ for key in config:
65
+ if key in valid_keys:
66
+ kwargs[key] = config[key]
67
+ return kwargs
68
+
69
+
70
+ def load_state_dict_from_file(
71
+ file: str, only_state_dict=True
72
+ ) -> dict[str, torch.Tensor]:
73
+ file = os.path.realpath(os.path.expanduser(file))
74
+ checkpoint = torch.load(file, map_location="cpu")
75
+ if only_state_dict and "state_dict" in checkpoint:
76
+ checkpoint = checkpoint["state_dict"]
77
+ return checkpoint
src/efficientvit/models/utils/random.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ __all__ = [
9
+ "torch_randint",
10
+ "torch_random",
11
+ "torch_shuffle",
12
+ "torch_uniform",
13
+ "torch_random_choices",
14
+ ]
15
+
16
+
17
+ def torch_randint(
18
+ low: int, high: int, generator: torch.Generator or None = None
19
+ ) -> int:
20
+ """uniform: [low, high)"""
21
+ if low == high:
22
+ return low
23
+ else:
24
+ assert low < high
25
+ return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))
26
+
27
+
28
+ def torch_random(generator: torch.Generator or None = None) -> float:
29
+ """uniform distribution on the interval [0, 1)"""
30
+ return float(torch.rand(1, generator=generator))
31
+
32
+
33
+ def torch_shuffle(
34
+ src_list: list[any], generator: torch.Generator or None = None
35
+ ) -> list[any]:
36
+ rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
37
+ return [src_list[i] for i in rand_indexes]
38
+
39
+
40
+ def torch_uniform(
41
+ low: float, high: float, generator: torch.Generator or None = None
42
+ ) -> float:
43
+ """uniform distribution on the interval [low, high)"""
44
+ rand_val = torch_random(generator)
45
+ return (high - low) * rand_val + low
46
+
47
+
48
+ def torch_random_choices(
49
+ src_list: list[any],
50
+ generator: torch.Generator or None = None,
51
+ k=1,
52
+ weight_list: list[float] or None = None,
53
+ ) -> any or list:
54
+ if weight_list is None:
55
+ rand_idx = torch.randint(
56
+ low=0, high=len(src_list), generator=generator, size=(k,)
57
+ )
58
+ out_list = [src_list[i] for i in rand_idx]
59
+ else:
60
+ assert len(weight_list) == len(src_list)
61
+ accumulate_weight_list = np.cumsum(weight_list)
62
+
63
+ out_list = []
64
+ for _ in range(k):
65
+ val = torch_uniform(0, accumulate_weight_list[-1], generator)
66
+ active_id = 0
67
+ for i, weight_val in enumerate(accumulate_weight_list):
68
+ active_id = i
69
+ if weight_val > val:
70
+ break
71
+ out_list.append(src_list[active_id])
72
+
73
+ return out_list[0] if k == 1 else out_list
src/efficientvit/sam_model_zoo.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from src.efficientvit.models.efficientvit import (EfficientViTSam,
6
+ efficientvit_sam_l0,
7
+ efficientvit_sam_l1,
8
+ efficientvit_sam_l2,
9
+ efficientvit_sam_xl0,
10
+ efficientvit_sam_xl1)
11
+ from src.efficientvit.models.nn.norm import set_norm_eps
12
+ from src.efficientvit.models.utils import load_state_dict_from_file
13
+
14
+ __all__ = ["create_sam_model"]
15
+
16
+
17
+ REGISTERED_SAM_MODEL: dict[str, str] = {
18
+ "l0": "assets/checkpoints/sam/l0.pt",
19
+ "l1": "assets/checkpoints/sam/l1.pt",
20
+ "l2": "assets/checkpoints/sam/l2.pt",
21
+ "xl0": "assets/checkpoints/sam/xl0.pt",
22
+ "xl1": "assets/checkpoints/sam/xl1.pt",
23
+ }
24
+
25
+
26
+ def create_sam_model(
27
+ name: str, pretrained=True, weight_url: str or None = None, **kwargs
28
+ ) -> EfficientViTSam:
29
+ model_dict = {
30
+ "l0": efficientvit_sam_l0,
31
+ "l1": efficientvit_sam_l1,
32
+ "l2": efficientvit_sam_l2,
33
+ "xl0": efficientvit_sam_xl0,
34
+ "xl1": efficientvit_sam_xl1,
35
+ }
36
+
37
+ model_id = name.split("-")[0]
38
+ if model_id not in model_dict:
39
+ raise ValueError(
40
+ f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}"
41
+ )
42
+ else:
43
+ model = model_dict[model_id](**kwargs)
44
+ set_norm_eps(model, 1e-6)
45
+
46
+ if pretrained:
47
+ weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
48
+ if weight_url is None:
49
+ raise ValueError(f"Do not find the pretrained weight of {name}.")
50
+ else:
51
+ weight = load_state_dict_from_file(weight_url)
52
+ model.load_state_dict(weight)
53
+ return model
src/ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ try:
7
+ import xformers
8
+ import xformers.ops
9
+
10
+ xformers_available = True
11
+ except Exception as e:
12
+ xformers_available = False
13
+
14
+
15
+ class AttnProcessor(nn.Module):
16
+ r"""
17
+ Default processor for performing attention-related computations.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size=None,
23
+ cross_attention_dim=None,
24
+ ):
25
+ super().__init__()
26
+
27
+ def __call__(
28
+ self,
29
+ attn,
30
+ hidden_states,
31
+ encoder_hidden_states=None,
32
+ attention_mask=None,
33
+ temb=None,
34
+ ):
35
+ residual = hidden_states
36
+
37
+ if attn.spatial_norm is not None:
38
+ hidden_states = attn.spatial_norm(hidden_states, temb)
39
+
40
+ input_ndim = hidden_states.ndim
41
+
42
+ if input_ndim == 4:
43
+ batch_size, channel, height, width = hidden_states.shape
44
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
45
+
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
48
+ )
49
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
50
+
51
+ if attn.group_norm is not None:
52
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
53
+
54
+ query = attn.to_q(hidden_states)
55
+
56
+ if encoder_hidden_states is None:
57
+ encoder_hidden_states = hidden_states
58
+ elif attn.norm_cross:
59
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
60
+
61
+ key = attn.to_k(encoder_hidden_states)
62
+ value = attn.to_v(encoder_hidden_states)
63
+
64
+ query = attn.head_to_batch_dim(query)
65
+ key = attn.head_to_batch_dim(key)
66
+ value = attn.head_to_batch_dim(value)
67
+
68
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
69
+ hidden_states = torch.bmm(attention_probs, value)
70
+ hidden_states = attn.batch_to_head_dim(hidden_states)
71
+
72
+ # linear proj
73
+ hidden_states = attn.to_out[0](hidden_states)
74
+ # dropout
75
+ hidden_states = attn.to_out[1](hidden_states)
76
+
77
+ if input_ndim == 4:
78
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
79
+
80
+ if attn.residual_connection:
81
+ hidden_states = hidden_states + residual
82
+
83
+ hidden_states = hidden_states / attn.rescale_output_factor
84
+
85
+ return hidden_states
86
+
87
+
88
+ class IPAttnProcessor(nn.Module):
89
+ r"""
90
+ Attention processor for IP-Adapater.
91
+ Args:
92
+ hidden_size (`int`):
93
+ The hidden size of the attention layer.
94
+ cross_attention_dim (`int`):
95
+ The number of channels in the `encoder_hidden_states`.
96
+ scale (`float`, defaults to 1.0):
97
+ the weight scale of image prompt.
98
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
99
+ The context length of the image features.
100
+ """
101
+
102
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
103
+ super().__init__()
104
+
105
+ self.hidden_size = hidden_size
106
+ self.cross_attention_dim = cross_attention_dim
107
+ self.scale = scale
108
+ self.num_tokens = num_tokens
109
+
110
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
111
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
112
+
113
+ def __call__(
114
+ self,
115
+ attn,
116
+ hidden_states,
117
+ encoder_hidden_states=None,
118
+ attention_mask=None,
119
+ temb=None,
120
+ ):
121
+ residual = hidden_states
122
+
123
+ if attn.spatial_norm is not None:
124
+ hidden_states = attn.spatial_norm(hidden_states, temb)
125
+
126
+ input_ndim = hidden_states.ndim
127
+
128
+ if input_ndim == 4:
129
+ batch_size, channel, height, width = hidden_states.shape
130
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
131
+
132
+ batch_size, sequence_length, _ = (
133
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
134
+ )
135
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
136
+
137
+ if attn.group_norm is not None:
138
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
139
+
140
+ query = attn.to_q(hidden_states)
141
+
142
+ if encoder_hidden_states is None:
143
+ encoder_hidden_states = hidden_states
144
+ else:
145
+ # get encoder_hidden_states, ip_hidden_states
146
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
147
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:,
148
+ end_pos:, :]
149
+ if attn.norm_cross:
150
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
151
+
152
+ key = attn.to_k(encoder_hidden_states)
153
+ value = attn.to_v(encoder_hidden_states)
154
+
155
+ query = attn.head_to_batch_dim(query)
156
+ key = attn.head_to_batch_dim(key)
157
+ value = attn.head_to_batch_dim(value)
158
+
159
+ if xformers_available:
160
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
161
+ else:
162
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
163
+ hidden_states = torch.bmm(attention_probs, value)
164
+ hidden_states = attn.batch_to_head_dim(hidden_states)
165
+
166
+ # for ip-adapter
167
+ ip_key = self.to_k_ip(ip_hidden_states)
168
+ ip_value = self.to_v_ip(ip_hidden_states)
169
+
170
+ ip_key = attn.head_to_batch_dim(ip_key)
171
+ ip_value = attn.head_to_batch_dim(ip_value)
172
+
173
+ if xformers_available:
174
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
175
+ else:
176
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
177
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
178
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
179
+
180
+ hidden_states = hidden_states + self.scale * ip_hidden_states
181
+
182
+ # linear proj
183
+ hidden_states = attn.to_out[0](hidden_states)
184
+ # dropout
185
+ hidden_states = attn.to_out[1](hidden_states)
186
+
187
+ if input_ndim == 4:
188
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
189
+
190
+ if attn.residual_connection:
191
+ hidden_states = hidden_states + residual
192
+
193
+ hidden_states = hidden_states / attn.rescale_output_factor
194
+
195
+ return hidden_states
196
+
197
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
198
+ # TODO attention_mask
199
+ query = query.contiguous()
200
+ key = key.contiguous()
201
+ value = value.contiguous()
202
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
203
+ # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class AttnProcessor2_0(torch.nn.Module):
208
+ r"""
209
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ hidden_size=None,
215
+ cross_attention_dim=None,
216
+ ):
217
+ super().__init__()
218
+ if not hasattr(F, "scaled_dot_product_attention"):
219
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
220
+
221
+ def __call__(
222
+ self,
223
+ attn,
224
+ hidden_states,
225
+ encoder_hidden_states=None,
226
+ attention_mask=None,
227
+ temb=None,
228
+ ):
229
+ residual = hidden_states
230
+
231
+ if attn.spatial_norm is not None:
232
+ hidden_states = attn.spatial_norm(hidden_states, temb)
233
+
234
+ input_ndim = hidden_states.ndim
235
+
236
+ if input_ndim == 4:
237
+ batch_size, channel, height, width = hidden_states.shape
238
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
239
+
240
+ batch_size, sequence_length, _ = (
241
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
242
+ )
243
+
244
+ if attention_mask is not None:
245
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
246
+ # scaled_dot_product_attention expects attention_mask shape to be
247
+ # (batch, heads, source_length, target_length)
248
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
249
+
250
+ if attn.group_norm is not None:
251
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
252
+
253
+ query = attn.to_q(hidden_states)
254
+
255
+ if encoder_hidden_states is None:
256
+ encoder_hidden_states = hidden_states
257
+ elif attn.norm_cross:
258
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
259
+
260
+ key = attn.to_k(encoder_hidden_states)
261
+ value = attn.to_v(encoder_hidden_states)
262
+
263
+ inner_dim = key.shape[-1]
264
+ head_dim = inner_dim // attn.heads
265
+
266
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
267
+
268
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
269
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
270
+
271
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
272
+ # TODO: add support for attn.scale when we move to Torch 2.1
273
+ hidden_states = F.scaled_dot_product_attention(
274
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
275
+ )
276
+
277
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
278
+ hidden_states = hidden_states.to(query.dtype)
279
+
280
+ # linear proj
281
+ hidden_states = attn.to_out[0](hidden_states)
282
+ # dropout
283
+ hidden_states = attn.to_out[1](hidden_states)
284
+
285
+ if input_ndim == 4:
286
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
287
+
288
+ if attn.residual_connection:
289
+ hidden_states = hidden_states + residual
290
+
291
+ hidden_states = hidden_states / attn.rescale_output_factor
292
+
293
+ return hidden_states
294
+
295
+
296
+ class IPAttnProcessor2_0(torch.nn.Module):
297
+ r"""
298
+ Attention processor for IP-Adapater for PyTorch 2.0.
299
+ Args:
300
+ hidden_size (`int`):
301
+ The hidden size of the attention layer.
302
+ cross_attention_dim (`int`):
303
+ The number of channels in the `encoder_hidden_states`.
304
+ scale (`float`, defaults to 1.0):
305
+ the weight scale of image prompt.
306
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
307
+ The context length of the image features.
308
+ """
309
+
310
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
311
+ super().__init__()
312
+
313
+ if not hasattr(F, "scaled_dot_product_attention"):
314
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
315
+
316
+ self.hidden_size = hidden_size
317
+ self.cross_attention_dim = cross_attention_dim
318
+ self.scale = scale
319
+ self.num_tokens = num_tokens
320
+
321
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
322
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
323
+
324
+ def __call__(
325
+ self,
326
+ attn,
327
+ hidden_states,
328
+ encoder_hidden_states=None,
329
+ attention_mask=None,
330
+ temb=None,
331
+ ):
332
+ residual = hidden_states
333
+
334
+ if attn.spatial_norm is not None:
335
+ hidden_states = attn.spatial_norm(hidden_states, temb)
336
+
337
+ input_ndim = hidden_states.ndim
338
+
339
+ if input_ndim == 4:
340
+ batch_size, channel, height, width = hidden_states.shape
341
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
342
+
343
+ batch_size, sequence_length, _ = (
344
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
345
+ )
346
+
347
+ if attention_mask is not None:
348
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
349
+ # scaled_dot_product_attention expects attention_mask shape to be
350
+ # (batch, heads, source_length, target_length)
351
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
352
+
353
+ if attn.group_norm is not None:
354
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
355
+
356
+ query = attn.to_q(hidden_states)
357
+
358
+ if encoder_hidden_states is None:
359
+ encoder_hidden_states = hidden_states
360
+ else:
361
+ # get encoder_hidden_states, ip_hidden_states
362
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
363
+ encoder_hidden_states, ip_hidden_states = (
364
+ encoder_hidden_states[:, :end_pos, :],
365
+ encoder_hidden_states[:, end_pos:, :],
366
+ )
367
+ if attn.norm_cross:
368
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
369
+
370
+ key = attn.to_k(encoder_hidden_states)
371
+ value = attn.to_v(encoder_hidden_states)
372
+
373
+ inner_dim = key.shape[-1]
374
+ head_dim = inner_dim // attn.heads
375
+
376
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
377
+
378
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+
381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
382
+ # TODO: add support for attn.scale when we move to Torch 2.1
383
+ hidden_states = F.scaled_dot_product_attention(
384
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
385
+ )
386
+
387
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
388
+ hidden_states = hidden_states.to(query.dtype)
389
+
390
+ # for ip-adapter
391
+ ip_key = self.to_k_ip(ip_hidden_states)
392
+ ip_value = self.to_v_ip(ip_hidden_states)
393
+
394
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
395
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
396
+
397
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
398
+ # TODO: add support for attn.scale when we move to Torch 2.1
399
+ ip_hidden_states = F.scaled_dot_product_attention(
400
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
401
+ )
402
+ with torch.no_grad():
403
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
404
+ # print(self.attn_map.shape)
405
+
406
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
407
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
408
+
409
+ hidden_states = hidden_states + self.scale * ip_hidden_states
410
+
411
+ # linear proj
412
+ hidden_states = attn.to_out[0](hidden_states)
413
+ # dropout
414
+ hidden_states = attn.to_out[1](hidden_states)
415
+
416
+ if input_ndim == 4:
417
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
418
+
419
+ if attn.residual_connection:
420
+ hidden_states = hidden_states + residual
421
+
422
+ hidden_states = hidden_states / attn.rescale_output_factor
423
+
424
+ return hidden_states
src/ip_adapter/resampler.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head ** -0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+ def forward(self, x, latents):
46
+ """
47
+ Args:
48
+ x (torch.Tensor): image features
49
+ shape (b, n1, D)
50
+ latent (torch.Tensor): latent features
51
+ shape (b, n2, D)
52
+ """
53
+ x = self.norm1(x)
54
+ latents = self.norm2(latents)
55
+
56
+ b, l, _ = latents.shape
57
+
58
+ q = self.to_q(latents)
59
+ kv_input = torch.cat((x, latents), dim=-2)
60
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
61
+
62
+ q = reshape_tensor(q, self.heads)
63
+ k = reshape_tensor(k, self.heads)
64
+ v = reshape_tensor(v, self.heads)
65
+
66
+ # attention
67
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
68
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
69
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
70
+ out = weight @ v
71
+
72
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
73
+
74
+ return self.to_out(out)
75
+
76
+
77
+ class Resampler(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim=1024,
81
+ depth=8,
82
+ dim_head=64,
83
+ heads=16,
84
+ num_queries=8,
85
+ embedding_dim=768,
86
+ output_dim=1024,
87
+ ff_mult=4,
88
+ ):
89
+ super().__init__()
90
+
91
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
92
+
93
+ self.proj_in = nn.Linear(embedding_dim, dim)
94
+
95
+ self.proj_out = nn.Linear(dim, output_dim)
96
+ self.norm_out = nn.LayerNorm(output_dim)
97
+
98
+ self.layers = nn.ModuleList([])
99
+ for _ in range(depth):
100
+ self.layers.append(
101
+ nn.ModuleList(
102
+ [
103
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
104
+ FeedForward(dim=dim, mult=ff_mult),
105
+ ]
106
+ )
107
+ )
108
+
109
+ def forward(self, x):
110
+
111
+ latents = self.latents.repeat(x.size(0), 1, 1)
112
+
113
+ x = self.proj_in(x)
114
+
115
+ for attn, ff in self.layers:
116
+ latents = attn(x, latents) + latents
117
+ latents = ff(latents) + latents
118
+
119
+ latents = self.proj_out(latents)
120
+ return self.norm_out(latents)
src/ip_adapter/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")
src/pipelines/instantid_pipeline.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPTextModel,
11
+ CLIPTextModelWithProjection,
12
+ CLIPTokenizer,
13
+ CLIPVisionModelWithProjection,
14
+ )
15
+
16
+ from diffusers.utils.import_utils import is_invisible_watermark_available
17
+
18
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
19
+ from diffusers.loaders import (
20
+ FromSingleFileMixin,
21
+ IPAdapterMixin,
22
+ StableDiffusionXLLoraLoaderMixin,
23
+ TextualInversionLoaderMixin,
24
+ )
25
+ from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
26
+ from diffusers.models.attention_processor import (
27
+ AttnProcessor2_0,
28
+ LoRAAttnProcessor2_0,
29
+ LoRAXFormersAttnProcessor,
30
+ XFormersAttnProcessor,
31
+ )
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
45
+
46
+
47
+ if is_invisible_watermark_available():
48
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
49
+
50
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
51
+ from diffusers import StableDiffusionXLControlNetPipeline
52
+ from PIL import Image
53
+ from torchvision.transforms.functional import to_tensor
54
+ from einops import rearrange
55
+ from torch import einsum
56
+ import math
57
+ from torchvision.utils import save_image
58
+ from diffusers.utils import load_image
59
+ import cv2
60
+
61
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+
63
+ class RegionControlNet_AttnProcessor:
64
+ def __init__(self, attention_op=None, controller=None, place_in_unet=None):
65
+ self.attention_op = attention_op
66
+ self.controller = controller
67
+ self.place_in_unet = place_in_unet
68
+
69
+ def __call__(
70
+ self,
71
+ attn,
72
+ hidden_states: torch.FloatTensor,
73
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
74
+ attention_mask: Optional[torch.FloatTensor] = None,
75
+ temb: Optional[torch.FloatTensor] = None,
76
+ scale: float = 1.0,
77
+ **cross_attention_kwargs
78
+ ) -> torch.Tensor:
79
+ residual = hidden_states
80
+
81
+ args = () if USE_PEFT_BACKEND else (scale,)
82
+
83
+ if attn.spatial_norm is not None:
84
+ hidden_states = attn.spatial_norm(hidden_states, temb)
85
+
86
+ input_ndim = hidden_states.ndim
87
+
88
+ if input_ndim == 4:
89
+ batch_size, channel, height, width = hidden_states.shape
90
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
91
+
92
+ batch_size, sequence_length, _ = (
93
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
94
+ )
95
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
96
+
97
+ if attn.group_norm is not None:
98
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
99
+
100
+ query = attn.to_q(hidden_states, *args)
101
+
102
+ is_cross = True
103
+ if encoder_hidden_states is None:
104
+ is_cross = False
105
+ encoder_hidden_states = hidden_states
106
+ elif attn.norm_cross:
107
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
108
+
109
+ key = attn.to_k(encoder_hidden_states, *args)
110
+ value = attn.to_v(encoder_hidden_states, *args)
111
+
112
+ query = attn.head_to_batch_dim(query)
113
+ key = attn.head_to_batch_dim(key)
114
+ value = attn.head_to_batch_dim(value)
115
+
116
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
117
+ attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
118
+ hidden_states = torch.bmm(attention_probs, value)
119
+
120
+
121
+ hidden_states = attn.batch_to_head_dim(hidden_states)
122
+
123
+ # linear proj
124
+ hidden_states = attn.to_out[0](hidden_states, *args)
125
+ # dropout
126
+ hidden_states = attn.to_out[1](hidden_states)
127
+
128
+ if input_ndim == 4:
129
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
130
+
131
+ if attn.residual_connection:
132
+ hidden_states = hidden_states + residual
133
+
134
+ hidden_states = hidden_states / attn.rescale_output_factor
135
+
136
+ return hidden_states
137
+
138
+
139
+ def revise_regionally_controlnet_forward(unet, controller):
140
+ def change_forward(unet, count, place_in_unet):
141
+ for name, layer in unet.named_children():
142
+ if layer.__class__.__name__ == 'Attention':
143
+ layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet))
144
+ if 'attn2' in name:
145
+ count += 1
146
+ else:
147
+ count = change_forward(layer, count, place_in_unet)
148
+ return count
149
+
150
+ # use this to ensure the order
151
+ cross_attention_idx = change_forward(unet.down_blocks, 0, "down")
152
+ cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up")
153
+ cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid")
154
+ print(f'Number of attention layer registered {cross_attention_idx}')
155
+ controller.num_att_layers = cross_attention_idx*2
156
+
157
+ class InstantidMultiConceptPipeline(StableDiffusionXLControlNetPipeline):
158
+ # leave controlnet out on purpose because it iterates with unet
159
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
160
+ _optional_components = [
161
+ "tokenizer",
162
+ "tokenizer_2",
163
+ "text_encoder",
164
+ "text_encoder_2",
165
+ "feature_extractor",
166
+ "image_encoder",
167
+ ]
168
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
169
+
170
+ def __init__(
171
+ self,
172
+ vae: AutoencoderKL,
173
+ text_encoder: CLIPTextModel,
174
+ text_encoder_2: CLIPTextModelWithProjection,
175
+ tokenizer: CLIPTokenizer,
176
+ tokenizer_2: CLIPTokenizer,
177
+ unet: UNet2DConditionModel,
178
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
179
+ scheduler: KarrasDiffusionSchedulers,
180
+ force_zeros_for_empty_prompt: bool = True,
181
+ add_watermarker: Optional[bool] = None,
182
+ feature_extractor: CLIPImageProcessor = None,
183
+ image_encoder: CLIPVisionModelWithProjection = None,
184
+ ):
185
+ if isinstance(controlnet, (list, tuple)):
186
+ controlnet = MultiControlNetModel(controlnet)
187
+
188
+ self.register_modules(
189
+ vae=vae,
190
+ text_encoder=text_encoder,
191
+ text_encoder_2=text_encoder_2,
192
+ tokenizer=tokenizer,
193
+ tokenizer_2=tokenizer_2,
194
+ unet=unet,
195
+ controlnet=controlnet,
196
+ scheduler=scheduler,
197
+ feature_extractor=feature_extractor,
198
+ image_encoder=image_encoder,
199
+ )
200
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
201
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
202
+ self.control_image_processor = VaeImageProcessor(
203
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
204
+ )
205
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
206
+
207
+ if add_watermarker:
208
+ self.watermark = StableDiffusionXLWatermarker()
209
+ else:
210
+ self.watermark = None
211
+
212
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
213
+
214
+ @torch.no_grad()
215
+ def __call__(
216
+ self,
217
+ prompt: Union[str, List[str]] = None,
218
+ prompt_2: Optional[Union[str, List[str]]] = None,
219
+ image: PipelineImageInput = None,
220
+ height: Optional[int] = None,
221
+ width: Optional[int] = None,
222
+ num_inference_steps: int = 50,
223
+ guidance_scale: float = 5.0,
224
+ negative_prompt: Optional[Union[str, List[str]]] = None,
225
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
226
+ num_images_per_prompt: Optional[int] = 1,
227
+ eta: float = 0.0,
228
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
229
+ latents: Optional[torch.FloatTensor] = None,
230
+ prompt_embeds: Optional[torch.FloatTensor] = None,
231
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
232
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
233
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
234
+ ip_adapter_image: Optional[PipelineImageInput] = None,
235
+ output_type: Optional[str] = "pil",
236
+ return_dict: bool = True,
237
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
238
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
239
+ guess_mode: bool = False,
240
+ control_guidance_start: Union[float, List[float]] = 0.0,
241
+ control_guidance_end: Union[float, List[float]] = 1.0,
242
+ original_size: Tuple[int, int] = None,
243
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
244
+ target_size: Tuple[int, int] = None,
245
+ negative_original_size: Optional[Tuple[int, int]] = None,
246
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
247
+ negative_target_size: Optional[Tuple[int, int]] = None,
248
+ clip_skip: Optional[int] = None,
249
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
250
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
251
+ controller=None,
252
+ concept_models=None,
253
+ indices_to_alter=None,
254
+ face_app=None,
255
+ stage=None,
256
+ region_masks=None,
257
+ t2i_image=None,
258
+ t2i_controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
259
+ **kwargs,
260
+ ):
261
+ # revise_regionally_controlnet_forward(self.unet, controller)
262
+ callback = kwargs.pop("callback", None)
263
+ callback_steps = kwargs.pop("callback_steps", None)
264
+
265
+ if callback is not None:
266
+ deprecate(
267
+ "callback",
268
+ "1.0.0",
269
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
270
+ )
271
+ if callback_steps is not None:
272
+ deprecate(
273
+ "callback_steps",
274
+ "1.0.0",
275
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
276
+ )
277
+
278
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
279
+
280
+ # align format for control guidance
281
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
282
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
283
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
284
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
285
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
286
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
287
+ control_guidance_start, control_guidance_end = (
288
+ mult * [control_guidance_start],
289
+ mult * [control_guidance_end],
290
+ )
291
+
292
+ # 1. Check inputs. Raise error if not correct
293
+ self.check_inputs(
294
+ prompt,
295
+ prompt_2,
296
+ image,
297
+ callback_steps,
298
+ negative_prompt,
299
+ negative_prompt_2,
300
+ prompt_embeds,
301
+ negative_prompt_embeds,
302
+ pooled_prompt_embeds,
303
+ negative_pooled_prompt_embeds,
304
+ controlnet_conditioning_scale,
305
+ control_guidance_start,
306
+ control_guidance_end,
307
+ callback_on_step_end_tensor_inputs,
308
+ )
309
+
310
+ self._guidance_scale = guidance_scale
311
+ self._clip_skip = clip_skip
312
+ self._cross_attention_kwargs = cross_attention_kwargs
313
+
314
+ # 2. Define call parameters
315
+ batch_size = 2
316
+
317
+ device = self._execution_device
318
+
319
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
320
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
321
+
322
+ global_pool_conditions = (
323
+ controlnet.config.global_pool_conditions
324
+ if isinstance(controlnet, ControlNetModel)
325
+ else controlnet.nets[0].config.global_pool_conditions
326
+ )
327
+ guess_mode = guess_mode or global_pool_conditions
328
+
329
+ # 3.1 Encode input prompt
330
+ text_encoder_lora_scale = (
331
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
332
+ )
333
+
334
+ global_prompt = prompt[0]
335
+ global_negative_prompt = negative_prompt
336
+ region_prompts = [pt[0] for pt in prompt[1]]
337
+ region_negative_prompts = [pt[1] for pt in prompt[1]]
338
+ ref_images = [pt[2] for pt in prompt[1]]
339
+
340
+ concat_prompts = global_prompt + region_prompts
341
+ concat_negative_prompts = global_negative_prompt + region_negative_prompts
342
+
343
+ (
344
+ concat_prompt_embeds,
345
+ concat_negative_prompt_embeds,
346
+ concat_pooled_prompt_embeds,
347
+ concat_negative_pooled_prompt_embeds,
348
+ ) = self.encode_prompt(
349
+ concat_prompts,
350
+ prompt_2,
351
+ device,
352
+ num_images_per_prompt,
353
+ self.do_classifier_free_guidance,
354
+ concat_negative_prompts,
355
+ negative_prompt_2,
356
+ prompt_embeds=prompt_embeds,
357
+ negative_prompt_embeds=negative_prompt_embeds,
358
+ pooled_prompt_embeds=pooled_prompt_embeds,
359
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
360
+ lora_scale=text_encoder_lora_scale,
361
+ clip_skip=self.clip_skip,
362
+ )
363
+
364
+ prompt_embeds = concat_prompt_embeds[:2]
365
+ negative_prompt_embeds = concat_negative_prompt_embeds[:2]
366
+ pooled_prompt_embeds = concat_pooled_prompt_embeds[:2]
367
+ negative_pooled_prompt_embeds = concat_negative_pooled_prompt_embeds[:2]
368
+
369
+ region_prompt_embeds_list = []
370
+ region_add_text_embeds_list = []
371
+ for region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds in zip(concat_prompt_embeds[2:], concat_negative_prompt_embeds[2:], concat_pooled_prompt_embeds[2:], concat_negative_pooled_prompt_embeds[2:]):
372
+ region_prompt_embeds_list.append(
373
+ torch.concat([region_negative_prompt_embeds.unsqueeze(0), region_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device))
374
+ region_add_text_embeds_list.append(
375
+ torch.concat([region_negative_pooled_prompt_embeds.unsqueeze(0), region_pooled_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device))
376
+
377
+
378
+ if stage==2:
379
+ mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks]
380
+ image_embedding_list = get_face_embedding(face_app, ref_images)
381
+ image_prompt_image_emb_list = []
382
+ for image_embeds in image_embedding_list:
383
+ prompt_image_emb = concept_models._encode_prompt_image_emb(image_embeds,
384
+ concept_models._execution_device,
385
+ num_images_per_prompt,
386
+ concept_models.unet.dtype,
387
+ True)
388
+ image_prompt_image_emb_list.append(prompt_image_emb)
389
+
390
+
391
+
392
+ # 4. Prepare image
393
+ if isinstance(controlnet, ControlNetModel) and image is not None:
394
+ image = self.prepare_image(
395
+ image=image,
396
+ width=width,
397
+ height=height,
398
+ batch_size=1 * num_images_per_prompt,
399
+ num_images_per_prompt=num_images_per_prompt,
400
+ device=device,
401
+ dtype=controlnet.dtype,
402
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
403
+ guess_mode=guess_mode,
404
+ )
405
+ height, width = image.shape[-2:]
406
+ elif isinstance(controlnet, MultiControlNetModel) and image is not None:
407
+ images = []
408
+
409
+ for image_ in image:
410
+ image_ = self.prepare_image(
411
+ image=image_,
412
+ width=width,
413
+ height=height,
414
+ batch_size=batch_size * num_images_per_prompt,
415
+ num_images_per_prompt=num_images_per_prompt,
416
+ device=device,
417
+ dtype=controlnet.dtype,
418
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
419
+ guess_mode=guess_mode,
420
+ )
421
+
422
+ images.append(image_)
423
+
424
+ image = images
425
+ height, width = image[0].shape[-2:]
426
+ else:
427
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
428
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
429
+
430
+ if t2i_image is not None:
431
+ t2i_image = self.prepare_image(
432
+ image=t2i_image,
433
+ width=width,
434
+ height=height,
435
+ batch_size=batch_size * num_images_per_prompt,
436
+ num_images_per_prompt=num_images_per_prompt,
437
+ device=device,
438
+ dtype=controlnet.dtype,
439
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
440
+ guess_mode=guess_mode,
441
+ )
442
+ height, width = t2i_image.shape[-2:]
443
+
444
+ # 5. Prepare timesteps
445
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
446
+ timesteps = self.scheduler.timesteps
447
+ self._num_timesteps = len(timesteps)
448
+
449
+ # 6. Prepare latent variables
450
+ num_channels_latents = self.unet.config.in_channels
451
+ latents = self.prepare_latents(
452
+ batch_size//2 * num_images_per_prompt,
453
+ num_channels_latents,
454
+ height,
455
+ width,
456
+ prompt_embeds.dtype,
457
+ device,
458
+ generator,
459
+ latents,
460
+ )
461
+
462
+ # 6.1 repeat latent
463
+ latents = torch.cat([latents, latents.clone()])
464
+
465
+ # 6.5 Optionally get Guidance Scale Embedding
466
+ timestep_cond = None
467
+ if self.unet.config.time_cond_proj_dim is not None:
468
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
469
+ timestep_cond = self.get_guidance_scale_embedding(
470
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
471
+ ).to(device=device, dtype=latents.dtype)
472
+
473
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
474
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
475
+
476
+ # 7.1 Create tensor stating which controlnets to keep
477
+ controlnet_keep = []
478
+ for i in range(len(timesteps)):
479
+ keeps = [
480
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
481
+ for s, e in zip(control_guidance_start, control_guidance_end)
482
+ ]
483
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
484
+
485
+ # 7.2 Prepare added time ids & embeddings
486
+ if isinstance(image, list):
487
+ original_size = original_size or image[0].shape[-2:]
488
+ else:
489
+ original_size = original_size or (height, width)
490
+ target_size = target_size or (height, width)
491
+
492
+ add_text_embeds = pooled_prompt_embeds
493
+ if self.text_encoder_2 is None:
494
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
495
+ else:
496
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
497
+
498
+ add_time_ids = self._get_add_time_ids(
499
+ original_size,
500
+ crops_coords_top_left,
501
+ target_size,
502
+ dtype=prompt_embeds.dtype,
503
+ text_encoder_projection_dim=text_encoder_projection_dim,
504
+ )
505
+
506
+ add_time_ids_list = []
507
+ region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim)
508
+ for _ in range(len(prompt[1])):
509
+ add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device))
510
+
511
+ if negative_original_size is not None and negative_target_size is not None:
512
+ negative_add_time_ids = self._get_add_time_ids(
513
+ negative_original_size,
514
+ negative_crops_coords_top_left,
515
+ negative_target_size,
516
+ dtype=prompt_embeds.dtype,
517
+ text_encoder_projection_dim=text_encoder_projection_dim,
518
+ )
519
+ else:
520
+ negative_add_time_ids = add_time_ids
521
+
522
+ if self.do_classifier_free_guidance:
523
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
524
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
525
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
526
+
527
+ prompt_embeds = prompt_embeds.to(device)
528
+ add_text_embeds = add_text_embeds.to(device)
529
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
530
+
531
+ # 8. Denoising loop
532
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
533
+ is_unet_compiled = is_compiled_module(self.unet)
534
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
535
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
536
+ # hyper-parameters
537
+ scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps))
538
+
539
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
540
+ for i, t in enumerate(timesteps):
541
+ # Relevant thread:
542
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
543
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
544
+ torch._inductor.cudagraph_mark_step_begin()
545
+ # expand the latents if we are doing classifier free guidance
546
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
547
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
548
+
549
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
550
+
551
+ # controlnet(s) inference
552
+ if guess_mode and self.do_classifier_free_guidance:
553
+ # Infer ControlNet only for the conditional batch.
554
+ control_model_input = latents
555
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
556
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
557
+ controlnet_added_cond_kwargs = {
558
+ "text_embeds": add_text_embeds.chunk(2)[1],
559
+ "time_ids": add_time_ids.chunk(2)[1],
560
+ }
561
+ else:
562
+ control_model_input = latent_model_input
563
+ controlnet_prompt_embeds = prompt_embeds
564
+ controlnet_added_cond_kwargs = added_cond_kwargs
565
+
566
+ if isinstance(controlnet_keep[i], list):
567
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
568
+ else:
569
+ controlnet_cond_scale = controlnet_conditioning_scale
570
+ if isinstance(controlnet_cond_scale, list):
571
+ controlnet_cond_scale = controlnet_cond_scale[0]
572
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
573
+
574
+ if t2i_image is not None:
575
+ t2i_controlnet_cond_scale = t2i_controlnet_conditioning_scale
576
+ if isinstance(t2i_controlnet_cond_scale, list):
577
+ t2i_controlnet_cond_scale = t2i_controlnet_cond_scale[0]
578
+ t2i_cond_scale = t2i_controlnet_cond_scale * controlnet_keep[i]
579
+
580
+ t2i_down_block_res_samples, t2i_mid_block_res_sample = self.controlnet2(
581
+ control_model_input,
582
+ t,
583
+ encoder_hidden_states=controlnet_prompt_embeds,
584
+ controlnet_cond=t2i_image,
585
+ conditioning_scale=t2i_cond_scale,
586
+ guess_mode=guess_mode,
587
+ added_cond_kwargs=controlnet_added_cond_kwargs,
588
+ return_dict=False,
589
+ )
590
+ else:
591
+ t2i_down_block_res_samples = None
592
+ t2i_mid_block_res_sample = None
593
+
594
+
595
+ if t2i_image is None:
596
+ noise_pred = self.unet(
597
+ latent_model_input,
598
+ t,
599
+ encoder_hidden_states=prompt_embeds,
600
+ timestep_cond=timestep_cond,
601
+ cross_attention_kwargs=self.cross_attention_kwargs,
602
+ added_cond_kwargs=added_cond_kwargs,
603
+ return_dict=False,
604
+ )[0]
605
+ else:
606
+ noise_pred = self.unet(
607
+ latent_model_input,
608
+ t,
609
+ encoder_hidden_states=prompt_embeds,
610
+ timestep_cond=timestep_cond,
611
+ cross_attention_kwargs=self.cross_attention_kwargs,
612
+ down_block_additional_residuals=t2i_down_block_res_samples,
613
+ mid_block_additional_residual=t2i_mid_block_res_sample,
614
+ added_cond_kwargs=added_cond_kwargs,
615
+ return_dict=False,
616
+ )[0]
617
+
618
+ if i > 15 and stage == 2:
619
+ region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3])
620
+ edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0)
621
+ new_noise_pred = torch.zeros_like(edit_noise)
622
+ new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0]
623
+ replace_ratio = 1.0
624
+ new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0]
625
+
626
+ for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, region_prompt_image_emb in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, image_prompt_image_emb_list):
627
+ if concept_mask is not None:
628
+ concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0),
629
+ size=(noise_pred.shape[2], noise_pred.shape[3]),
630
+ mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device)
631
+
632
+ region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device)
633
+
634
+ region_latent_model_input = torch.cat([region_latent_model_input] * 2)
635
+ region_added_cond_kwargs = {"text_embeds": region_add_text_embeds,
636
+ "time_ids": region_add_time_ids}
637
+
638
+ if image is not None:
639
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
640
+ region_latent_model_input,
641
+ t,
642
+ encoder_hidden_states=region_prompt_image_emb,
643
+ controlnet_cond=image,
644
+ conditioning_scale=cond_scale,
645
+ guess_mode=guess_mode,
646
+ added_cond_kwargs=region_added_cond_kwargs,
647
+ return_dict=False,
648
+ )
649
+
650
+ if guess_mode and self.do_classifier_free_guidance:
651
+ # Infered ControlNet only for the conditional batch.
652
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
653
+ # add 0 to the unconditional batch to keep it unchanged.
654
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in
655
+ down_block_res_samples]
656
+ mid_block_res_sample = torch.cat(
657
+ [torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
658
+
659
+ else:
660
+ down_block_res_samples = None
661
+ mid_block_res_sample = None
662
+
663
+ region_encoder_hidden_states = torch.cat([region_prompt_embeds, region_prompt_image_emb], dim=1)
664
+
665
+ region_noise_pred = concept_models.unet(
666
+ region_latent_model_input,
667
+ t,
668
+ encoder_hidden_states=region_encoder_hidden_states,
669
+ cross_attention_kwargs=None,
670
+ down_block_additional_residuals=down_block_res_samples,
671
+ mid_block_additional_residual=mid_block_res_sample,
672
+ added_cond_kwargs=region_added_cond_kwargs,
673
+ return_dict=False,
674
+ )[0]
675
+
676
+
677
+ new_noise_pred = new_noise_pred.to(concept_models._execution_device)
678
+ new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device)))
679
+
680
+
681
+ new_noise_pred = new_noise_pred.to(noise_pred.device)
682
+ noise_pred[1, :, :, :] = new_noise_pred[0]
683
+ noise_pred[3, :, :, :] = new_noise_pred[1]
684
+
685
+ if self.do_classifier_free_guidance:
686
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
687
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
688
+
689
+ # compute the previous noisy sample x_t -> x_t-1
690
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
691
+
692
+ if callback_on_step_end is not None:
693
+ callback_kwargs = {}
694
+ for k in callback_on_step_end_tensor_inputs:
695
+ callback_kwargs[k] = locals()[k]
696
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
697
+
698
+ latents = callback_outputs.pop("latents", latents)
699
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
700
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
701
+
702
+ # call the callback, if provided
703
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
704
+ progress_bar.update()
705
+ if callback is not None and i % callback_steps == 0:
706
+ step_idx = i // getattr(self.scheduler, "order", 1)
707
+ callback(step_idx, t, latents)
708
+
709
+ # manually for max memory savings
710
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
711
+ self.upcast_vae()
712
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
713
+
714
+ if not output_type == "latent":
715
+ # make sure the VAE is in float32 mode, as it overflows in float16
716
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
717
+
718
+ if needs_upcasting:
719
+ self.upcast_vae()
720
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
721
+
722
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
723
+
724
+ # cast back to fp16 if needed
725
+ if needs_upcasting:
726
+ self.vae.to(dtype=torch.float16)
727
+ else:
728
+ image = latents
729
+
730
+ if not output_type == "latent":
731
+ # apply watermark if available
732
+ if self.watermark is not None:
733
+ image = self.watermark.apply_watermark(image)
734
+
735
+ image = self.image_processor.postprocess(image, output_type=output_type)
736
+
737
+ # Offload all models
738
+ self.maybe_free_model_hooks()
739
+
740
+ if not return_dict:
741
+ return (image,)
742
+
743
+ return StableDiffusionXLPipelineOutput(images=image)
744
+
745
+ def check_image(self, image, prompt, prompt_embeds):
746
+ pass
747
+
748
+ def get_region_mask(self, mask_list, feat_height, feat_width):
749
+ exclusive_mask = torch.zeros((feat_height, feat_width))
750
+ for mask in mask_list:
751
+ if mask is not None:
752
+ mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width),
753
+ mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device)
754
+ exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype)
755
+ return exclusive_mask
756
+
757
+ def get_face_embedding(face_app, ref_images):
758
+ emb_list = []
759
+ for img_path in ref_images:
760
+ face_image = load_image(img_path)
761
+
762
+ # prepare face emb
763
+ face_info = face_app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
764
+ face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[0] # only use the maximum face
765
+ face_emb = face_info['embedding']
766
+ emb_list.append(face_emb)
767
+ # face_kps = draw_kps(face_image, face_info['kps'])
768
+ return emb_list
src/pipelines/instantid_single_pieline.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import cv2
19
+ import math
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+ from diffusers.image_processor import PipelineImageInput
27
+
28
+ from diffusers.models import ControlNetModel
29
+
30
+ from diffusers.utils import (
31
+ deprecate,
32
+ logging,
33
+ replace_example_docstring,
34
+ )
35
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
37
+
38
+ from diffusers import StableDiffusionXLControlNetPipeline
39
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+
42
+ from src.ip_adapter.resampler import Resampler
43
+ from src.ip_adapter.utils import is_torch2_available
44
+
45
+ if is_torch2_available():
46
+ from src.ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
47
+ else:
48
+ from src.ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """
53
+ Examples:
54
+ ```py
55
+ >>> # !pip install opencv-python transformers accelerate insightface
56
+ >>> import diffusers
57
+ >>> from diffusers.utils import load_image
58
+ >>> from diffusers.models import ControlNetModel
59
+
60
+ >>> import cv2
61
+ >>> import torch
62
+ >>> import numpy as np
63
+ >>> from PIL import Image
64
+
65
+ >>> from insightface.app import FaceAnalysis
66
+ >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
67
+
68
+ >>> # download 'antelopev2' under ./models
69
+ >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
70
+ >>> app.prepare(ctx_id=0, det_size=(640, 640))
71
+
72
+ >>> # download models under ./checkpoints
73
+ >>> face_adapter = f'./checkpoints/ip-adapter.bin'
74
+ >>> controlnet_path = f'./checkpoints/ControlNetModel'
75
+
76
+ >>> # load IdentityNet
77
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
78
+
79
+ >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
80
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
81
+ ... )
82
+ >>> pipe.cuda()
83
+
84
+ >>> # load adapter
85
+ >>> pipe.load_ip_adapter_instantid(face_adapter)
86
+
87
+ >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
88
+ >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
89
+
90
+ >>> # load an image
91
+ >>> image = load_image("your-example.jpg")
92
+
93
+ >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
94
+ >>> face_emb = face_info['embedding']
95
+ >>> face_kps = draw_kps(face_image, face_info['kps'])
96
+
97
+ >>> pipe.set_ip_adapter_scale(0.8)
98
+
99
+ >>> # generate image
100
+ >>> image = pipe(
101
+ ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
102
+ ... ).images[0]
103
+ ```
104
+ """
105
+
106
+
107
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
108
+ stickwidth = 4
109
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
110
+ kps = np.array(kps)
111
+
112
+ w, h = image_pil.size
113
+ out_img = np.zeros([h, w, 3])
114
+
115
+ for i in range(len(limbSeq)):
116
+ index = limbSeq[i]
117
+ color = color_list[index[0]]
118
+
119
+ x = kps[index][:, 0]
120
+ y = kps[index][:, 1]
121
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
122
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
123
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
124
+ 360, 1)
125
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
126
+ out_img = (out_img * 0.6).astype(np.uint8)
127
+
128
+ for idx_kp, kp in enumerate(kps):
129
+ color = color_list[idx_kp]
130
+ x, y = kp
131
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
132
+
133
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
134
+ return out_img_pil
135
+
136
+
137
+ class InstantidSingleConceptPipeline(StableDiffusionXLControlNetPipeline):
138
+
139
+ def cuda(self, dtype=torch.float16, use_xformers=False):
140
+ self.to('cuda', dtype)
141
+
142
+ if hasattr(self, 'image_proj_model'):
143
+ self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
144
+
145
+ if use_xformers:
146
+ if is_xformers_available():
147
+ import xformers
148
+ from packaging import version
149
+
150
+ xformers_version = version.parse(xformers.__version__)
151
+ if xformers_version == version.parse("0.0.16"):
152
+ logger.warn(
153
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
154
+ )
155
+ self.enable_xformers_memory_efficient_attention()
156
+ else:
157
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
158
+
159
+ def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
160
+ self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
161
+ self.set_ip_adapter(model_ckpt, num_tokens, scale)
162
+
163
+ def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
164
+
165
+ image_proj_model = Resampler(
166
+ dim=1280,
167
+ depth=4,
168
+ dim_head=64,
169
+ heads=20,
170
+ num_queries=num_tokens,
171
+ embedding_dim=image_emb_dim,
172
+ output_dim=self.unet.config.cross_attention_dim,
173
+ ff_mult=4,
174
+ )
175
+
176
+ image_proj_model.eval()
177
+
178
+ self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
179
+ state_dict = torch.load(model_ckpt, map_location="cpu")
180
+ if 'image_proj' in state_dict:
181
+ state_dict = state_dict["image_proj"]
182
+ self.image_proj_model.load_state_dict(state_dict)
183
+
184
+ self.image_proj_model_in_features = image_emb_dim
185
+
186
+ def set_ip_adapter(self, model_ckpt, num_tokens, scale):
187
+
188
+ unet = self.unet
189
+ attn_procs = {}
190
+ for name in unet.attn_processors.keys():
191
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
192
+ if name.startswith("mid_block"):
193
+ hidden_size = unet.config.block_out_channels[-1]
194
+ elif name.startswith("up_blocks"):
195
+ block_id = int(name[len("up_blocks.")])
196
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
197
+ elif name.startswith("down_blocks"):
198
+ block_id = int(name[len("down_blocks.")])
199
+ hidden_size = unet.config.block_out_channels[block_id]
200
+ if cross_attention_dim is None:
201
+ attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
202
+ else:
203
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
204
+ cross_attention_dim=cross_attention_dim,
205
+ scale=scale,
206
+ num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
207
+ unet.set_attn_processor(attn_procs)
208
+
209
+ state_dict = torch.load(model_ckpt, map_location="cpu")
210
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
211
+ if 'ip_adapter' in state_dict:
212
+ state_dict = state_dict['ip_adapter']
213
+ ip_layers.load_state_dict(state_dict)
214
+
215
+ def set_ip_adapter_scale(self, scale):
216
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
217
+ for attn_processor in unet.attn_processors.values():
218
+ if isinstance(attn_processor, IPAttnProcessor):
219
+ attn_processor.scale = scale
220
+
221
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype,
222
+ do_classifier_free_guidance):
223
+
224
+ if isinstance(prompt_image_emb, torch.Tensor):
225
+ prompt_image_emb = prompt_image_emb.clone().detach()
226
+ else:
227
+ prompt_image_emb = torch.tensor(prompt_image_emb)
228
+
229
+ prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
230
+ prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
231
+
232
+ if do_classifier_free_guidance:
233
+ prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
234
+ else:
235
+ prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
236
+
237
+ prompt_image_emb = self.image_proj_model(prompt_image_emb)
238
+
239
+ bs_embed, seq_len, _ = prompt_image_emb.shape
240
+ prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
241
+ prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
242
+
243
+ return prompt_image_emb
244
+
245
+ @torch.no_grad()
246
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
247
+ def __call__(
248
+ self,
249
+ prompt: Union[str, List[str]] = None,
250
+ prompt_2: Optional[Union[str, List[str]]] = None,
251
+ image: PipelineImageInput = None,
252
+ height: Optional[int] = None,
253
+ width: Optional[int] = None,
254
+ num_inference_steps: int = 50,
255
+ guidance_scale: float = 5.0,
256
+ negative_prompt: Optional[Union[str, List[str]]] = None,
257
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
258
+ num_images_per_prompt: Optional[int] = 1,
259
+ eta: float = 0.0,
260
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
261
+ latents: Optional[torch.FloatTensor] = None,
262
+ prompt_embeds: Optional[torch.FloatTensor] = None,
263
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
264
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
265
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
266
+ image_embeds: Optional[torch.FloatTensor] = None,
267
+ output_type: Optional[str] = "pil",
268
+ return_dict: bool = True,
269
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
270
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
271
+ guess_mode: bool = False,
272
+ control_guidance_start: Union[float, List[float]] = 0.0,
273
+ control_guidance_end: Union[float, List[float]] = 1.0,
274
+ original_size: Tuple[int, int] = None,
275
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
276
+ target_size: Tuple[int, int] = None,
277
+ negative_original_size: Optional[Tuple[int, int]] = None,
278
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
279
+ negative_target_size: Optional[Tuple[int, int]] = None,
280
+ clip_skip: Optional[int] = None,
281
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
282
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
283
+
284
+ # IP adapter
285
+ ip_adapter_scale=None,
286
+
287
+ **kwargs,
288
+ ):
289
+ r"""
290
+ The call function to the pipeline for generation.
291
+
292
+ Args:
293
+ prompt (`str` or `List[str]`, *optional*):
294
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
295
+ prompt_2 (`str` or `List[str]`, *optional*):
296
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
297
+ used in both text-encoders.
298
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
299
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
300
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
301
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
302
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
303
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
304
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
305
+ input to a single ControlNet.
306
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
307
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
308
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
309
+ and checkpoints that are not specifically fine-tuned on low resolutions.
310
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
311
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
312
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
313
+ and checkpoints that are not specifically fine-tuned on low resolutions.
314
+ num_inference_steps (`int`, *optional*, defaults to 50):
315
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
316
+ expense of slower inference.
317
+ guidance_scale (`float`, *optional*, defaults to 5.0):
318
+ A higher guidance scale value encourages the model to generate images closely linked to the text
319
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
320
+ negative_prompt (`str` or `List[str]`, *optional*):
321
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
322
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
323
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
324
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
325
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
326
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
327
+ The number of images to generate per prompt.
328
+ eta (`float`, *optional*, defaults to 0.0):
329
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
330
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
331
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
332
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
333
+ generation deterministic.
334
+ latents (`torch.FloatTensor`, *optional*):
335
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
336
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
337
+ tensor is generated by sampling using the supplied random `generator`.
338
+ prompt_embeds (`torch.FloatTensor`, *optional*):
339
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
340
+ provided, text embeddings are generated from the `prompt` input argument.
341
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
342
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
343
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
344
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
345
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
346
+ not provided, pooled text embeddings are generated from `prompt` input argument.
347
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
348
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
349
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
350
+ argument.
351
+ image_embeds (`torch.FloatTensor`, *optional*):
352
+ Pre-generated image embeddings.
353
+ output_type (`str`, *optional*, defaults to `"pil"`):
354
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
355
+ return_dict (`bool`, *optional*, defaults to `True`):
356
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
357
+ plain tuple.
358
+ cross_attention_kwargs (`dict`, *optional*):
359
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
360
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
361
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
362
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
363
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
364
+ the corresponding scale as a list.
365
+ guess_mode (`bool`, *optional*, defaults to `False`):
366
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
367
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
368
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
369
+ The percentage of total steps at which the ControlNet starts applying.
370
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
371
+ The percentage of total steps at which the ControlNet stops applying.
372
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
373
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
374
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
375
+ explained in section 2.2 of
376
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
377
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
378
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
379
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
380
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
381
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
382
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
383
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
384
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
385
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
386
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
387
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
388
+ micro-conditioning as explained in section 2.2 of
389
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
390
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
391
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
392
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
393
+ micro-conditioning as explained in section 2.2 of
394
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
395
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
396
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
397
+ To negatively condition the generation process based on a target image resolution. It should be as same
398
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
399
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
400
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
401
+ clip_skip (`int`, *optional*):
402
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
403
+ the output of the pre-final layer will be used for computing the prompt embeddings.
404
+ callback_on_step_end (`Callable`, *optional*):
405
+ A function that calls at the end of each denoising steps during the inference. The function is called
406
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
407
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
408
+ `callback_on_step_end_tensor_inputs`.
409
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
410
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
411
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
412
+ `._callback_tensor_inputs` attribute of your pipeine class.
413
+
414
+ Examples:
415
+
416
+ Returns:
417
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
418
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
419
+ otherwise a `tuple` is returned containing the output images.
420
+ """
421
+
422
+ callback = kwargs.pop("callback", None)
423
+ callback_steps = kwargs.pop("callback_steps", None)
424
+
425
+ if callback is not None:
426
+ deprecate(
427
+ "callback",
428
+ "1.0.0",
429
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
430
+ )
431
+ if callback_steps is not None:
432
+ deprecate(
433
+ "callback_steps",
434
+ "1.0.0",
435
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
436
+ )
437
+
438
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
439
+
440
+ # align format for control guidance
441
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
442
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
443
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
444
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
445
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
446
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
447
+ control_guidance_start, control_guidance_end = (
448
+ mult * [control_guidance_start],
449
+ mult * [control_guidance_end],
450
+ )
451
+
452
+ # 0. set ip_adapter_scale
453
+ if ip_adapter_scale is not None:
454
+ self.set_ip_adapter_scale(ip_adapter_scale)
455
+
456
+ # 1. Check inputs. Raise error if not correct
457
+ self.check_inputs(
458
+ prompt,
459
+ prompt_2,
460
+ image,
461
+ callback_steps,
462
+ negative_prompt,
463
+ negative_prompt_2,
464
+ prompt_embeds,
465
+ negative_prompt_embeds,
466
+ pooled_prompt_embeds,
467
+ negative_pooled_prompt_embeds,
468
+ controlnet_conditioning_scale,
469
+ control_guidance_start,
470
+ control_guidance_end,
471
+ callback_on_step_end_tensor_inputs,
472
+ )
473
+
474
+ self._guidance_scale = guidance_scale
475
+ self._clip_skip = clip_skip
476
+ self._cross_attention_kwargs = cross_attention_kwargs
477
+
478
+ # 2. Define call parameters
479
+ if prompt is not None and isinstance(prompt, str):
480
+ batch_size = 1
481
+ elif prompt is not None and isinstance(prompt, list):
482
+ batch_size = len(prompt)
483
+ else:
484
+ batch_size = prompt_embeds.shape[0]
485
+
486
+ device = self._execution_device
487
+
488
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
489
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
490
+
491
+ global_pool_conditions = (
492
+ controlnet.config.global_pool_conditions
493
+ if isinstance(controlnet, ControlNetModel)
494
+ else controlnet.nets[0].config.global_pool_conditions
495
+ )
496
+ guess_mode = guess_mode or global_pool_conditions
497
+
498
+ # 3.1 Encode input prompt
499
+ text_encoder_lora_scale = (
500
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
501
+ )
502
+ (
503
+ prompt_embeds,
504
+ negative_prompt_embeds,
505
+ pooled_prompt_embeds,
506
+ negative_pooled_prompt_embeds,
507
+ ) = self.encode_prompt(
508
+ prompt,
509
+ prompt_2,
510
+ device,
511
+ num_images_per_prompt,
512
+ self.do_classifier_free_guidance,
513
+ negative_prompt,
514
+ negative_prompt_2,
515
+ prompt_embeds=prompt_embeds,
516
+ negative_prompt_embeds=negative_prompt_embeds,
517
+ pooled_prompt_embeds=pooled_prompt_embeds,
518
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
519
+ lora_scale=text_encoder_lora_scale,
520
+ clip_skip=self.clip_skip,
521
+ )
522
+
523
+ # 3.2 Encode image prompt
524
+ prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
525
+ device,
526
+ num_images_per_prompt,
527
+ self.unet.dtype,
528
+ self.do_classifier_free_guidance)
529
+
530
+ # 4. Prepare image
531
+ if isinstance(controlnet, ControlNetModel):
532
+ image = self.prepare_image(
533
+ image=image,
534
+ width=width,
535
+ height=height,
536
+ batch_size=batch_size * num_images_per_prompt,
537
+ num_images_per_prompt=num_images_per_prompt,
538
+ device=device,
539
+ dtype=controlnet.dtype,
540
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
541
+ guess_mode=guess_mode,
542
+ )
543
+ height, width = image.shape[-2:]
544
+ elif isinstance(controlnet, MultiControlNetModel):
545
+ images = []
546
+
547
+ for image_ in image:
548
+ image_ = self.prepare_image(
549
+ image=image_,
550
+ width=width,
551
+ height=height,
552
+ batch_size=batch_size * num_images_per_prompt,
553
+ num_images_per_prompt=num_images_per_prompt,
554
+ device=device,
555
+ dtype=controlnet.dtype,
556
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
557
+ guess_mode=guess_mode,
558
+ )
559
+
560
+ images.append(image_)
561
+
562
+ image = images
563
+ height, width = image[0].shape[-2:]
564
+ else:
565
+ assert False
566
+
567
+ # 5. Prepare timesteps
568
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
569
+ timesteps = self.scheduler.timesteps
570
+ self._num_timesteps = len(timesteps)
571
+
572
+ # 6. Prepare latent variables
573
+ num_channels_latents = self.unet.config.in_channels
574
+ latents = self.prepare_latents(
575
+ batch_size * num_images_per_prompt,
576
+ num_channels_latents,
577
+ height,
578
+ width,
579
+ prompt_embeds.dtype,
580
+ device,
581
+ generator,
582
+ latents,
583
+ )
584
+
585
+ # 6.5 Optionally get Guidance Scale Embedding
586
+ timestep_cond = None
587
+ if self.unet.config.time_cond_proj_dim is not None:
588
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
589
+ timestep_cond = self.get_guidance_scale_embedding(
590
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
591
+ ).to(device=device, dtype=latents.dtype)
592
+
593
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
594
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
595
+
596
+ # 7.1 Create tensor stating which controlnets to keep
597
+ controlnet_keep = []
598
+ for i in range(len(timesteps)):
599
+ keeps = [
600
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
601
+ for s, e in zip(control_guidance_start, control_guidance_end)
602
+ ]
603
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
604
+
605
+ # 7.2 Prepare added time ids & embeddings
606
+ if isinstance(image, list):
607
+ original_size = original_size or image[0].shape[-2:]
608
+ else:
609
+ original_size = original_size or image.shape[-2:]
610
+ target_size = target_size or (height, width)
611
+
612
+ add_text_embeds = pooled_prompt_embeds
613
+ if self.text_encoder_2 is None:
614
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
615
+ else:
616
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
617
+
618
+ add_time_ids = self._get_add_time_ids(
619
+ original_size,
620
+ crops_coords_top_left,
621
+ target_size,
622
+ dtype=prompt_embeds.dtype,
623
+ text_encoder_projection_dim=text_encoder_projection_dim,
624
+ )
625
+
626
+ if negative_original_size is not None and negative_target_size is not None:
627
+ negative_add_time_ids = self._get_add_time_ids(
628
+ negative_original_size,
629
+ negative_crops_coords_top_left,
630
+ negative_target_size,
631
+ dtype=prompt_embeds.dtype,
632
+ text_encoder_projection_dim=text_encoder_projection_dim,
633
+ )
634
+ else:
635
+ negative_add_time_ids = add_time_ids
636
+
637
+ if self.do_classifier_free_guidance:
638
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
639
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
640
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
641
+
642
+ prompt_embeds = prompt_embeds.to(device)
643
+ add_text_embeds = add_text_embeds.to(device)
644
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
645
+ encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
646
+
647
+ # 8. Denoising loop
648
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
649
+ is_unet_compiled = is_compiled_module(self.unet)
650
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
651
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
652
+
653
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
654
+ for i, t in enumerate(timesteps):
655
+ # Relevant thread:
656
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
657
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
658
+ torch._inductor.cudagraph_mark_step_begin()
659
+ # expand the latents if we are doing classifier free guidance
660
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
661
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
662
+
663
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
664
+
665
+ # controlnet(s) inference
666
+ if guess_mode and self.do_classifier_free_guidance:
667
+ # Infer ControlNet only for the conditional batch.
668
+ control_model_input = latents
669
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
670
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
671
+ controlnet_added_cond_kwargs = {
672
+ "text_embeds": add_text_embeds.chunk(2)[1],
673
+ "time_ids": add_time_ids.chunk(2)[1],
674
+ }
675
+ else:
676
+ control_model_input = latent_model_input
677
+ controlnet_prompt_embeds = prompt_embeds
678
+ controlnet_added_cond_kwargs = added_cond_kwargs
679
+
680
+ if isinstance(controlnet_keep[i], list):
681
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
682
+ else:
683
+ controlnet_cond_scale = controlnet_conditioning_scale
684
+ if isinstance(controlnet_cond_scale, list):
685
+ controlnet_cond_scale = controlnet_cond_scale[0]
686
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
687
+
688
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
689
+ control_model_input,
690
+ t,
691
+ encoder_hidden_states=prompt_image_emb,
692
+ controlnet_cond=image,
693
+ conditioning_scale=cond_scale,
694
+ guess_mode=guess_mode,
695
+ added_cond_kwargs=controlnet_added_cond_kwargs,
696
+ return_dict=False,
697
+ )
698
+
699
+ if guess_mode and self.do_classifier_free_guidance:
700
+ # Infered ControlNet only for the conditional batch.
701
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
702
+ # add 0 to the unconditional batch to keep it unchanged.
703
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
704
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
705
+
706
+ # predict the noise residual
707
+ noise_pred = self.unet(
708
+ latent_model_input,
709
+ t,
710
+ encoder_hidden_states=encoder_hidden_states,
711
+ timestep_cond=timestep_cond,
712
+ cross_attention_kwargs=self.cross_attention_kwargs,
713
+ down_block_additional_residuals=down_block_res_samples,
714
+ mid_block_additional_residual=mid_block_res_sample,
715
+ added_cond_kwargs=added_cond_kwargs,
716
+ return_dict=False,
717
+ )[0]
718
+
719
+ # perform guidance
720
+ if self.do_classifier_free_guidance:
721
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
722
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
723
+
724
+ # compute the previous noisy sample x_t -> x_t-1
725
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
726
+
727
+ if callback_on_step_end is not None:
728
+ callback_kwargs = {}
729
+ for k in callback_on_step_end_tensor_inputs:
730
+ callback_kwargs[k] = locals()[k]
731
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
732
+
733
+ latents = callback_outputs.pop("latents", latents)
734
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
735
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
736
+
737
+ # call the callback, if provided
738
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
739
+ progress_bar.update()
740
+ if callback is not None and i % callback_steps == 0:
741
+ step_idx = i // getattr(self.scheduler, "order", 1)
742
+ callback(step_idx, t, latents)
743
+
744
+ if not output_type == "latent":
745
+ # make sure the VAE is in float32 mode, as it overflows in float16
746
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
747
+ if needs_upcasting:
748
+ self.upcast_vae()
749
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
750
+
751
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
752
+
753
+ # cast back to fp16 if needed
754
+ if needs_upcasting:
755
+ self.vae.to(dtype=torch.float16)
756
+ else:
757
+ image = latents
758
+
759
+ if not output_type == "latent":
760
+ # apply watermark if available
761
+ if self.watermark is not None:
762
+ image = self.watermark.apply_watermark(image)
763
+
764
+ image = self.image_processor.postprocess(image, output_type=output_type)
765
+
766
+ # Offload all models
767
+ self.maybe_free_model_hooks()
768
+
769
+ if not return_dict:
770
+ return (image,)
771
+
772
+ return StableDiffusionXLPipelineOutput(images=image)
src/pipelines/lora_pipeline.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPTextModel,
11
+ CLIPTextModelWithProjection,
12
+ CLIPTokenizer,
13
+ CLIPVisionModelWithProjection,
14
+ )
15
+
16
+ from diffusers.utils.import_utils import is_invisible_watermark_available
17
+
18
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
19
+ from diffusers.loaders import (
20
+ FromSingleFileMixin,
21
+ IPAdapterMixin,
22
+ StableDiffusionXLLoraLoaderMixin,
23
+ TextualInversionLoaderMixin,
24
+ )
25
+ from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
26
+ from diffusers.models.attention_processor import (
27
+ AttnProcessor2_0,
28
+ LoRAAttnProcessor2_0,
29
+ LoRAXFormersAttnProcessor,
30
+ XFormersAttnProcessor,
31
+ )
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
45
+
46
+
47
+ if is_invisible_watermark_available():
48
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
49
+
50
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
51
+ from diffusers import StableDiffusionXLControlNetPipeline
52
+ from PIL import Image
53
+ from torchvision.transforms.functional import to_tensor
54
+ from einops import rearrange
55
+ from torch import einsum
56
+ import math
57
+ from torchvision.utils import save_image
58
+
59
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
+
61
+ class RegionControlNet_AttnProcessor:
62
+ def __init__(self, attention_op=None, controller=None, place_in_unet=None):
63
+ self.attention_op = attention_op
64
+ self.controller = controller
65
+ self.place_in_unet = place_in_unet
66
+
67
+ def __call__(
68
+ self,
69
+ attn,
70
+ hidden_states: torch.FloatTensor,
71
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
72
+ attention_mask: Optional[torch.FloatTensor] = None,
73
+ temb: Optional[torch.FloatTensor] = None,
74
+ scale: float = 1.0,
75
+ **cross_attention_kwargs
76
+ ) -> torch.Tensor:
77
+ residual = hidden_states
78
+
79
+ args = () if USE_PEFT_BACKEND else (scale,)
80
+
81
+ if attn.spatial_norm is not None:
82
+ hidden_states = attn.spatial_norm(hidden_states, temb)
83
+
84
+ input_ndim = hidden_states.ndim
85
+
86
+ if input_ndim == 4:
87
+ batch_size, channel, height, width = hidden_states.shape
88
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
89
+
90
+ batch_size, sequence_length, _ = (
91
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
92
+ )
93
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
94
+
95
+ if attn.group_norm is not None:
96
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
97
+
98
+ query = attn.to_q(hidden_states, *args)
99
+
100
+ is_cross = True
101
+ if encoder_hidden_states is None:
102
+ is_cross = False
103
+ encoder_hidden_states = hidden_states
104
+ elif attn.norm_cross:
105
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
106
+
107
+ key = attn.to_k(encoder_hidden_states, *args)
108
+ value = attn.to_v(encoder_hidden_states, *args)
109
+
110
+ query = attn.head_to_batch_dim(query)
111
+ key = attn.head_to_batch_dim(key)
112
+ value = attn.head_to_batch_dim(value)
113
+
114
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
115
+ attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
116
+ hidden_states = torch.bmm(attention_probs, value)
117
+
118
+ hidden_states = attn.batch_to_head_dim(hidden_states)
119
+
120
+ # linear proj
121
+ hidden_states = attn.to_out[0](hidden_states, *args)
122
+ # dropout
123
+ hidden_states = attn.to_out[1](hidden_states)
124
+
125
+ if input_ndim == 4:
126
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
127
+
128
+ if attn.residual_connection:
129
+ hidden_states = hidden_states + residual
130
+
131
+ hidden_states = hidden_states / attn.rescale_output_factor
132
+
133
+ return hidden_states
134
+
135
+
136
+ def revise_regionally_controlnet_forward(unet, controller):
137
+ def change_forward(unet, count, place_in_unet):
138
+ for name, layer in unet.named_children():
139
+ if layer.__class__.__name__ == 'Attention':
140
+ layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet))
141
+ if 'attn2' in name:
142
+ count += 1
143
+ else:
144
+ count = change_forward(layer, count, place_in_unet)
145
+ return count
146
+
147
+ # use this to ensure the order
148
+ cross_attention_idx = change_forward(unet.down_blocks, 0, "down")
149
+ cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up")
150
+ cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid")
151
+ print(f'Number of attention layer registered {cross_attention_idx}')
152
+ controller.num_att_layers = cross_attention_idx*2
153
+
154
+ class LoraMultiConceptPipeline(StableDiffusionXLControlNetPipeline):
155
+ # leave controlnet out on purpose because it iterates with unet
156
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
157
+ _optional_components = [
158
+ "tokenizer",
159
+ "tokenizer_2",
160
+ "text_encoder",
161
+ "text_encoder_2",
162
+ "feature_extractor",
163
+ "image_encoder",
164
+ ]
165
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
166
+
167
+ def __init__(
168
+ self,
169
+ vae: AutoencoderKL,
170
+ text_encoder: CLIPTextModel,
171
+ text_encoder_2: CLIPTextModelWithProjection,
172
+ tokenizer: CLIPTokenizer,
173
+ tokenizer_2: CLIPTokenizer,
174
+ unet: UNet2DConditionModel,
175
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
176
+ scheduler: KarrasDiffusionSchedulers,
177
+ force_zeros_for_empty_prompt: bool = True,
178
+ add_watermarker: Optional[bool] = None,
179
+ feature_extractor: CLIPImageProcessor = None,
180
+ image_encoder: CLIPVisionModelWithProjection = None
181
+ ):
182
+ if isinstance(controlnet, (list, tuple)):
183
+ controlnet = MultiControlNetModel(controlnet)
184
+
185
+ self.register_modules(
186
+ vae=vae,
187
+ text_encoder=text_encoder,
188
+ text_encoder_2=text_encoder_2,
189
+ tokenizer=tokenizer,
190
+ tokenizer_2=tokenizer_2,
191
+ unet=unet,
192
+ controlnet=controlnet,
193
+ scheduler=scheduler,
194
+ feature_extractor=feature_extractor,
195
+ image_encoder=image_encoder,
196
+ )
197
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
198
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
199
+ self.control_image_processor = VaeImageProcessor(
200
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
201
+ )
202
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
203
+
204
+ if add_watermarker:
205
+ self.watermark = StableDiffusionXLWatermarker()
206
+ else:
207
+ self.watermark = None
208
+
209
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
210
+
211
+ @torch.no_grad()
212
+ def __call__(
213
+ self,
214
+ prompt: Union[str, List[str]] = None,
215
+ prompt_2: Optional[Union[str, List[str]]] = None,
216
+ image: PipelineImageInput = None,
217
+ height: Optional[int] = None,
218
+ width: Optional[int] = None,
219
+ num_inference_steps: int = 50,
220
+ guidance_scale: float = 5.0,
221
+ negative_prompt: Optional[Union[str, List[str]]] = None,
222
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
223
+ num_images_per_prompt: Optional[int] = 1,
224
+ eta: float = 0.0,
225
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
226
+ latents: Optional[torch.FloatTensor] = None,
227
+ prompt_embeds: Optional[torch.FloatTensor] = None,
228
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
229
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
230
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
231
+ ip_adapter_image: Optional[PipelineImageInput] = None,
232
+ output_type: Optional[str] = "pil",
233
+ return_dict: bool = True,
234
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
235
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
236
+ guess_mode: bool = False,
237
+ control_guidance_start: Union[float, List[float]] = 0.0,
238
+ control_guidance_end: Union[float, List[float]] = 1.0,
239
+ original_size: Tuple[int, int] = None,
240
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
241
+ target_size: Tuple[int, int] = None,
242
+ negative_original_size: Optional[Tuple[int, int]] = None,
243
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
244
+ negative_target_size: Optional[Tuple[int, int]] = None,
245
+ clip_skip: Optional[int] = None,
246
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
247
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
248
+ controller=None,
249
+ concept_models=None,
250
+ stage=None,
251
+ region_masks=None,
252
+ lora_list=None,
253
+ styleL=None,
254
+ **kwargs,
255
+ ):
256
+ callback = kwargs.pop("callback", None)
257
+ callback_steps = kwargs.pop("callback_steps", None)
258
+
259
+ if callback is not None:
260
+ deprecate(
261
+ "callback",
262
+ "1.0.0",
263
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
264
+ )
265
+ if callback_steps is not None:
266
+ deprecate(
267
+ "callback_steps",
268
+ "1.0.0",
269
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
270
+ )
271
+
272
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
273
+
274
+ # align format for control guidance
275
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
276
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
277
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
278
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
279
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
280
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
281
+ control_guidance_start, control_guidance_end = (
282
+ mult * [control_guidance_start],
283
+ mult * [control_guidance_end],
284
+ )
285
+
286
+ self._guidance_scale = guidance_scale
287
+ self._clip_skip = clip_skip
288
+ self._cross_attention_kwargs = cross_attention_kwargs
289
+
290
+ # 2. Define call parameters
291
+ batch_size = 2
292
+
293
+ device = self._execution_device
294
+
295
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
296
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
297
+
298
+ global_pool_conditions = (
299
+ controlnet.config.global_pool_conditions
300
+ if isinstance(controlnet, ControlNetModel)
301
+ else controlnet.nets[0].config.global_pool_conditions
302
+ )
303
+ guess_mode = guess_mode or global_pool_conditions
304
+
305
+ # 3.1 Encode input prompt
306
+ text_encoder_lora_scale = (
307
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
308
+ )
309
+
310
+ global_prompt = prompt[0]
311
+ global_negative_prompt = negative_prompt
312
+ region_prompts = [pt[0] for pt in prompt[1]]
313
+ region_negative_prompts = [pt[1] for pt in prompt[1]]
314
+
315
+ (
316
+ prompt_embeds,
317
+ negative_prompt_embeds,
318
+ pooled_prompt_embeds,
319
+ negative_pooled_prompt_embeds,
320
+ ) = self.encode_prompt(
321
+ global_prompt,
322
+ prompt_2,
323
+ device,
324
+ num_images_per_prompt,
325
+ self.do_classifier_free_guidance,
326
+ global_negative_prompt,
327
+ negative_prompt_2,
328
+ prompt_embeds=prompt_embeds,
329
+ negative_prompt_embeds=negative_prompt_embeds,
330
+ pooled_prompt_embeds=pooled_prompt_embeds,
331
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
332
+ lora_scale=text_encoder_lora_scale,
333
+ clip_skip=self.clip_skip,
334
+ )
335
+
336
+ region_prompt_embeds_list = []
337
+ region_add_text_embeds_list = []
338
+ for lora_param, region_prompt, region_negative_prompt in zip(lora_list, region_prompts, region_negative_prompts):
339
+ if styleL:
340
+ concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5])
341
+ else:
342
+ concept_models.set_adapters(lora_param)
343
+ region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds = concept_models.encode_prompt(
344
+ prompt=region_prompt, device=concept_models._execution_device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=region_negative_prompt, lora_scale=text_encoder_lora_scale
345
+ )
346
+ region_prompt_embeds_list.append(torch.concat([region_negative_prompt_embeds, region_prompt_embeds], dim=0).to(concept_models._execution_device))
347
+ region_add_text_embeds_list.append(torch.concat([region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds], dim=0).to(concept_models._execution_device))
348
+
349
+ if stage==2:
350
+ mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks]
351
+
352
+ # 4. Prepare image
353
+ if isinstance(controlnet, ControlNetModel) and image is not None:
354
+ image = self.prepare_image(
355
+ image=image,
356
+ width=width,
357
+ height=height,
358
+ batch_size=batch_size * num_images_per_prompt,
359
+ num_images_per_prompt=num_images_per_prompt,
360
+ device=device,
361
+ dtype=controlnet.dtype,
362
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
363
+ guess_mode=guess_mode,
364
+ )
365
+ height, width = image.shape[-2:]
366
+ elif isinstance(controlnet, MultiControlNetModel) and image is not None:
367
+ images = []
368
+
369
+ for image_ in image:
370
+ image_ = self.prepare_image(
371
+ image=image_,
372
+ width=width,
373
+ height=height,
374
+ batch_size=batch_size * num_images_per_prompt,
375
+ num_images_per_prompt=num_images_per_prompt,
376
+ device=device,
377
+ dtype=controlnet.dtype,
378
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
379
+ guess_mode=guess_mode,
380
+ )
381
+
382
+ images.append(image_)
383
+
384
+ image = images
385
+ height, width = image[0].shape[-2:]
386
+ else:
387
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
388
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
389
+
390
+ # 5. Prepare timesteps
391
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
392
+ timesteps = self.scheduler.timesteps
393
+ self._num_timesteps = len(timesteps)
394
+
395
+ # 6. Prepare latent variables
396
+ num_channels_latents = self.unet.config.in_channels
397
+ latents = self.prepare_latents(
398
+ batch_size//2 * num_images_per_prompt,
399
+ num_channels_latents,
400
+ height,
401
+ width,
402
+ prompt_embeds.dtype,
403
+ device,
404
+ generator,
405
+ latents,
406
+ )
407
+
408
+ # 6.1 repeat latent
409
+ latents = torch.cat([latents, latents.clone()])
410
+
411
+ timestep_cond = None
412
+ if self.unet.config.time_cond_proj_dim is not None:
413
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
414
+ timestep_cond = self.get_guidance_scale_embedding(
415
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
416
+ ).to(device=device, dtype=latents.dtype)
417
+
418
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
419
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
420
+
421
+ # 7.1 Create tensor stating which controlnets to keep
422
+ controlnet_keep = []
423
+ for i in range(len(timesteps)):
424
+ keeps = [
425
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
426
+ for s, e in zip(control_guidance_start, control_guidance_end)
427
+ ]
428
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
429
+
430
+ # 7.2 Prepare added time ids & embeddings
431
+ if isinstance(image, list):
432
+ original_size = original_size or image[0].shape[-2:]
433
+ else:
434
+ original_size = original_size or (height, width)
435
+ target_size = target_size or (height, width)
436
+
437
+ add_text_embeds = pooled_prompt_embeds
438
+ if self.text_encoder_2 is None:
439
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
440
+ else:
441
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
442
+
443
+ add_time_ids = self._get_add_time_ids(
444
+ original_size,
445
+ crops_coords_top_left,
446
+ target_size,
447
+ dtype=prompt_embeds.dtype,
448
+ text_encoder_projection_dim=text_encoder_projection_dim,
449
+ )
450
+
451
+ add_time_ids_list = []
452
+ for _ in lora_list:
453
+ region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim)
454
+ add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device))
455
+
456
+ if negative_original_size is not None and negative_target_size is not None:
457
+ negative_add_time_ids = self._get_add_time_ids(
458
+ negative_original_size,
459
+ negative_crops_coords_top_left,
460
+ negative_target_size,
461
+ dtype=prompt_embeds.dtype,
462
+ text_encoder_projection_dim=text_encoder_projection_dim,
463
+ )
464
+ else:
465
+ negative_add_time_ids = add_time_ids
466
+
467
+ if self.do_classifier_free_guidance:
468
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
469
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
470
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
471
+
472
+ prompt_embeds = prompt_embeds.to(device)
473
+ add_text_embeds = add_text_embeds.to(device)
474
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
475
+
476
+ # 8. Denoising loop
477
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
478
+ is_unet_compiled = is_compiled_module(self.unet)
479
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
480
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
481
+ # hyper-parameters
482
+ scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps))
483
+
484
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
485
+ for i, t in enumerate(timesteps):
486
+ # Relevant thread:
487
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
488
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
489
+ torch._inductor.cudagraph_mark_step_begin()
490
+ # expand the latents if we are doing classifier free guidance
491
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
492
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
493
+
494
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
495
+
496
+ # controlnet(s) inference
497
+ if guess_mode and self.do_classifier_free_guidance:
498
+ # Infer ControlNet only for the conditional batch.
499
+ control_model_input = latents
500
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
501
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
502
+ controlnet_added_cond_kwargs = {
503
+ "text_embeds": add_text_embeds.chunk(2)[1],
504
+ "time_ids": add_time_ids.chunk(2)[1],
505
+ }
506
+ else:
507
+ control_model_input = latent_model_input
508
+ controlnet_prompt_embeds = prompt_embeds
509
+ controlnet_added_cond_kwargs = added_cond_kwargs
510
+
511
+ if isinstance(controlnet_keep[i], list):
512
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
513
+ else:
514
+ controlnet_cond_scale = controlnet_conditioning_scale
515
+ if isinstance(controlnet_cond_scale, list):
516
+ controlnet_cond_scale = controlnet_cond_scale[0]
517
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
518
+
519
+ if image is not None:
520
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
521
+ control_model_input,
522
+ t,
523
+ encoder_hidden_states=controlnet_prompt_embeds,
524
+ controlnet_cond=image,
525
+ conditioning_scale=cond_scale,
526
+ guess_mode=guess_mode,
527
+ added_cond_kwargs=controlnet_added_cond_kwargs,
528
+ return_dict=False,
529
+ )
530
+
531
+ if guess_mode and self.do_classifier_free_guidance:
532
+ # Infered ControlNet only for the conditional batch.
533
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
534
+ # add 0 to the unconditional batch to keep it unchanged.
535
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
536
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
537
+
538
+ else:
539
+ down_block_res_samples = None
540
+ mid_block_res_sample = None
541
+
542
+
543
+
544
+ # predict the noise residual
545
+ if image is not None:
546
+ noise_pred = self.unet(
547
+ latent_model_input,
548
+ t,
549
+ encoder_hidden_states=prompt_embeds,
550
+ timestep_cond=timestep_cond,
551
+ cross_attention_kwargs=self.cross_attention_kwargs,
552
+ down_block_additional_residuals=down_block_res_samples,
553
+ mid_block_additional_residual=mid_block_res_sample,
554
+ added_cond_kwargs=added_cond_kwargs,
555
+ return_dict=False,
556
+ )[0]
557
+ else:
558
+ noise_pred = self.unet(
559
+ latent_model_input,
560
+ t,
561
+ encoder_hidden_states=prompt_embeds,
562
+ timestep_cond=timestep_cond,
563
+ cross_attention_kwargs=self.cross_attention_kwargs,
564
+ added_cond_kwargs=added_cond_kwargs,
565
+ return_dict=False,
566
+ )[0]
567
+
568
+ if i > 15 and stage == 2:
569
+ region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3])
570
+ edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0)
571
+ new_noise_pred = torch.zeros_like(edit_noise)
572
+ new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0]
573
+ replace_ratio = 1.0
574
+ new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0]
575
+
576
+ for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, lora_param in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, lora_list):
577
+ if concept_mask is not None:
578
+ concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0),
579
+ size=(noise_pred.shape[2], noise_pred.shape[3]),
580
+ mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device)
581
+
582
+
583
+ region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device)
584
+
585
+ region_latent_model_input = torch.cat([region_latent_model_input] * 2)
586
+ region_added_cond_kwargs = {"text_embeds": region_add_text_embeds,
587
+ "time_ids": region_add_time_ids}
588
+ if styleL:
589
+ concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5])
590
+ else:
591
+ concept_models.set_adapters(lora_param)
592
+ region_noise_pred = concept_models.unet(
593
+ region_latent_model_input,
594
+ t,
595
+ encoder_hidden_states=region_prompt_embeds,
596
+ cross_attention_kwargs={'scale': 0.8},
597
+ added_cond_kwargs=region_added_cond_kwargs,
598
+ return_dict=False,
599
+ )[0]
600
+
601
+ new_noise_pred = new_noise_pred.to(concept_models._execution_device)
602
+ new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device)))
603
+
604
+
605
+ new_noise_pred = new_noise_pred.to(noise_pred.device)
606
+ noise_pred[1, :, :, :] = new_noise_pred[0]
607
+ noise_pred[3, :, :, :] = new_noise_pred[1]
608
+
609
+
610
+ if self.do_classifier_free_guidance:
611
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
612
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
613
+
614
+ # compute the previous noisy sample x_t -> x_t-1
615
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
616
+
617
+ if callback_on_step_end is not None:
618
+ callback_kwargs = {}
619
+ for k in callback_on_step_end_tensor_inputs:
620
+ callback_kwargs[k] = locals()[k]
621
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
622
+
623
+ latents = callback_outputs.pop("latents", latents)
624
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
625
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
626
+
627
+ # call the callback, if provided
628
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
629
+ progress_bar.update()
630
+ if callback is not None and i % callback_steps == 0:
631
+ step_idx = i // getattr(self.scheduler, "order", 1)
632
+ callback(step_idx, t, latents)
633
+
634
+ # manually for max memory savings
635
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
636
+ self.upcast_vae()
637
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
638
+ if stage==2:
639
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
640
+ if not output_type == "latent":
641
+ # make sure the VAE is in float32 mode, as it overflows in float16
642
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
643
+
644
+ if needs_upcasting:
645
+ self.upcast_vae()
646
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
647
+
648
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
649
+
650
+ # cast back to fp16 if needed
651
+ if needs_upcasting:
652
+ self.vae.to(dtype=torch.float16)
653
+ else:
654
+ image = latents
655
+
656
+ if not output_type == "latent":
657
+ # apply watermark if available
658
+ if self.watermark is not None:
659
+ image = self.watermark.apply_watermark(image)
660
+
661
+ image = self.image_processor.postprocess(image, output_type=output_type)
662
+
663
+ # Offload all models
664
+ self.maybe_free_model_hooks()
665
+
666
+ if not return_dict:
667
+ return (image,)
668
+
669
+ return StableDiffusionXLPipelineOutput(images=image)
670
+
671
+ def check_image(self, image, prompt, prompt_embeds):
672
+ pass
673
+
674
+ def get_region_mask(self, mask_list, feat_height, feat_width):
675
+ exclusive_mask = torch.zeros((feat_height, feat_width))
676
+ for mask in mask_list:
677
+ if mask is not None:
678
+ mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width),
679
+ mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device)
680
+ exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype)
681
+ return exclusive_mask
src/prompt_attention/p2p_attention.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple, List, Callable, Dict
2
+ import torch
3
+ import torch.nn.functional as nnf
4
+ import numpy as np
5
+ import abc
6
+ import src.prompt_attention.p2p_utils as p2p_utils
7
+ import src.prompt_attention.seq_aligner as seq_aligner
8
+
9
+
10
+
11
+ class AttentionControl(abc.ABC):
12
+
13
+ def step_callback(self, x_t):
14
+ return x_t
15
+
16
+ def between_steps(self):
17
+ return
18
+
19
+ @property
20
+ def num_uncond_att_layers(self):
21
+ # return self.num_att_layers if self.low_resource else 0
22
+ return 0
23
+
24
+ @abc.abstractmethod
25
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
26
+ raise NotImplementedError
27
+
28
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
29
+ if self.cur_att_layer >= self.num_uncond_att_layers:
30
+ if self.low_resource:
31
+ attn = self.forward(attn, is_cross, place_in_unet)
32
+ else:
33
+ h = attn.shape[0]
34
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
35
+ self.cur_att_layer += 1
36
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
37
+ self.cur_att_layer = 0
38
+ self.cur_step += 1
39
+ self.between_steps()
40
+ return attn
41
+
42
+ def reset(self):
43
+ self.cur_step = 0
44
+ self.cur_att_layer = 0
45
+
46
+ def __init__(self, low_resource=False, width=None, height=None):
47
+ self.cur_step = 0
48
+ self.num_att_layers = -1
49
+ self.cur_att_layer = 0
50
+ self.low_resource = low_resource
51
+ self.width = width
52
+ self.height = height
53
+
54
+ class AttentionStore(AttentionControl):
55
+
56
+ @staticmethod
57
+ def get_empty_store():
58
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
59
+ "down_self": [], "mid_self": [], "up_self": []}
60
+
61
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
62
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
63
+ # if attn.shape[1] <= att_size * 64:
64
+ return attn
65
+
66
+ def between_steps(self):
67
+ if self.save_global_store:
68
+ if len(self.attention_store) == 0:
69
+ self.attention_store = self.step_store
70
+ else:
71
+ for key in self.attention_store:
72
+ for i in range(len(self.attention_store[key])):
73
+ self.attention_store[key][i] += self.step_store[key][i]
74
+ else:
75
+ self.attention_store = self.step_store
76
+ self.step_store = self.get_empty_store()
77
+
78
+ def get_average_attention(self):
79
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
80
+ self.attention_store}
81
+ return average_attention
82
+
83
+ def reset(self):
84
+ super(AttentionStore, self).reset()
85
+ self.step_store = self.get_empty_store()
86
+ self.attention_store = {}
87
+
88
+ def __init__(self, width, height, low_resolution=False, save_global_store=False):
89
+ super(AttentionStore, self).__init__(low_resolution, width, height)
90
+ self.step_store = self.get_empty_store()
91
+ self.attention_store = {}
92
+ self.save_global_store = save_global_store
93
+
94
+ class AttentionControlEdit(AttentionStore, abc.ABC):
95
+ def __init__(self, prompts, num_steps: int,
96
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
97
+ self_replace_steps: Union[float, Tuple[float, float]],
98
+ local_blend=None, width=None, height=None, tokenizer=None, device=None):
99
+ super(AttentionControlEdit, self).__init__(width, height)
100
+ self.batch_size = len(prompts)
101
+ self.cross_replace_alpha = p2p_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
102
+ tokenizer).to(device)
103
+ if type(self_replace_steps) is float:
104
+ self_replace_steps = 0, self_replace_steps
105
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
106
+ self.local_blend = local_blend
107
+
108
+ def step_callback(self, x_t):
109
+ print("step_callback")
110
+ if self.local_blend is not None:
111
+ x_t = self.local_blend(x_t, self.attention_store)
112
+ return x_t
113
+
114
+ def replace_self_attention(self, attn_base, att_replace):
115
+ if att_replace.shape[2] <= self.width * self.height:
116
+ return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
117
+ else:
118
+ return att_replace
119
+
120
+ @abc.abstractmethod
121
+ def replace_cross_attention(self, attn_base, att_replace):
122
+ raise NotImplementedError
123
+
124
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
125
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
126
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
127
+ h = attn.shape[0] // (self.batch_size)
128
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
129
+ attn_base, attn_repalce = attn[0], attn[1:]
130
+ if is_cross:
131
+ alpha_words = self.cross_replace_alpha[self.cur_step]
132
+ attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
133
+ 1 - alpha_words) * attn_repalce
134
+ attn[1:] = attn_repalce_new
135
+ else:
136
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
137
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
138
+ return attn
139
+
140
+ class AttentionReplace(AttentionControlEdit):
141
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, width, height,
142
+ local_blend = None, tokenizer=None, device=None, dtype=None):
143
+ super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, width, height, tokenizer=tokenizer, device=device)
144
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(dtype=dtype, device=device)
145
+
146
+ def replace_cross_attention(self, attn_base, att_replace):
147
+ return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
148
+
src/prompt_attention/p2p_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ from PIL import Image, ImageDraw, ImageFont
18
+ import cv2
19
+ from typing import Optional, Union, Tuple, List, Callable, Dict
20
+
21
+
22
+
23
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
24
+ word_inds: Optional[torch.Tensor] = None):
25
+ if type(bounds) is float:
26
+ bounds = 0, bounds
27
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
28
+ if word_inds is None:
29
+ word_inds = torch.arange(alpha.shape[2])
30
+ alpha[: start, prompt_ind, word_inds] = 0
31
+ alpha[start: end, prompt_ind, word_inds] = 1
32
+ alpha[end:, prompt_ind, word_inds] = 0
33
+ return alpha
34
+
35
+ def get_word_inds(text: str, word_place: int, tokenizer):
36
+ split_text = text.split(" ")
37
+ if type(word_place) is str:
38
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
39
+ elif type(word_place) is int:
40
+ word_place = [word_place]
41
+ out = []
42
+ if len(word_place) > 0:
43
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
44
+ cur_len, ptr = 0, 0
45
+
46
+ for i in range(len(words_encode)):
47
+ cur_len += len(words_encode[i])
48
+ if ptr in word_place:
49
+ out.append(i + 1)
50
+ if cur_len >= len(split_text[ptr]):
51
+ ptr += 1
52
+ cur_len = 0
53
+ return np.array(out)
54
+
55
+ def get_time_words_attention_alpha(prompts, num_steps,
56
+ cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
57
+ tokenizer, max_num_words=77):
58
+ if type(cross_replace_steps) is not dict:
59
+ cross_replace_steps = {"default_": cross_replace_steps}
60
+ if "default_" not in cross_replace_steps:
61
+ cross_replace_steps["default_"] = (0., 1.)
62
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
63
+ for i in range(len(prompts) - 1):
64
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
65
+ i)
66
+ for key, item in cross_replace_steps.items():
67
+ if key != "default_":
68
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
69
+ for i, ind in enumerate(inds):
70
+ if len(ind) > 0:
71
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
72
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
73
+ return alpha_time_words
74
+
src/prompt_attention/seq_aligner.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def get_word_inds(text: str, word_place: int, tokenizer):
6
+ split_text = text.split(" ")
7
+ if type(word_place) is str:
8
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
9
+ elif type(word_place) is int:
10
+ word_place = [word_place]
11
+ out = []
12
+ if len(word_place) > 0:
13
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
14
+ cur_len, ptr = 0, 0
15
+
16
+ for i in range(len(words_encode)):
17
+ cur_len += len(words_encode[i])
18
+ if ptr in word_place:
19
+ out.append(i + 1)
20
+ if cur_len >= len(split_text[ptr]):
21
+ ptr += 1
22
+ cur_len = 0
23
+ return np.array(out)
24
+
25
+ def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
26
+ words_x = x.split(' ')
27
+ words_y = y.split(' ')
28
+ if len(words_x) != len(words_y):
29
+ raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
30
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
31
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
32
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
33
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
34
+ mapper = np.zeros((max_len, max_len))
35
+ i = j = 0
36
+ cur_inds = 0
37
+ while i < max_len and j < max_len:
38
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
39
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
40
+ if len(inds_source_) == len(inds_target_):
41
+ mapper[inds_source_, inds_target_] = 1
42
+ else:
43
+ ratio = 1 / len(inds_target_)
44
+ for i_t in inds_target_:
45
+ mapper[inds_source_, i_t] = ratio
46
+ cur_inds += 1
47
+ i += len(inds_source_)
48
+ j += len(inds_target_)
49
+ elif cur_inds < len(inds_source):
50
+ mapper[i, j] = 1
51
+ i += 1
52
+ j += 1
53
+ else:
54
+ mapper[j, j] = 1
55
+ i += 1
56
+ j += 1
57
+
58
+ return torch.from_numpy(mapper).float()
59
+
60
+ def get_replacement_mapper(prompts, tokenizer, max_len=77):
61
+ x_seq = prompts[0]
62
+ mappers = []
63
+ for i in range(1, len(prompts)):
64
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
65
+ mappers.append(mapper)
66
+ return torch.stack(mappers)