aki-0421 commited on
Commit
a3a3ae4
·
unverified ·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +1 -0
  3. README.md +14 -0
  4. annotator/base_annotator.py +57 -0
  5. annotator/canny.py +63 -0
  6. annotator/color.py +59 -0
  7. annotator/hed.py +155 -0
  8. annotator/identity.py +25 -0
  9. annotator/invert.py +25 -0
  10. annotator/midas/__init__.py +0 -0
  11. annotator/midas/api.py +165 -0
  12. annotator/midas/base_model.py +17 -0
  13. annotator/midas/blocks.py +390 -0
  14. annotator/midas/dpt_depth.py +106 -0
  15. annotator/midas/midas_net.py +79 -0
  16. annotator/midas/midas_net_custom.py +166 -0
  17. annotator/midas/transforms.py +230 -0
  18. annotator/midas/utils.py +192 -0
  19. annotator/midas/vit.py +509 -0
  20. annotator/midas_op.py +79 -0
  21. annotator/mlsd/__init__.py +0 -0
  22. annotator/mlsd/mbv2_mlsd_large.py +303 -0
  23. annotator/mlsd/mbv2_mlsd_tiny.py +287 -0
  24. annotator/mlsd/utils.py +638 -0
  25. annotator/mlsd_op.py +74 -0
  26. annotator/openpose.py +812 -0
  27. annotator/registry.py +30 -0
  28. annotator/utils.py +114 -0
  29. app.py +32 -0
  30. dataset/.gitignore +1 -0
  31. dataset/opencv_transforms/__init__.py +0 -0
  32. dataset/opencv_transforms/functional.py +598 -0
  33. dataset/opencv_transforms/transforms.py +1044 -0
  34. dataset/setup.py +23 -0
  35. dataset/tests/compare_to_pil_for_testing.ipynb +241 -0
  36. dataset/tests/setup_testing_directory.py +50 -0
  37. dataset/tests/test_color.py +68 -0
  38. dataset/tests/test_spatial.py +52 -0
  39. dataset/tests/utils.py +8 -0
  40. inference.yaml +166 -0
  41. packages.txt +2 -0
  42. pipeline.py +168 -0
  43. requirements.txt +49 -0
  44. sgm/__init__.py +4 -0
  45. sgm/data/__init__.py +1 -0
  46. sgm/data/dataset.py +80 -0
  47. sgm/data/video_dataset.py +191 -0
  48. sgm/data/video_dataset_stage2_degradeImages.py +303 -0
  49. sgm/inference/api.py +385 -0
  50. sgm/inference/helpers.py +305 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Character 360
3
+ emoji: 🏆
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unknown
11
+ short_description: Would you like to see your character in 360°?
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
annotator/base_annotator.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from abc import ABCMeta
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from scepter.modules.annotator.registry import ANNOTATORS
9
+ from scepter.modules.model.base_model import BaseModel
10
+ from scepter.modules.utils.config import dict_to_yaml
11
+
12
+
13
+ @ANNOTATORS.register_class()
14
+ class BaseAnnotator(BaseModel, metaclass=ABCMeta):
15
+ para_dict = {}
16
+
17
+ def __init__(self, cfg, logger=None):
18
+ super().__init__(cfg, logger=logger)
19
+
20
+ @torch.no_grad()
21
+ @torch.inference_mode
22
+ def forward(self, *args, **kwargs):
23
+ raise NotImplementedError
24
+
25
+ @staticmethod
26
+ def get_config_template():
27
+ return dict_to_yaml('ANNOTATORS',
28
+ __class__.__name__,
29
+ BaseAnnotator.para_dict,
30
+ set_name=True)
31
+
32
+
33
+ @ANNOTATORS.register_class()
34
+ class GeneralAnnotator(BaseAnnotator, metaclass=ABCMeta):
35
+ def __init__(self, cfg, logger=None):
36
+ super().__init__(cfg, logger=logger)
37
+ anno_models = cfg.get('ANNOTATORS', [])
38
+ self.annotators = nn.ModuleList()
39
+ for n, anno_config in enumerate(anno_models):
40
+ annotator = ANNOTATORS.build(anno_config, logger=logger)
41
+ annotator.input_keys = anno_config.get('INPUT_KEYS', [])
42
+ if isinstance(annotator.input_keys, str):
43
+ annotator.input_keys = [annotator.input_keys]
44
+ annotator.output_keys = anno_config.get('OUTPUT_KEYS', [])
45
+ if isinstance(annotator.output_keys, str):
46
+ annotator.output_keys = [annotator.output_keys]
47
+ assert len(annotator.input_keys) == len(annotator.output_keys)
48
+ self.annotators.append(annotator)
49
+
50
+ def forward(self, input_dict):
51
+ output_dict = {}
52
+ for annotator in self.annotators:
53
+ for idx, in_key in enumerate(annotator.input_keys):
54
+ if in_key in input_dict:
55
+ image = annotator(input_dict[in_key])
56
+ output_dict[annotator.output_keys[idx]] = image
57
+ return output_dict
annotator/canny.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from abc import ABCMeta
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
11
+ from scepter.modules.annotator.registry import ANNOTATORS
12
+ from scepter.modules.utils.config import dict_to_yaml
13
+
14
+
15
+ @ANNOTATORS.register_class()
16
+ class CannyAnnotator(BaseAnnotator, metaclass=ABCMeta):
17
+ para_dict = {}
18
+
19
+ def __init__(self, cfg, logger=None):
20
+ super().__init__(cfg, logger=logger)
21
+ self.low_threshold = cfg.get('LOW_THRESHOLD', 100)
22
+ self.high_threshold = cfg.get('HIGH_THRESHOLD', 200)
23
+ self.random_cfg = cfg.get('RANDOM_CFG', None)
24
+
25
+ def forward(self, image):
26
+ if isinstance(image, Image.Image):
27
+ image = np.array(image)
28
+ elif isinstance(image, torch.Tensor):
29
+ image = image.detach().cpu().numpy()
30
+ elif isinstance(image, np.ndarray):
31
+ image = image.copy()
32
+ else:
33
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
34
+ assert len(image.shape) < 4
35
+
36
+ if self.random_cfg is None:
37
+ image = cv2.Canny(image, self.low_threshold, self.high_threshold)
38
+ else:
39
+ proba = self.random_cfg.get('PROBA', 1.0)
40
+ if np.random.random() < proba:
41
+ min_low_threshold = self.random_cfg.get(
42
+ 'MIN_LOW_THRESHOLD', 50)
43
+ max_low_threshold = self.random_cfg.get(
44
+ 'MAX_LOW_THRESHOLD', 100)
45
+ min_high_threshold = self.random_cfg.get(
46
+ 'MIN_HIGH_THRESHOLD', 200)
47
+ max_high_threshold = self.random_cfg.get(
48
+ 'MAX_HIGH_THRESHOLD', 350)
49
+ low_th = np.random.randint(min_low_threshold,
50
+ max_low_threshold)
51
+ high_th = np.random.randint(min_high_threshold,
52
+ max_high_threshold)
53
+ else:
54
+ low_th, high_th = self.low_threshold, self.high_threshold
55
+ image = cv2.Canny(image, low_th, high_th)
56
+ return image[..., None].repeat(3, 2)
57
+
58
+ @staticmethod
59
+ def get_config_template():
60
+ return dict_to_yaml('ANNOTATORS',
61
+ __class__.__name__,
62
+ CannyAnnotator.para_dict,
63
+ set_name=True)
annotator/color.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from abc import ABCMeta
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
11
+ from scepter.modules.annotator.registry import ANNOTATORS
12
+ from scepter.modules.utils.config import dict_to_yaml
13
+
14
+
15
+ @ANNOTATORS.register_class()
16
+ class ColorAnnotator(BaseAnnotator, metaclass=ABCMeta):
17
+ para_dict = {}
18
+
19
+ def __init__(self, cfg, logger=None):
20
+ super().__init__(cfg, logger=logger)
21
+ self.ratio = cfg.get('RATIO', 64)
22
+ self.random_cfg = cfg.get('RANDOM_CFG', None)
23
+
24
+ def forward(self, image):
25
+ if isinstance(image, Image.Image):
26
+ image = np.array(image)
27
+ elif isinstance(image, torch.Tensor):
28
+ image = image.detach().cpu().numpy()
29
+ elif isinstance(image, np.ndarray):
30
+ image = image.copy()
31
+ else:
32
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
33
+ h, w = image.shape[:2]
34
+
35
+ if self.random_cfg is None:
36
+ ratio = self.ratio
37
+ else:
38
+ proba = self.random_cfg.get('PROBA', 1.0)
39
+ if np.random.random() < proba:
40
+ if 'CHOICE_RATIO' in self.random_cfg:
41
+ ratio = np.random.choice(self.random_cfg['CHOICE_RATIO'])
42
+ else:
43
+ min_ratio = self.random_cfg.get('MIN_RATIO', 48)
44
+ max_ratio = self.random_cfg.get('MAX_RATIO', 96)
45
+ ratio = np.random.randint(min_ratio, max_ratio)
46
+ else:
47
+ ratio = self.ratio
48
+ image = cv2.resize(image, (int(w // ratio), int(h // ratio)),
49
+ interpolation=cv2.INTER_CUBIC)
50
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_NEAREST)
51
+ assert len(image.shape) < 4
52
+ return image
53
+
54
+ @staticmethod
55
+ def get_config_template():
56
+ return dict_to_yaml('ANNOTATORS',
57
+ __class__.__name__,
58
+ ColorAnnotator.para_dict,
59
+ set_name=True)
annotator/hed.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # Please use this implementation in your products
4
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
5
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
6
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
7
+ # and in this way it works better for gradio's RGB protocol
8
+
9
+ from abc import ABCMeta
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ from einops import rearrange
15
+
16
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
17
+ from scepter.modules.annotator.registry import ANNOTATORS
18
+ from scepter.modules.utils.config import dict_to_yaml
19
+ from scepter.modules.utils.distribute import we
20
+ from scepter.modules.utils.file_system import FS
21
+
22
+
23
+ def nms(x, t, s):
24
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
25
+
26
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
27
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
28
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
29
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
30
+
31
+ y = np.zeros_like(x)
32
+
33
+ for f in [f1, f2, f3, f4]:
34
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
35
+
36
+ z = np.zeros_like(y, dtype=np.uint8)
37
+ z[y > t] = 255
38
+ return z
39
+
40
+
41
+ class DoubleConvBlock(torch.nn.Module):
42
+ def __init__(self, input_channel, output_channel, layer_number):
43
+ super().__init__()
44
+ self.convs = torch.nn.Sequential()
45
+ self.convs.append(
46
+ torch.nn.Conv2d(in_channels=input_channel,
47
+ out_channels=output_channel,
48
+ kernel_size=(3, 3),
49
+ stride=(1, 1),
50
+ padding=1))
51
+ for i in range(1, layer_number):
52
+ self.convs.append(
53
+ torch.nn.Conv2d(in_channels=output_channel,
54
+ out_channels=output_channel,
55
+ kernel_size=(3, 3),
56
+ stride=(1, 1),
57
+ padding=1))
58
+ self.projection = torch.nn.Conv2d(in_channels=output_channel,
59
+ out_channels=1,
60
+ kernel_size=(1, 1),
61
+ stride=(1, 1),
62
+ padding=0)
63
+
64
+ def __call__(self, x, down_sampling=False):
65
+ h = x
66
+ if down_sampling:
67
+ h = torch.nn.functional.max_pool2d(h,
68
+ kernel_size=(2, 2),
69
+ stride=(2, 2))
70
+ for conv in self.convs:
71
+ h = conv(h)
72
+ h = torch.nn.functional.relu(h)
73
+ return h, self.projection(h)
74
+
75
+
76
+ class ControlNetHED_Apache2(torch.nn.Module):
77
+ def __init__(self):
78
+ super().__init__()
79
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
80
+ self.block1 = DoubleConvBlock(input_channel=3,
81
+ output_channel=64,
82
+ layer_number=2)
83
+ self.block2 = DoubleConvBlock(input_channel=64,
84
+ output_channel=128,
85
+ layer_number=2)
86
+ self.block3 = DoubleConvBlock(input_channel=128,
87
+ output_channel=256,
88
+ layer_number=3)
89
+ self.block4 = DoubleConvBlock(input_channel=256,
90
+ output_channel=512,
91
+ layer_number=3)
92
+ self.block5 = DoubleConvBlock(input_channel=512,
93
+ output_channel=512,
94
+ layer_number=3)
95
+
96
+ def __call__(self, x):
97
+ h = x - self.norm
98
+ h, projection1 = self.block1(h)
99
+ h, projection2 = self.block2(h, down_sampling=True)
100
+ h, projection3 = self.block3(h, down_sampling=True)
101
+ h, projection4 = self.block4(h, down_sampling=True)
102
+ h, projection5 = self.block5(h, down_sampling=True)
103
+ return projection1, projection2, projection3, projection4, projection5
104
+
105
+
106
+ @ANNOTATORS.register_class()
107
+ class HedAnnotator(BaseAnnotator, metaclass=ABCMeta):
108
+ para_dict = {}
109
+
110
+ def __init__(self, cfg, logger=None):
111
+ super().__init__(cfg, logger=logger)
112
+ self.netNetwork = ControlNetHED_Apache2().float().eval()
113
+ pretrained_model = cfg.get('PRETRAINED_MODEL', None)
114
+ if pretrained_model:
115
+ with FS.get_from(pretrained_model, wait_finish=True) as local_path:
116
+ self.netNetwork.load_state_dict(torch.load(local_path))
117
+
118
+ @torch.no_grad()
119
+ @torch.inference_mode()
120
+ @torch.autocast('cuda', enabled=False)
121
+ def forward(self, image):
122
+ if isinstance(image, torch.Tensor):
123
+ if len(image.shape) == 3:
124
+ image = rearrange(image, 'h w c -> 1 c h w')
125
+ B, C, H, W = image.shape
126
+ else:
127
+ raise "Unsurpport input image's shape"
128
+ elif isinstance(image, np.ndarray):
129
+ image = torch.from_numpy(image.copy()).float()
130
+ if len(image.shape) == 3:
131
+ image = rearrange(image, 'h w c -> 1 c h w')
132
+ B, C, H, W = image.shape
133
+ else:
134
+ raise "Unsurpport input image's shape"
135
+ else:
136
+ raise "Unsurpport input image's type"
137
+ edges = self.netNetwork(image.to(we.device_id))
138
+ edges = [
139
+ e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges
140
+ ]
141
+ edges = [
142
+ cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR)
143
+ for e in edges
144
+ ]
145
+ edges = np.stack(edges, axis=2)
146
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
147
+ edge = 255 - (edge * 255.0).clip(0, 255).astype(np.uint8)
148
+ return edge[..., None].repeat(3, 2)
149
+
150
+ @staticmethod
151
+ def get_config_template():
152
+ return dict_to_yaml('ANNOTATORS',
153
+ __class__.__name__,
154
+ HedAnnotator.para_dict,
155
+ set_name=True)
annotator/identity.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from abc import ABCMeta
4
+
5
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
6
+ from scepter.modules.annotator.registry import ANNOTATORS
7
+ from scepter.modules.utils.config import dict_to_yaml
8
+
9
+
10
+ @ANNOTATORS.register_class()
11
+ class IdentityAnnotator(BaseAnnotator, metaclass=ABCMeta):
12
+ para_dict = {}
13
+
14
+ def __init__(self, cfg, logger=None):
15
+ super().__init__(cfg, logger=logger)
16
+
17
+ def forward(self, image):
18
+ return image
19
+
20
+ @staticmethod
21
+ def get_config_template():
22
+ return dict_to_yaml('ANNOTATORS',
23
+ __class__.__name__,
24
+ IdentityAnnotator.para_dict,
25
+ set_name=True)
annotator/invert.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from abc import ABCMeta
4
+
5
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
6
+ from scepter.modules.annotator.registry import ANNOTATORS
7
+ from scepter.modules.utils.config import dict_to_yaml
8
+
9
+
10
+ @ANNOTATORS.register_class()
11
+ class InvertAnnotator(BaseAnnotator, metaclass=ABCMeta):
12
+ para_dict = {}
13
+
14
+ def __init__(self, cfg, logger=None):
15
+ super().__init__(cfg, logger=logger)
16
+
17
+ def forward(self, image):
18
+ return 255 - image
19
+
20
+ @staticmethod
21
+ def get_config_template():
22
+ return dict_to_yaml('ANNOTATORS',
23
+ __class__.__name__,
24
+ InvertAnnotator.para_dict,
25
+ set_name=True)
annotator/midas/__init__.py ADDED
File without changes
annotator/midas/api.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # based on https://github.com/isl-org/MiDaS
3
+
4
+ import cv2
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.transforms import Compose
8
+
9
+ from .dpt_depth import DPTDepthModel
10
+ from .midas_net import MidasNet
11
+ from .midas_net_custom import MidasNet_small
12
+ from .transforms import NormalizeImage, PrepareForNet, Resize
13
+
14
+ # ISL_PATHS = {
15
+ # "dpt_large": "dpt_large-midas-2f21e586.pt",
16
+ # "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
17
+ # "midas_v21": "",
18
+ # "midas_v21_small": "",
19
+ # }
20
+
21
+ # remote_model_path =
22
+ # "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
23
+
24
+
25
+ def disabled_train(self, mode=True):
26
+ """Overwrite model.train with this function to make sure train/eval mode
27
+ does not change anymore."""
28
+ return self
29
+
30
+
31
+ def load_midas_transform(model_type):
32
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
33
+ # load transform only
34
+ if model_type == 'dpt_large': # DPT-Large
35
+ net_w, net_h = 384, 384
36
+ resize_mode = 'minimal'
37
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
38
+ std=[0.5, 0.5, 0.5])
39
+
40
+ elif model_type == 'dpt_hybrid': # DPT-Hybrid
41
+ net_w, net_h = 384, 384
42
+ resize_mode = 'minimal'
43
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
44
+ std=[0.5, 0.5, 0.5])
45
+
46
+ elif model_type == 'midas_v21':
47
+ net_w, net_h = 384, 384
48
+ resize_mode = 'upper_bound'
49
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
50
+ std=[0.229, 0.224, 0.225])
51
+
52
+ elif model_type == 'midas_v21_small':
53
+ net_w, net_h = 256, 256
54
+ resize_mode = 'upper_bound'
55
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
56
+ std=[0.229, 0.224, 0.225])
57
+
58
+ else:
59
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
60
+
61
+ transform = Compose([
62
+ Resize(
63
+ net_w,
64
+ net_h,
65
+ resize_target=None,
66
+ keep_aspect_ratio=True,
67
+ ensure_multiple_of=32,
68
+ resize_method=resize_mode,
69
+ image_interpolation_method=cv2.INTER_CUBIC,
70
+ ),
71
+ normalization,
72
+ PrepareForNet(),
73
+ ])
74
+
75
+ return transform
76
+
77
+
78
+ def load_model(model_type, model_path):
79
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
80
+ # load network
81
+ # model_path = ISL_PATHS[model_type]
82
+ if model_type == 'dpt_large': # DPT-Large
83
+ model = DPTDepthModel(
84
+ path=model_path,
85
+ backbone='vitl16_384',
86
+ non_negative=True,
87
+ )
88
+ net_w, net_h = 384, 384
89
+ resize_mode = 'minimal'
90
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
91
+ std=[0.5, 0.5, 0.5])
92
+
93
+ elif model_type == 'dpt_hybrid': # DPT-Hybrid
94
+ model = DPTDepthModel(
95
+ path=model_path,
96
+ backbone='vitb_rn50_384',
97
+ non_negative=True,
98
+ )
99
+ net_w, net_h = 384, 384
100
+ resize_mode = 'minimal'
101
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
102
+ std=[0.5, 0.5, 0.5])
103
+
104
+ elif model_type == 'midas_v21':
105
+ model = MidasNet(model_path, non_negative=True)
106
+ net_w, net_h = 384, 384
107
+ resize_mode = 'upper_bound'
108
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
109
+ std=[0.229, 0.224, 0.225])
110
+
111
+ elif model_type == 'midas_v21_small':
112
+ model = MidasNet_small(model_path,
113
+ features=64,
114
+ backbone='efficientnet_lite3',
115
+ exportable=True,
116
+ non_negative=True,
117
+ blocks={'expand': True})
118
+ net_w, net_h = 256, 256
119
+ resize_mode = 'upper_bound'
120
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
121
+ std=[0.229, 0.224, 0.225])
122
+
123
+ else:
124
+ print(
125
+ f"model_type '{model_type}' not implemented, use: --model_type large"
126
+ )
127
+ assert False
128
+
129
+ transform = Compose([
130
+ Resize(
131
+ net_w,
132
+ net_h,
133
+ resize_target=None,
134
+ keep_aspect_ratio=True,
135
+ ensure_multiple_of=32,
136
+ resize_method=resize_mode,
137
+ image_interpolation_method=cv2.INTER_CUBIC,
138
+ ),
139
+ normalization,
140
+ PrepareForNet(),
141
+ ])
142
+
143
+ return model.eval(), transform
144
+
145
+
146
+ class MiDaSInference(nn.Module):
147
+ MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
148
+ MODEL_TYPES_ISL = [
149
+ 'dpt_large',
150
+ 'dpt_hybrid',
151
+ 'midas_v21',
152
+ 'midas_v21_small',
153
+ ]
154
+
155
+ def __init__(self, model_type, model_path):
156
+ super().__init__()
157
+ assert (model_type in self.MODEL_TYPES_ISL)
158
+ model, _ = load_model(model_type, model_path)
159
+ self.model = model
160
+ self.model.train = disabled_train
161
+
162
+ def forward(self, x):
163
+ with torch.no_grad():
164
+ prediction = self.model(x)
165
+ return prediction
annotator/midas/base_model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+
4
+
5
+ class BaseModel(torch.nn.Module):
6
+ def load(self, path):
7
+ """Load model from file.
8
+
9
+ Args:
10
+ path (str): file path
11
+ """
12
+ parameters = torch.load(path, map_location=torch.device('cpu'))
13
+
14
+ if 'optimizer' in parameters:
15
+ parameters = parameters['model']
16
+
17
+ self.load_state_dict(parameters)
annotator/midas/blocks.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384)
7
+
8
+
9
+ def _make_encoder(
10
+ backbone,
11
+ features,
12
+ use_pretrained,
13
+ groups=1,
14
+ expand=False,
15
+ exportable=True,
16
+ hooks=None,
17
+ use_vit_only=False,
18
+ use_readout='ignore',
19
+ ):
20
+ if backbone == 'vitl16_384':
21
+ pretrained = _make_pretrained_vitl16_384(use_pretrained,
22
+ hooks=hooks,
23
+ use_readout=use_readout)
24
+ scratch = _make_scratch(
25
+ [256, 512, 1024, 1024], features, groups=groups,
26
+ expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
27
+ elif backbone == 'vitb_rn50_384':
28
+ pretrained = _make_pretrained_vitb_rn50_384(
29
+ use_pretrained,
30
+ hooks=hooks,
31
+ use_vit_only=use_vit_only,
32
+ use_readout=use_readout,
33
+ )
34
+ scratch = _make_scratch(
35
+ [256, 512, 768, 768], features, groups=groups,
36
+ expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
37
+ elif backbone == 'vitb16_384':
38
+ pretrained = _make_pretrained_vitb16_384(use_pretrained,
39
+ hooks=hooks,
40
+ use_readout=use_readout)
41
+ scratch = _make_scratch(
42
+ [96, 192, 384, 768], features, groups=groups,
43
+ expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
44
+ elif backbone == 'resnext101_wsl':
45
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
46
+ scratch = _make_scratch([256, 512, 1024, 2048],
47
+ features,
48
+ groups=groups,
49
+ expand=expand) # efficientnet_lite3
50
+ elif backbone == 'efficientnet_lite3':
51
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
52
+ exportable=exportable)
53
+ scratch = _make_scratch([32, 48, 136, 384],
54
+ features,
55
+ groups=groups,
56
+ expand=expand) # efficientnet_lite3
57
+ else:
58
+ print(f"Backbone '{backbone}' not implemented")
59
+ assert False
60
+
61
+ return pretrained, scratch
62
+
63
+
64
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
65
+ scratch = nn.Module()
66
+
67
+ out_shape1 = out_shape
68
+ out_shape2 = out_shape
69
+ out_shape3 = out_shape
70
+ out_shape4 = out_shape
71
+ if expand is True:
72
+ out_shape1 = out_shape
73
+ out_shape2 = out_shape * 2
74
+ out_shape3 = out_shape * 4
75
+ out_shape4 = out_shape * 8
76
+
77
+ scratch.layer1_rn = nn.Conv2d(in_shape[0],
78
+ out_shape1,
79
+ kernel_size=3,
80
+ stride=1,
81
+ padding=1,
82
+ bias=False,
83
+ groups=groups)
84
+ scratch.layer2_rn = nn.Conv2d(in_shape[1],
85
+ out_shape2,
86
+ kernel_size=3,
87
+ stride=1,
88
+ padding=1,
89
+ bias=False,
90
+ groups=groups)
91
+ scratch.layer3_rn = nn.Conv2d(in_shape[2],
92
+ out_shape3,
93
+ kernel_size=3,
94
+ stride=1,
95
+ padding=1,
96
+ bias=False,
97
+ groups=groups)
98
+ scratch.layer4_rn = nn.Conv2d(in_shape[3],
99
+ out_shape4,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1,
103
+ bias=False,
104
+ groups=groups)
105
+
106
+ return scratch
107
+
108
+
109
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
110
+ efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
111
+ 'tf_efficientnet_lite3',
112
+ pretrained=use_pretrained,
113
+ exportable=exportable)
114
+ return _make_efficientnet_backbone(efficientnet)
115
+
116
+
117
+ def _make_efficientnet_backbone(effnet):
118
+ pretrained = nn.Module()
119
+
120
+ pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
121
+ effnet.act1, *effnet.blocks[0:2])
122
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
123
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
124
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
125
+
126
+ return pretrained
127
+
128
+
129
+ def _make_resnet_backbone(resnet):
130
+ pretrained = nn.Module()
131
+ pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
132
+ resnet.maxpool, resnet.layer1)
133
+
134
+ pretrained.layer2 = resnet.layer2
135
+ pretrained.layer3 = resnet.layer3
136
+ pretrained.layer4 = resnet.layer4
137
+
138
+ return pretrained
139
+
140
+
141
+ def _make_pretrained_resnext101_wsl(use_pretrained):
142
+ resnet = torch.hub.load('facebookresearch/WSL-Images',
143
+ 'resnext101_32x8d_wsl')
144
+ return _make_resnet_backbone(resnet)
145
+
146
+
147
+ class Interpolate(nn.Module):
148
+ """Interpolation module.
149
+ """
150
+ def __init__(self, scale_factor, mode, align_corners=False):
151
+ """Init.
152
+
153
+ Args:
154
+ scale_factor (float): scaling
155
+ mode (str): interpolation mode
156
+ """
157
+ super(Interpolate, self).__init__()
158
+
159
+ self.interp = nn.functional.interpolate
160
+ self.scale_factor = scale_factor
161
+ self.mode = mode
162
+ self.align_corners = align_corners
163
+
164
+ def forward(self, x):
165
+ """Forward pass.
166
+
167
+ Args:
168
+ x (tensor): input
169
+
170
+ Returns:
171
+ tensor: interpolated data
172
+ """
173
+
174
+ x = self.interp(x,
175
+ scale_factor=self.scale_factor,
176
+ mode=self.mode,
177
+ align_corners=self.align_corners)
178
+
179
+ return x
180
+
181
+
182
+ class ResidualConvUnit(nn.Module):
183
+ """Residual convolution module.
184
+ """
185
+ def __init__(self, features):
186
+ """Init.
187
+
188
+ Args:
189
+ features (int): number of features
190
+ """
191
+ super().__init__()
192
+
193
+ self.conv1 = nn.Conv2d(features,
194
+ features,
195
+ kernel_size=3,
196
+ stride=1,
197
+ padding=1,
198
+ bias=True)
199
+
200
+ self.conv2 = nn.Conv2d(features,
201
+ features,
202
+ kernel_size=3,
203
+ stride=1,
204
+ padding=1,
205
+ bias=True)
206
+
207
+ self.relu = nn.ReLU(inplace=True)
208
+
209
+ def forward(self, x):
210
+ """Forward pass.
211
+
212
+ Args:
213
+ x (tensor): input
214
+
215
+ Returns:
216
+ tensor: output
217
+ """
218
+ out = self.relu(x)
219
+ out = self.conv1(out)
220
+ out = self.relu(out)
221
+ out = self.conv2(out)
222
+
223
+ return out + x
224
+
225
+
226
+ class FeatureFusionBlock(nn.Module):
227
+ """Feature fusion block.
228
+ """
229
+ def __init__(self, features):
230
+ """Init.
231
+
232
+ Args:
233
+ features (int): number of features
234
+ """
235
+ super(FeatureFusionBlock, self).__init__()
236
+
237
+ self.resConfUnit1 = ResidualConvUnit(features)
238
+ self.resConfUnit2 = ResidualConvUnit(features)
239
+
240
+ def forward(self, *xs):
241
+ """Forward pass.
242
+
243
+ Returns:
244
+ tensor: output
245
+ """
246
+ output = xs[0]
247
+
248
+ if len(xs) == 2:
249
+ output += self.resConfUnit1(xs[1])
250
+
251
+ output = self.resConfUnit2(output)
252
+
253
+ output = nn.functional.interpolate(output,
254
+ scale_factor=2,
255
+ mode='bilinear',
256
+ align_corners=True)
257
+
258
+ return output
259
+
260
+
261
+ class ResidualConvUnit_custom(nn.Module):
262
+ """Residual convolution module.
263
+ """
264
+ def __init__(self, features, activation, bn):
265
+ """Init.
266
+
267
+ Args:
268
+ features (int): number of features
269
+ """
270
+ super().__init__()
271
+
272
+ self.bn = bn
273
+
274
+ self.groups = 1
275
+
276
+ self.conv1 = nn.Conv2d(features,
277
+ features,
278
+ kernel_size=3,
279
+ stride=1,
280
+ padding=1,
281
+ bias=True,
282
+ groups=self.groups)
283
+
284
+ self.conv2 = nn.Conv2d(features,
285
+ features,
286
+ kernel_size=3,
287
+ stride=1,
288
+ padding=1,
289
+ bias=True,
290
+ groups=self.groups)
291
+
292
+ if self.bn is True:
293
+ self.bn1 = nn.BatchNorm2d(features)
294
+ self.bn2 = nn.BatchNorm2d(features)
295
+
296
+ self.activation = activation
297
+
298
+ self.skip_add = nn.quantized.FloatFunctional()
299
+
300
+ def forward(self, x):
301
+ """Forward pass.
302
+
303
+ Args:
304
+ x (tensor): input
305
+
306
+ Returns:
307
+ tensor: output
308
+ """
309
+
310
+ out = self.activation(x)
311
+ out = self.conv1(out)
312
+ if self.bn is True:
313
+ out = self.bn1(out)
314
+
315
+ out = self.activation(out)
316
+ out = self.conv2(out)
317
+ if self.bn is True:
318
+ out = self.bn2(out)
319
+
320
+ if self.groups > 1:
321
+ out = self.conv_merge(out)
322
+
323
+ return self.skip_add.add(out, x)
324
+
325
+ # return out + x
326
+
327
+
328
+ class FeatureFusionBlock_custom(nn.Module):
329
+ """Feature fusion block.
330
+ """
331
+ def __init__(self,
332
+ features,
333
+ activation,
334
+ deconv=False,
335
+ bn=False,
336
+ expand=False,
337
+ align_corners=True):
338
+ """Init.
339
+
340
+ Args:
341
+ features (int): number of features
342
+ """
343
+ super(FeatureFusionBlock_custom, self).__init__()
344
+
345
+ self.deconv = deconv
346
+ self.align_corners = align_corners
347
+
348
+ self.groups = 1
349
+
350
+ self.expand = expand
351
+ out_features = features
352
+ if self.expand is True:
353
+ out_features = features // 2
354
+
355
+ self.out_conv = nn.Conv2d(features,
356
+ out_features,
357
+ kernel_size=1,
358
+ stride=1,
359
+ padding=0,
360
+ bias=True,
361
+ groups=1)
362
+
363
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
364
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
365
+
366
+ self.skip_add = nn.quantized.FloatFunctional()
367
+
368
+ def forward(self, *xs):
369
+ """Forward pass.
370
+
371
+ Returns:
372
+ tensor: output
373
+ """
374
+ output = xs[0]
375
+
376
+ if len(xs) == 2:
377
+ res = self.resConfUnit1(xs[1])
378
+ output = self.skip_add.add(output, res)
379
+ # output += res
380
+
381
+ output = self.resConfUnit2(output)
382
+
383
+ output = nn.functional.interpolate(output,
384
+ scale_factor=2,
385
+ mode='bilinear',
386
+ align_corners=self.align_corners)
387
+
388
+ output = self.out_conv(output)
389
+
390
+ return output
annotator/midas/dpt_depth.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
7
+ from .vit import forward_vit
8
+
9
+
10
+ def _make_fusion_block(features, use_bn):
11
+ return FeatureFusionBlock_custom(
12
+ features,
13
+ nn.ReLU(False),
14
+ deconv=False,
15
+ bn=use_bn,
16
+ expand=False,
17
+ align_corners=True,
18
+ )
19
+
20
+
21
+ class DPT(BaseModel):
22
+ def __init__(
23
+ self,
24
+ head,
25
+ features=256,
26
+ backbone='vitb_rn50_384',
27
+ readout='project',
28
+ channels_last=False,
29
+ use_bn=False,
30
+ ):
31
+
32
+ super(DPT, self).__init__()
33
+
34
+ self.channels_last = channels_last
35
+
36
+ hooks = {
37
+ 'vitb_rn50_384': [0, 1, 8, 11],
38
+ 'vitb16_384': [2, 5, 8, 11],
39
+ 'vitl16_384': [5, 11, 17, 23],
40
+ }
41
+
42
+ # Instantiate backbone and reassemble blocks
43
+ self.pretrained, self.scratch = _make_encoder(
44
+ backbone,
45
+ features,
46
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
47
+ groups=1,
48
+ expand=False,
49
+ exportable=False,
50
+ hooks=hooks[backbone],
51
+ use_readout=readout,
52
+ )
53
+
54
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
55
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
56
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
57
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
58
+
59
+ self.scratch.output_conv = head
60
+
61
+ def forward(self, x):
62
+ if self.channels_last is True:
63
+ x.contiguous(memory_format=torch.channels_last)
64
+
65
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
66
+
67
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
68
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
69
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
70
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
71
+
72
+ path_4 = self.scratch.refinenet4(layer_4_rn)
73
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
74
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
75
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
76
+
77
+ out = self.scratch.output_conv(path_1)
78
+
79
+ return out
80
+
81
+
82
+ class DPTDepthModel(DPT):
83
+ def __init__(self, path=None, non_negative=True, **kwargs):
84
+ features = kwargs['features'] if 'features' in kwargs else 256
85
+
86
+ head = nn.Sequential(
87
+ nn.Conv2d(features,
88
+ features // 2,
89
+ kernel_size=3,
90
+ stride=1,
91
+ padding=1),
92
+ Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
93
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
94
+ nn.ReLU(True),
95
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
96
+ nn.ReLU(True) if non_negative else nn.Identity(),
97
+ nn.Identity(),
98
+ )
99
+
100
+ super().__init__(head, **kwargs)
101
+
102
+ if path is not None:
103
+ self.load(path)
104
+
105
+ def forward(self, x):
106
+ return super().forward(x).squeeze(dim=1)
annotator/midas/midas_net.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
3
+ This file contains code that is adapted from
4
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .base_model import BaseModel
10
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
11
+
12
+
13
+ class MidasNet(BaseModel):
14
+ """Network for monocular depth estimation.
15
+ """
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print('Loading weights: ', path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(
31
+ backbone='resnext101_wsl',
32
+ features=features,
33
+ use_pretrained=use_pretrained)
34
+
35
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
36
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
37
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
38
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
39
+
40
+ self.scratch.output_conv = nn.Sequential(
41
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
42
+ Interpolate(scale_factor=2, mode='bilinear'),
43
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
44
+ nn.ReLU(True),
45
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
46
+ nn.ReLU(True) if non_negative else nn.Identity(),
47
+ )
48
+
49
+ if path:
50
+ self.load(path)
51
+
52
+ def forward(self, x):
53
+ """Forward pass.
54
+
55
+ Args:
56
+ x (tensor): input data (image)
57
+
58
+ Returns:
59
+ tensor: depth
60
+ """
61
+
62
+ layer_1 = self.pretrained.layer1(x)
63
+ layer_2 = self.pretrained.layer2(layer_1)
64
+ layer_3 = self.pretrained.layer3(layer_2)
65
+ layer_4 = self.pretrained.layer4(layer_3)
66
+
67
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
68
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
69
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
70
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
71
+
72
+ path_4 = self.scratch.refinenet4(layer_4_rn)
73
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
74
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
75
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
76
+
77
+ out = self.scratch.output_conv(path_1)
78
+
79
+ return torch.squeeze(out, dim=1)
annotator/midas/midas_net_custom.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
3
+ This file contains code that is adapted from
4
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .base_model import BaseModel
10
+ from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
11
+
12
+
13
+ class MidasNet_small(BaseModel):
14
+ """Network for monocular depth estimation.
15
+ """
16
+ def __init__(self,
17
+ path=None,
18
+ features=64,
19
+ backbone='efficientnet_lite3',
20
+ non_negative=True,
21
+ exportable=True,
22
+ channels_last=False,
23
+ align_corners=True,
24
+ blocks={'expand': True}):
25
+ """Init.
26
+
27
+ Args:
28
+ path (str, optional): Path to saved model. Defaults to None.
29
+ features (int, optional): Number of features. Defaults to 256.
30
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
31
+ """
32
+ print('Loading weights: ', path)
33
+
34
+ super(MidasNet_small, self).__init__()
35
+
36
+ use_pretrained = False if path else True
37
+
38
+ self.channels_last = channels_last
39
+ self.blocks = blocks
40
+ self.backbone = backbone
41
+
42
+ self.groups = 1
43
+
44
+ features1 = features
45
+ features2 = features
46
+ features3 = features
47
+ features4 = features
48
+ self.expand = False
49
+ if 'expand' in self.blocks and self.blocks['expand'] is True:
50
+ self.expand = True
51
+ features1 = features
52
+ features2 = features * 2
53
+ features3 = features * 4
54
+ features4 = features * 8
55
+
56
+ self.pretrained, self.scratch = _make_encoder(self.backbone,
57
+ features,
58
+ use_pretrained,
59
+ groups=self.groups,
60
+ expand=self.expand,
61
+ exportable=exportable)
62
+
63
+ self.scratch.activation = nn.ReLU(False)
64
+
65
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(
66
+ features4,
67
+ self.scratch.activation,
68
+ deconv=False,
69
+ bn=False,
70
+ expand=self.expand,
71
+ align_corners=align_corners)
72
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(
73
+ features3,
74
+ self.scratch.activation,
75
+ deconv=False,
76
+ bn=False,
77
+ expand=self.expand,
78
+ align_corners=align_corners)
79
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(
80
+ features2,
81
+ self.scratch.activation,
82
+ deconv=False,
83
+ bn=False,
84
+ expand=self.expand,
85
+ align_corners=align_corners)
86
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(
87
+ features1,
88
+ self.scratch.activation,
89
+ deconv=False,
90
+ bn=False,
91
+ align_corners=align_corners)
92
+
93
+ self.scratch.output_conv = nn.Sequential(
94
+ nn.Conv2d(features,
95
+ features // 2,
96
+ kernel_size=3,
97
+ stride=1,
98
+ padding=1,
99
+ groups=self.groups),
100
+ Interpolate(scale_factor=2, mode='bilinear'),
101
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
102
+ self.scratch.activation,
103
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
104
+ nn.ReLU(True) if non_negative else nn.Identity(),
105
+ nn.Identity(),
106
+ )
107
+
108
+ if path:
109
+ self.load(path)
110
+
111
+ def forward(self, x):
112
+ """Forward pass.
113
+
114
+ Args:
115
+ x (tensor): input data (image)
116
+
117
+ Returns:
118
+ tensor: depth
119
+ """
120
+ if self.channels_last is True:
121
+ print('self.channels_last = ', self.channels_last)
122
+ x.contiguous(memory_format=torch.channels_last)
123
+
124
+ layer_1 = self.pretrained.layer1(x)
125
+ layer_2 = self.pretrained.layer2(layer_1)
126
+ layer_3 = self.pretrained.layer3(layer_2)
127
+ layer_4 = self.pretrained.layer4(layer_3)
128
+
129
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
130
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
131
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
132
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
133
+
134
+ path_4 = self.scratch.refinenet4(layer_4_rn)
135
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
136
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
137
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
138
+
139
+ out = self.scratch.output_conv(path_1)
140
+
141
+ return torch.squeeze(out, dim=1)
142
+
143
+
144
+ def fuse_model(m):
145
+ prev_previous_type = nn.Identity()
146
+ prev_previous_name = ''
147
+ previous_type = nn.Identity()
148
+ previous_name = ''
149
+ for name, module in m.named_modules():
150
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
151
+ module) == nn.ReLU:
152
+ # print("FUSED ", prev_previous_name, previous_name, name)
153
+ torch.quantization.fuse_modules(
154
+ m, [prev_previous_name, previous_name, name], inplace=True)
155
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
156
+ # print("FUSED ", prev_previous_name, previous_name)
157
+ torch.quantization.fuse_modules(
158
+ m, [prev_previous_name, previous_name], inplace=True)
159
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
160
+ # print("FUSED ", previous_name, name)
161
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
162
+
163
+ prev_previous_type = previous_type
164
+ prev_previous_name = previous_name
165
+ previous_type = type(module)
166
+ previous_name = name
annotator/midas/transforms.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import math
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+
8
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
9
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
10
+
11
+ Args:
12
+ sample (dict): sample
13
+ size (tuple): image size
14
+
15
+ Returns:
16
+ tuple: new size
17
+ """
18
+ shape = list(sample['disparity'].shape)
19
+
20
+ if shape[0] >= size[0] and shape[1] >= size[1]:
21
+ return sample
22
+
23
+ scale = [0, 0]
24
+ scale[0] = size[0] / shape[0]
25
+ scale[1] = size[1] / shape[1]
26
+
27
+ scale = max(scale)
28
+
29
+ shape[0] = math.ceil(scale * shape[0])
30
+ shape[1] = math.ceil(scale * shape[1])
31
+
32
+ # resize
33
+ sample['image'] = cv2.resize(sample['image'],
34
+ tuple(shape[::-1]),
35
+ interpolation=image_interpolation_method)
36
+
37
+ sample['disparity'] = cv2.resize(sample['disparity'],
38
+ tuple(shape[::-1]),
39
+ interpolation=cv2.INTER_NEAREST)
40
+ sample['mask'] = cv2.resize(
41
+ sample['mask'].astype(np.float32),
42
+ tuple(shape[::-1]),
43
+ interpolation=cv2.INTER_NEAREST,
44
+ )
45
+ sample['mask'] = sample['mask'].astype(bool)
46
+
47
+ return tuple(shape)
48
+
49
+
50
+ class Resize(object):
51
+ """Resize sample to given size (width, height).
52
+ """
53
+ def __init__(
54
+ self,
55
+ width,
56
+ height,
57
+ resize_target=True,
58
+ keep_aspect_ratio=False,
59
+ ensure_multiple_of=1,
60
+ resize_method='lower_bound',
61
+ image_interpolation_method=cv2.INTER_AREA,
62
+ ):
63
+ """Init.
64
+
65
+ Args:
66
+ width (int): desired output width
67
+ height (int): desired output height
68
+ resize_target (bool, optional):
69
+ True: Resize the full sample (image, mask, target).
70
+ False: Resize image only.
71
+ Defaults to True.
72
+ keep_aspect_ratio (bool, optional):
73
+ True: Keep the aspect ratio of the input sample.
74
+ Output sample might not have the given width and height, and
75
+ resize behaviour depends on the parameter 'resize_method'.
76
+ Defaults to False.
77
+ ensure_multiple_of (int, optional):
78
+ Output width and height is constrained to be multiple of this parameter.
79
+ Defaults to 1.
80
+ resize_method (str, optional):
81
+ "lower_bound": Output will be at least as large as the given size.
82
+ "upper_bound": Output will be at max as large as the given size. "
83
+ "(Output size might be smaller than given size.)"
84
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
85
+ Defaults to "lower_bound".
86
+ """
87
+ self.__width = width
88
+ self.__height = height
89
+
90
+ self.__resize_target = resize_target
91
+ self.__keep_aspect_ratio = keep_aspect_ratio
92
+ self.__multiple_of = ensure_multiple_of
93
+ self.__resize_method = resize_method
94
+ self.__image_interpolation_method = image_interpolation_method
95
+
96
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
97
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
98
+
99
+ if max_val is not None and y > max_val:
100
+ y = (np.floor(x / self.__multiple_of) *
101
+ self.__multiple_of).astype(int)
102
+
103
+ if y < min_val:
104
+ y = (np.ceil(x / self.__multiple_of) *
105
+ self.__multiple_of).astype(int)
106
+
107
+ return y
108
+
109
+ def get_size(self, width, height):
110
+ # determine new height and width
111
+ scale_height = self.__height / height
112
+ scale_width = self.__width / width
113
+
114
+ if self.__keep_aspect_ratio:
115
+ if self.__resize_method == 'lower_bound':
116
+ # scale such that output size is lower bound
117
+ if scale_width > scale_height:
118
+ # fit width
119
+ scale_height = scale_width
120
+ else:
121
+ # fit height
122
+ scale_width = scale_height
123
+ elif self.__resize_method == 'upper_bound':
124
+ # scale such that output size is upper bound
125
+ if scale_width < scale_height:
126
+ # fit width
127
+ scale_height = scale_width
128
+ else:
129
+ # fit height
130
+ scale_width = scale_height
131
+ elif self.__resize_method == 'minimal':
132
+ # scale as least as possbile
133
+ if abs(1 - scale_width) < abs(1 - scale_height):
134
+ # fit width
135
+ scale_height = scale_width
136
+ else:
137
+ # fit height
138
+ scale_width = scale_height
139
+ else:
140
+ raise ValueError(
141
+ f'resize_method {self.__resize_method} not implemented')
142
+
143
+ if self.__resize_method == 'lower_bound':
144
+ new_height = self.constrain_to_multiple_of(scale_height * height,
145
+ min_val=self.__height)
146
+ new_width = self.constrain_to_multiple_of(scale_width * width,
147
+ min_val=self.__width)
148
+ elif self.__resize_method == 'upper_bound':
149
+ new_height = self.constrain_to_multiple_of(scale_height * height,
150
+ max_val=self.__height)
151
+ new_width = self.constrain_to_multiple_of(scale_width * width,
152
+ max_val=self.__width)
153
+ elif self.__resize_method == 'minimal':
154
+ new_height = self.constrain_to_multiple_of(scale_height * height)
155
+ new_width = self.constrain_to_multiple_of(scale_width * width)
156
+ else:
157
+ raise ValueError(
158
+ f'resize_method {self.__resize_method} not implemented')
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(sample['image'].shape[1],
164
+ sample['image'].shape[0])
165
+
166
+ # resize sample
167
+ sample['image'] = cv2.resize(
168
+ sample['image'],
169
+ (width, height),
170
+ interpolation=self.__image_interpolation_method,
171
+ )
172
+
173
+ if self.__resize_target:
174
+ if 'disparity' in sample:
175
+ sample['disparity'] = cv2.resize(
176
+ sample['disparity'],
177
+ (width, height),
178
+ interpolation=cv2.INTER_NEAREST,
179
+ )
180
+
181
+ if 'depth' in sample:
182
+ sample['depth'] = cv2.resize(sample['depth'], (width, height),
183
+ interpolation=cv2.INTER_NEAREST)
184
+
185
+ sample['mask'] = cv2.resize(
186
+ sample['mask'].astype(np.float32),
187
+ (width, height),
188
+ interpolation=cv2.INTER_NEAREST,
189
+ )
190
+ sample['mask'] = sample['mask'].astype(bool)
191
+
192
+ return sample
193
+
194
+
195
+ class NormalizeImage(object):
196
+ """Normlize image by given mean and std.
197
+ """
198
+ def __init__(self, mean, std):
199
+ self.__mean = mean
200
+ self.__std = std
201
+
202
+ def __call__(self, sample):
203
+ sample['image'] = (sample['image'] - self.__mean) / self.__std
204
+
205
+ return sample
206
+
207
+
208
+ class PrepareForNet(object):
209
+ """Prepare sample for usage as network input.
210
+ """
211
+ def __init__(self):
212
+ pass
213
+
214
+ def __call__(self, sample):
215
+ image = np.transpose(sample['image'], (2, 0, 1))
216
+ sample['image'] = np.ascontiguousarray(image).astype(np.float32)
217
+
218
+ if 'mask' in sample:
219
+ sample['mask'] = sample['mask'].astype(np.float32)
220
+ sample['mask'] = np.ascontiguousarray(sample['mask'])
221
+
222
+ if 'disparity' in sample:
223
+ disparity = sample['disparity'].astype(np.float32)
224
+ sample['disparity'] = np.ascontiguousarray(disparity)
225
+
226
+ if 'depth' in sample:
227
+ depth = sample['depth'].astype(np.float32)
228
+ sample['depth'] = np.ascontiguousarray(depth)
229
+
230
+ return sample
annotator/midas/utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Utils for monoDepth."""
3
+ import re
4
+ import sys
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def read_pfm(path):
12
+ """Read pfm file.
13
+
14
+ Args:
15
+ path (str): path to file
16
+
17
+ Returns:
18
+ tuple: (data, scale)
19
+ """
20
+ with open(path, 'rb') as file:
21
+
22
+ color = None
23
+ width = None
24
+ height = None
25
+ scale = None
26
+ endian = None
27
+
28
+ header = file.readline().rstrip()
29
+ if header.decode('ascii') == 'PF':
30
+ color = True
31
+ elif header.decode('ascii') == 'Pf':
32
+ color = False
33
+ else:
34
+ raise Exception('Not a PFM file: ' + path)
35
+
36
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$',
37
+ file.readline().decode('ascii'))
38
+ if dim_match:
39
+ width, height = list(map(int, dim_match.groups()))
40
+ else:
41
+ raise Exception('Malformed PFM header.')
42
+
43
+ scale = float(file.readline().decode('ascii').rstrip())
44
+ if scale < 0:
45
+ # little-endian
46
+ endian = '<'
47
+ scale = -scale
48
+ else:
49
+ # big-endian
50
+ endian = '>'
51
+
52
+ data = np.fromfile(file, endian + 'f')
53
+ shape = (height, width, 3) if color else (height, width)
54
+
55
+ data = np.reshape(data, shape)
56
+ data = np.flipud(data)
57
+
58
+ return data, scale
59
+
60
+
61
+ def write_pfm(path, image, scale=1):
62
+ """Write pfm file.
63
+
64
+ Args:
65
+ path (str): pathto file
66
+ image (array): data
67
+ scale (int, optional): Scale. Defaults to 1.
68
+ """
69
+
70
+ with open(path, 'wb') as file:
71
+ color = None
72
+
73
+ if image.dtype.name != 'float32':
74
+ raise Exception('Image dtype must be float32.')
75
+
76
+ image = np.flipud(image)
77
+
78
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
79
+ color = True
80
+ elif (len(image.shape) == 2
81
+ or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
82
+ color = False
83
+ else:
84
+ raise Exception(
85
+ 'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
86
+
87
+ file.write('PF\n' if color else 'Pf\n'.encode())
88
+ file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
89
+
90
+ endian = image.dtype.byteorder
91
+
92
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
93
+ scale = -scale
94
+
95
+ file.write('%f\n'.encode() % scale)
96
+
97
+ image.tofile(file)
98
+
99
+
100
+ def read_image(path):
101
+ """Read image and output RGB image (0-1).
102
+
103
+ Args:
104
+ path (str): path to file
105
+
106
+ Returns:
107
+ array: RGB image (0-1)
108
+ """
109
+ img = cv2.imread(path)
110
+
111
+ if img.ndim == 2:
112
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
113
+
114
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
115
+
116
+ return img
117
+
118
+
119
+ def resize_image(img):
120
+ """Resize image and make it fit for network.
121
+
122
+ Args:
123
+ img (array): image
124
+
125
+ Returns:
126
+ tensor: data ready for network
127
+ """
128
+ height_orig = img.shape[0]
129
+ width_orig = img.shape[1]
130
+
131
+ if width_orig > height_orig:
132
+ scale = width_orig / 384
133
+ else:
134
+ scale = height_orig / 384
135
+
136
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
137
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
138
+
139
+ img_resized = cv2.resize(img, (width, height),
140
+ interpolation=cv2.INTER_AREA)
141
+
142
+ img_resized = (torch.from_numpy(np.transpose(
143
+ img_resized, (2, 0, 1))).contiguous().float())
144
+ img_resized = img_resized.unsqueeze(0)
145
+
146
+ return img_resized
147
+
148
+
149
+ def resize_depth(depth, width, height):
150
+ """Resize depth map and bring to CPU (numpy).
151
+
152
+ Args:
153
+ depth (tensor): depth
154
+ width (int): image width
155
+ height (int): image height
156
+
157
+ Returns:
158
+ array: processed depth
159
+ """
160
+ depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
161
+
162
+ depth_resized = cv2.resize(depth.numpy(), (width, height),
163
+ interpolation=cv2.INTER_CUBIC)
164
+
165
+ return depth_resized
166
+
167
+
168
+ def write_depth(path, depth, bits=1):
169
+ """Write depth map to pfm and png file.
170
+
171
+ Args:
172
+ path (str): filepath without extension
173
+ depth (array): depth
174
+ """
175
+ write_pfm(path + '.pfm', depth.astype(np.float32))
176
+
177
+ depth_min = depth.min()
178
+ depth_max = depth.max()
179
+
180
+ max_val = (2**(8 * bits)) - 1
181
+
182
+ if depth_max - depth_min > np.finfo('float').eps:
183
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
184
+ else:
185
+ out = np.zeros(depth.shape, dtype=depth.type)
186
+
187
+ if bits == 1:
188
+ cv2.imwrite(path + '.png', out.astype('uint8'))
189
+ elif bits == 2:
190
+ cv2.imwrite(path + '.png', out.astype('uint16'))
191
+
192
+ return
annotator/midas/vit.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import math
3
+ import types
4
+
5
+ import timm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class Slice(nn.Module):
12
+ def __init__(self, start_index=1):
13
+ super(Slice, self).__init__()
14
+ self.start_index = start_index
15
+
16
+ def forward(self, x):
17
+ return x[:, self.start_index:]
18
+
19
+
20
+ class AddReadout(nn.Module):
21
+ def __init__(self, start_index=1):
22
+ super(AddReadout, self).__init__()
23
+ self.start_index = start_index
24
+
25
+ def forward(self, x):
26
+ if self.start_index == 2:
27
+ readout = (x[:, 0] + x[:, 1]) / 2
28
+ else:
29
+ readout = x[:, 0]
30
+ return x[:, self.start_index:] + readout.unsqueeze(1)
31
+
32
+
33
+ class ProjectReadout(nn.Module):
34
+ def __init__(self, in_features, start_index=1):
35
+ super(ProjectReadout, self).__init__()
36
+ self.start_index = start_index
37
+
38
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
39
+ nn.GELU())
40
+
41
+ def forward(self, x):
42
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
43
+ features = torch.cat((x[:, self.start_index:], readout), -1)
44
+
45
+ return self.project(features)
46
+
47
+
48
+ class Transpose(nn.Module):
49
+ def __init__(self, dim0, dim1):
50
+ super(Transpose, self).__init__()
51
+ self.dim0 = dim0
52
+ self.dim1 = dim1
53
+
54
+ def forward(self, x):
55
+ x = x.transpose(self.dim0, self.dim1)
56
+ return x
57
+
58
+
59
+ def forward_vit(pretrained, x):
60
+ b, c, h, w = x.shape
61
+
62
+ _ = pretrained.model.forward_flex(x)
63
+
64
+ layer_1 = pretrained.activations['1']
65
+ layer_2 = pretrained.activations['2']
66
+ layer_3 = pretrained.activations['3']
67
+ layer_4 = pretrained.activations['4']
68
+
69
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
70
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
71
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
72
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
73
+
74
+ unflatten = nn.Sequential(
75
+ nn.Unflatten(
76
+ 2,
77
+ torch.Size([
78
+ h // pretrained.model.patch_size[1],
79
+ w // pretrained.model.patch_size[0],
80
+ ]),
81
+ ))
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
93
+ layer_1)
94
+ layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
95
+ layer_2)
96
+ layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
97
+ layer_3)
98
+ layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
99
+ layer_4)
100
+
101
+ return layer_1, layer_2, layer_3, layer_4
102
+
103
+
104
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
105
+ posemb_tok, posemb_grid = (
106
+ posemb[:, :self.start_index],
107
+ posemb[0, self.start_index:],
108
+ )
109
+
110
+ gs_old = int(math.sqrt(len(posemb_grid)))
111
+
112
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
113
+ -1).permute(0, 3, 1, 2)
114
+ posemb_grid = F.interpolate(posemb_grid,
115
+ size=(gs_h, gs_w),
116
+ mode='bilinear')
117
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
118
+
119
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
120
+
121
+ return posemb
122
+
123
+
124
+ def forward_flex(self, x):
125
+ b, c, h, w = x.shape
126
+
127
+ pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
128
+ w // self.patch_size[0])
129
+
130
+ B = x.shape[0]
131
+
132
+ if hasattr(self.patch_embed, 'backbone'):
133
+ x = self.patch_embed.backbone(x)
134
+ if isinstance(x, (list, tuple)):
135
+ x = x[
136
+ -1] # last feature if backbone outputs list/tuple of features
137
+
138
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
139
+
140
+ if getattr(self, 'dist_token', None) is not None:
141
+ cls_tokens = self.cls_token.expand(
142
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
143
+ dist_token = self.dist_token.expand(B, -1, -1)
144
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
145
+ else:
146
+ cls_tokens = self.cls_token.expand(
147
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
148
+ x = torch.cat((cls_tokens, x), dim=1)
149
+
150
+ x = x + pos_embed
151
+ x = self.pos_drop(x)
152
+
153
+ for blk in self.blocks:
154
+ x = blk(x)
155
+
156
+ x = self.norm(x)
157
+
158
+ return x
159
+
160
+
161
+ activations = {}
162
+
163
+
164
+ def get_activation(name):
165
+ def hook(model, input, output):
166
+ activations[name] = output
167
+
168
+ return hook
169
+
170
+
171
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
172
+ if use_readout == 'ignore':
173
+ readout_oper = [Slice(start_index)] * len(features)
174
+ elif use_readout == 'add':
175
+ readout_oper = [AddReadout(start_index)] * len(features)
176
+ elif use_readout == 'project':
177
+ readout_oper = [
178
+ ProjectReadout(vit_features, start_index) for out_feat in features
179
+ ]
180
+ else:
181
+ assert (
182
+ False
183
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
184
+
185
+ return readout_oper
186
+
187
+
188
+ def _make_vit_b16_backbone(
189
+ model,
190
+ features=[96, 192, 384, 768],
191
+ size=[384, 384],
192
+ hooks=[2, 5, 8, 11],
193
+ vit_features=768,
194
+ use_readout='ignore',
195
+ start_index=1,
196
+ ):
197
+ pretrained = nn.Module()
198
+
199
+ pretrained.model = model
200
+ pretrained.model.blocks[hooks[0]].register_forward_hook(
201
+ get_activation('1'))
202
+ pretrained.model.blocks[hooks[1]].register_forward_hook(
203
+ get_activation('2'))
204
+ pretrained.model.blocks[hooks[2]].register_forward_hook(
205
+ get_activation('3'))
206
+ pretrained.model.blocks[hooks[3]].register_forward_hook(
207
+ get_activation('4'))
208
+
209
+ pretrained.activations = activations
210
+
211
+ readout_oper = get_readout_oper(vit_features, features, use_readout,
212
+ start_index)
213
+
214
+ # 32, 48, 136, 384
215
+ pretrained.act_postprocess1 = nn.Sequential(
216
+ readout_oper[0],
217
+ Transpose(1, 2),
218
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
219
+ nn.Conv2d(
220
+ in_channels=vit_features,
221
+ out_channels=features[0],
222
+ kernel_size=1,
223
+ stride=1,
224
+ padding=0,
225
+ ),
226
+ nn.ConvTranspose2d(
227
+ in_channels=features[0],
228
+ out_channels=features[0],
229
+ kernel_size=4,
230
+ stride=4,
231
+ padding=0,
232
+ bias=True,
233
+ dilation=1,
234
+ groups=1,
235
+ ),
236
+ )
237
+
238
+ pretrained.act_postprocess2 = nn.Sequential(
239
+ readout_oper[1],
240
+ Transpose(1, 2),
241
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
242
+ nn.Conv2d(
243
+ in_channels=vit_features,
244
+ out_channels=features[1],
245
+ kernel_size=1,
246
+ stride=1,
247
+ padding=0,
248
+ ),
249
+ nn.ConvTranspose2d(
250
+ in_channels=features[1],
251
+ out_channels=features[1],
252
+ kernel_size=2,
253
+ stride=2,
254
+ padding=0,
255
+ bias=True,
256
+ dilation=1,
257
+ groups=1,
258
+ ),
259
+ )
260
+
261
+ pretrained.act_postprocess3 = nn.Sequential(
262
+ readout_oper[2],
263
+ Transpose(1, 2),
264
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
265
+ nn.Conv2d(
266
+ in_channels=vit_features,
267
+ out_channels=features[2],
268
+ kernel_size=1,
269
+ stride=1,
270
+ padding=0,
271
+ ),
272
+ )
273
+
274
+ pretrained.act_postprocess4 = nn.Sequential(
275
+ readout_oper[3],
276
+ Transpose(1, 2),
277
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
278
+ nn.Conv2d(
279
+ in_channels=vit_features,
280
+ out_channels=features[3],
281
+ kernel_size=1,
282
+ stride=1,
283
+ padding=0,
284
+ ),
285
+ nn.Conv2d(
286
+ in_channels=features[3],
287
+ out_channels=features[3],
288
+ kernel_size=3,
289
+ stride=2,
290
+ padding=1,
291
+ ),
292
+ )
293
+
294
+ pretrained.model.start_index = start_index
295
+ pretrained.model.patch_size = [16, 16]
296
+
297
+ # We inject this function into the VisionTransformer instances so that
298
+ # we can use it with interpolated position embeddings without modifying the library source.
299
+ pretrained.model.forward_flex = types.MethodType(forward_flex,
300
+ pretrained.model)
301
+ pretrained.model._resize_pos_embed = types.MethodType(
302
+ _resize_pos_embed, pretrained.model)
303
+
304
+ return pretrained
305
+
306
+
307
+ def _make_pretrained_vitl16_384(pretrained, use_readout='ignore', hooks=None):
308
+ model = timm.create_model('vit_large_patch16_384', pretrained=pretrained)
309
+
310
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
311
+ return _make_vit_b16_backbone(
312
+ model,
313
+ features=[256, 512, 1024, 1024],
314
+ hooks=hooks,
315
+ vit_features=1024,
316
+ use_readout=use_readout,
317
+ )
318
+
319
+
320
+ def _make_pretrained_vitb16_384(pretrained, use_readout='ignore', hooks=None):
321
+ model = timm.create_model('vit_base_patch16_384', pretrained=pretrained)
322
+
323
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
324
+ return _make_vit_b16_backbone(model,
325
+ features=[96, 192, 384, 768],
326
+ hooks=hooks,
327
+ use_readout=use_readout)
328
+
329
+
330
+ def _make_pretrained_deitb16_384(pretrained, use_readout='ignore', hooks=None):
331
+ model = timm.create_model('vit_deit_base_patch16_384',
332
+ pretrained=pretrained)
333
+
334
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
335
+ return _make_vit_b16_backbone(model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout)
339
+
340
+
341
+ def _make_pretrained_deitb16_distil_384(pretrained,
342
+ use_readout='ignore',
343
+ hooks=None):
344
+ model = timm.create_model('vit_deit_base_distilled_patch16_384',
345
+ pretrained=pretrained)
346
+
347
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
348
+ return _make_vit_b16_backbone(
349
+ model,
350
+ features=[96, 192, 384, 768],
351
+ hooks=hooks,
352
+ use_readout=use_readout,
353
+ start_index=2,
354
+ )
355
+
356
+
357
+ def _make_vit_b_rn50_backbone(
358
+ model,
359
+ features=[256, 512, 768, 768],
360
+ size=[384, 384],
361
+ hooks=[0, 1, 8, 11],
362
+ vit_features=768,
363
+ use_vit_only=False,
364
+ use_readout='ignore',
365
+ start_index=1,
366
+ ):
367
+ pretrained = nn.Module()
368
+
369
+ pretrained.model = model
370
+
371
+ if use_vit_only is True:
372
+ pretrained.model.blocks[hooks[0]].register_forward_hook(
373
+ get_activation('1'))
374
+ pretrained.model.blocks[hooks[1]].register_forward_hook(
375
+ get_activation('2'))
376
+ else:
377
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
378
+ get_activation('1'))
379
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
380
+ get_activation('2'))
381
+
382
+ pretrained.model.blocks[hooks[2]].register_forward_hook(
383
+ get_activation('3'))
384
+ pretrained.model.blocks[hooks[3]].register_forward_hook(
385
+ get_activation('4'))
386
+
387
+ pretrained.activations = activations
388
+
389
+ readout_oper = get_readout_oper(vit_features, features, use_readout,
390
+ start_index)
391
+
392
+ if use_vit_only is True:
393
+ pretrained.act_postprocess1 = nn.Sequential(
394
+ readout_oper[0],
395
+ Transpose(1, 2),
396
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
397
+ nn.Conv2d(
398
+ in_channels=vit_features,
399
+ out_channels=features[0],
400
+ kernel_size=1,
401
+ stride=1,
402
+ padding=0,
403
+ ),
404
+ nn.ConvTranspose2d(
405
+ in_channels=features[0],
406
+ out_channels=features[0],
407
+ kernel_size=4,
408
+ stride=4,
409
+ padding=0,
410
+ bias=True,
411
+ dilation=1,
412
+ groups=1,
413
+ ),
414
+ )
415
+
416
+ pretrained.act_postprocess2 = nn.Sequential(
417
+ readout_oper[1],
418
+ Transpose(1, 2),
419
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
420
+ nn.Conv2d(
421
+ in_channels=vit_features,
422
+ out_channels=features[1],
423
+ kernel_size=1,
424
+ stride=1,
425
+ padding=0,
426
+ ),
427
+ nn.ConvTranspose2d(
428
+ in_channels=features[1],
429
+ out_channels=features[1],
430
+ kernel_size=2,
431
+ stride=2,
432
+ padding=0,
433
+ bias=True,
434
+ dilation=1,
435
+ groups=1,
436
+ ),
437
+ )
438
+ else:
439
+ pretrained.act_postprocess1 = nn.Sequential(nn.Identity(),
440
+ nn.Identity(),
441
+ nn.Identity())
442
+ pretrained.act_postprocess2 = nn.Sequential(nn.Identity(),
443
+ nn.Identity(),
444
+ nn.Identity())
445
+
446
+ pretrained.act_postprocess3 = nn.Sequential(
447
+ readout_oper[2],
448
+ Transpose(1, 2),
449
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
450
+ nn.Conv2d(
451
+ in_channels=vit_features,
452
+ out_channels=features[2],
453
+ kernel_size=1,
454
+ stride=1,
455
+ padding=0,
456
+ ),
457
+ )
458
+
459
+ pretrained.act_postprocess4 = nn.Sequential(
460
+ readout_oper[3],
461
+ Transpose(1, 2),
462
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
463
+ nn.Conv2d(
464
+ in_channels=vit_features,
465
+ out_channels=features[3],
466
+ kernel_size=1,
467
+ stride=1,
468
+ padding=0,
469
+ ),
470
+ nn.Conv2d(
471
+ in_channels=features[3],
472
+ out_channels=features[3],
473
+ kernel_size=3,
474
+ stride=2,
475
+ padding=1,
476
+ ),
477
+ )
478
+
479
+ pretrained.model.start_index = start_index
480
+ pretrained.model.patch_size = [16, 16]
481
+
482
+ # We inject this function into the VisionTransformer instances so that
483
+ # we can use it with interpolated position embeddings without modifying the library source.
484
+ pretrained.model.forward_flex = types.MethodType(forward_flex,
485
+ pretrained.model)
486
+
487
+ # We inject this function into the VisionTransformer instances so that
488
+ # we can use it with interpolated position embeddings without modifying the library source.
489
+ pretrained.model._resize_pos_embed = types.MethodType(
490
+ _resize_pos_embed, pretrained.model)
491
+
492
+ return pretrained
493
+
494
+
495
+ def _make_pretrained_vitb_rn50_384(pretrained,
496
+ use_readout='ignore',
497
+ hooks=None,
498
+ use_vit_only=False):
499
+ model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained)
500
+
501
+ hooks = [0, 1, 8, 11] if hooks is None else hooks
502
+ return _make_vit_b_rn50_backbone(
503
+ model,
504
+ features=[256, 512, 768, 768],
505
+ size=[384, 384],
506
+ hooks=hooks,
507
+ use_vit_only=use_vit_only,
508
+ use_readout=use_readout,
509
+ )
annotator/midas_op.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # Midas Depth Estimation
4
+ # From https://github.com/isl-org/MiDaS
5
+ # MIT LICENSE
6
+ from abc import ABCMeta
7
+
8
+ import numpy as np
9
+ import torch
10
+ from einops import rearrange
11
+ from PIL import Image
12
+
13
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
14
+ from scepter.modules.annotator.midas.api import MiDaSInference
15
+ from scepter.modules.annotator.registry import ANNOTATORS
16
+ from scepter.modules.annotator.utils import resize_image, resize_image_ori
17
+ from scepter.modules.utils.config import dict_to_yaml
18
+ from scepter.modules.utils.distribute import we
19
+ from scepter.modules.utils.file_system import FS
20
+
21
+
22
+ @ANNOTATORS.register_class()
23
+ class MidasDetector(BaseAnnotator, metaclass=ABCMeta):
24
+ def __init__(self, cfg, logger=None):
25
+ super().__init__(cfg, logger=logger)
26
+ pretrained_model = cfg.get('PRETRAINED_MODEL', None)
27
+ if pretrained_model:
28
+ with FS.get_from(pretrained_model, wait_finish=True) as local_path:
29
+ self.model = MiDaSInference(model_type='dpt_hybrid',
30
+ model_path=local_path)
31
+ self.a = cfg.get('A', np.pi * 2.0)
32
+ self.bg_th = cfg.get('BG_TH', 0.1)
33
+
34
+ @torch.no_grad()
35
+ @torch.inference_mode()
36
+ @torch.autocast('cuda', enabled=False)
37
+ def forward(self, image):
38
+ if isinstance(image, Image.Image):
39
+ image = np.array(image)
40
+ elif isinstance(image, torch.Tensor):
41
+ image = image.detach().cpu().numpy()
42
+ elif isinstance(image, np.ndarray):
43
+ image = image.copy()
44
+ else:
45
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
46
+ image_depth = image
47
+ h, w, c = image.shape
48
+ image_depth, k = resize_image(image_depth,
49
+ 1024 if min(h, w) > 1024 else min(h, w))
50
+ image_depth = torch.from_numpy(image_depth).float().to(we.device_id)
51
+ image_depth = image_depth / 127.5 - 1.0
52
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
53
+ depth = self.model(image_depth)[0]
54
+
55
+ depth_pt = depth.clone()
56
+ depth_pt -= torch.min(depth_pt)
57
+ depth_pt /= torch.max(depth_pt)
58
+ depth_pt = depth_pt.cpu().numpy()
59
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
60
+ depth_image = depth_image[..., None].repeat(3, 2)
61
+
62
+ # depth_np = depth.cpu().numpy() # float16 error
63
+ # x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
64
+ # y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
65
+ # z = np.ones_like(x) * self.a
66
+ # x[depth_pt < self.bg_th] = 0
67
+ # y[depth_pt < self.bg_th] = 0
68
+ # normal = np.stack([x, y, z], axis=2)
69
+ # normal /= np.sum(normal**2.0, axis=2, keepdims=True)**0.5
70
+ # normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
71
+ depth_image = resize_image_ori(h, w, depth_image, k)
72
+ return depth_image
73
+
74
+ @staticmethod
75
+ def get_config_template():
76
+ return dict_to_yaml('ANNOTATORS',
77
+ __class__.__name__,
78
+ MidasDetector.para_dict,
79
+ set_name=True)
annotator/mlsd/__init__.py ADDED
File without changes
annotator/mlsd/mbv2_mlsd_large.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class BlockTypeA(nn.Module):
10
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale=True):
11
+ super(BlockTypeA, self).__init__()
12
+ self.conv1 = nn.Sequential(nn.Conv2d(in_c2, out_c2, kernel_size=1),
13
+ nn.BatchNorm2d(out_c2),
14
+ nn.ReLU(inplace=True))
15
+ self.conv2 = nn.Sequential(nn.Conv2d(in_c1, out_c1, kernel_size=1),
16
+ nn.BatchNorm2d(out_c1),
17
+ nn.ReLU(inplace=True))
18
+ self.upscale = upscale
19
+
20
+ def forward(self, a, b):
21
+ b = self.conv1(b)
22
+ a = self.conv2(a)
23
+ if self.upscale:
24
+ b = F.interpolate(b,
25
+ scale_factor=2.0,
26
+ mode='bilinear',
27
+ align_corners=True)
28
+ return torch.cat((a, b), dim=1)
29
+
30
+
31
+ class BlockTypeB(nn.Module):
32
+ def __init__(self, in_c, out_c):
33
+ super(BlockTypeB, self).__init__()
34
+ self.conv1 = nn.Sequential(
35
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
36
+ nn.BatchNorm2d(in_c), nn.ReLU())
37
+ self.conv2 = nn.Sequential(
38
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
39
+ nn.BatchNorm2d(out_c), nn.ReLU())
40
+
41
+ def forward(self, x):
42
+ x = self.conv1(x) + x
43
+ x = self.conv2(x)
44
+ return x
45
+
46
+
47
+ class BlockTypeC(nn.Module):
48
+ def __init__(self, in_c, out_c):
49
+ super(BlockTypeC, self).__init__()
50
+ self.conv1 = nn.Sequential(
51
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
52
+ nn.BatchNorm2d(in_c), nn.ReLU())
53
+ self.conv2 = nn.Sequential(
54
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
55
+ nn.BatchNorm2d(in_c), nn.ReLU())
56
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
57
+
58
+ def forward(self, x):
59
+ x = self.conv1(x)
60
+ x = self.conv2(x)
61
+ x = self.conv3(x)
62
+ return x
63
+
64
+
65
+ def _make_divisible(v, divisor, min_value=None):
66
+ """
67
+ This function is taken from the original tf repo.
68
+ It ensures that all layers have a channel number that is divisible by 8
69
+ It can be seen here:
70
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
71
+ :param v:
72
+ :param divisor:
73
+ :param min_value:
74
+ :return:
75
+ """
76
+ if min_value is None:
77
+ min_value = divisor
78
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
79
+ # Make sure that round down does not go down by more than 10%.
80
+ if new_v < 0.9 * v:
81
+ new_v += divisor
82
+ return new_v
83
+
84
+
85
+ class ConvBNReLU(nn.Sequential):
86
+ def __init__(self,
87
+ in_planes,
88
+ out_planes,
89
+ kernel_size=3,
90
+ stride=1,
91
+ groups=1):
92
+ self.channel_pad = out_planes - in_planes
93
+ self.stride = stride
94
+ # padding = (kernel_size - 1) // 2
95
+
96
+ # TFLite uses slightly different padding than PyTorch
97
+ if stride == 2:
98
+ padding = 0
99
+ else:
100
+ padding = (kernel_size - 1) // 2
101
+
102
+ super(ConvBNReLU, self).__init__(
103
+ nn.Conv2d(in_planes,
104
+ out_planes,
105
+ kernel_size,
106
+ stride,
107
+ padding,
108
+ groups=groups,
109
+ bias=False), nn.BatchNorm2d(out_planes),
110
+ nn.ReLU6(inplace=True))
111
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
112
+
113
+ def forward(self, x):
114
+ # TFLite uses different padding
115
+ if self.stride == 2:
116
+ x = F.pad(x, (0, 1, 0, 1), 'constant', 0)
117
+ # print(x.shape)
118
+
119
+ for module in self:
120
+ if not isinstance(module, nn.MaxPool2d):
121
+ x = module(x)
122
+ return x
123
+
124
+
125
+ class InvertedResidual(nn.Module):
126
+ def __init__(self, inp, oup, stride, expand_ratio):
127
+ super(InvertedResidual, self).__init__()
128
+ self.stride = stride
129
+ assert stride in [1, 2]
130
+
131
+ hidden_dim = int(round(inp * expand_ratio))
132
+ self.use_res_connect = self.stride == 1 and inp == oup
133
+
134
+ layers = []
135
+ if expand_ratio != 1:
136
+ # pw
137
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
138
+ layers.extend([
139
+ # dw
140
+ ConvBNReLU(hidden_dim,
141
+ hidden_dim,
142
+ stride=stride,
143
+ groups=hidden_dim),
144
+ # pw-linear
145
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
146
+ nn.BatchNorm2d(oup),
147
+ ])
148
+ self.conv = nn.Sequential(*layers)
149
+
150
+ def forward(self, x):
151
+ if self.use_res_connect:
152
+ return x + self.conv(x)
153
+ else:
154
+ return self.conv(x)
155
+
156
+
157
+ class MobileNetV2(nn.Module):
158
+ def __init__(self, pretrained=True):
159
+ """
160
+ MobileNet V2 main class
161
+ Args:
162
+ num_classes (int): Number of classes
163
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
164
+ inverted_residual_setting: Network structure
165
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
166
+ Set to 1 to turn off rounding
167
+ block: Module specifying inverted residual building block for mobilenet
168
+ """
169
+ super(MobileNetV2, self).__init__()
170
+
171
+ block = InvertedResidual
172
+ input_channel = 32
173
+ last_channel = 1280
174
+ width_mult = 1.0
175
+ round_nearest = 8
176
+
177
+ inverted_residual_setting = [
178
+ # t, c, n, s
179
+ [1, 16, 1, 1],
180
+ [6, 24, 2, 2],
181
+ [6, 32, 3, 2],
182
+ [6, 64, 4, 2],
183
+ [6, 96, 3, 1],
184
+ # [6, 160, 3, 2],
185
+ # [6, 320, 1, 1],
186
+ ]
187
+
188
+ # only check the first element, assuming user knows t,c,n,s are required
189
+ if len(inverted_residual_setting) == 0 or len(
190
+ inverted_residual_setting[0]) != 4:
191
+ raise ValueError('inverted_residual_setting should be non-empty '
192
+ 'or a 4-element list, got {}'.format(
193
+ inverted_residual_setting))
194
+
195
+ # building first layer
196
+ input_channel = _make_divisible(input_channel * width_mult,
197
+ round_nearest)
198
+ self.last_channel = _make_divisible(
199
+ last_channel * max(1.0, width_mult), round_nearest)
200
+ features = [ConvBNReLU(4, input_channel, stride=2)]
201
+ # building inverted residual blocks
202
+ for t, c, n, s in inverted_residual_setting:
203
+ output_channel = _make_divisible(c * width_mult, round_nearest)
204
+ for i in range(n):
205
+ stride = s if i == 0 else 1
206
+ features.append(
207
+ block(input_channel,
208
+ output_channel,
209
+ stride,
210
+ expand_ratio=t))
211
+ input_channel = output_channel
212
+
213
+ self.features = nn.Sequential(*features)
214
+ self.fpn_selected = [1, 3, 6, 10, 13]
215
+ # weight initialization
216
+ for m in self.modules():
217
+ if isinstance(m, nn.Conv2d):
218
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
219
+ if m.bias is not None:
220
+ nn.init.zeros_(m.bias)
221
+ elif isinstance(m, nn.BatchNorm2d):
222
+ nn.init.ones_(m.weight)
223
+ nn.init.zeros_(m.bias)
224
+ elif isinstance(m, nn.Linear):
225
+ nn.init.normal_(m.weight, 0, 0.01)
226
+ nn.init.zeros_(m.bias)
227
+ if pretrained:
228
+ self._load_pretrained_model()
229
+
230
+ def _forward_impl(self, x):
231
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
232
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
233
+ fpn_features = []
234
+ for i, f in enumerate(self.features):
235
+ if i > self.fpn_selected[-1]:
236
+ break
237
+ x = f(x)
238
+ if i in self.fpn_selected:
239
+ fpn_features.append(x)
240
+
241
+ c1, c2, c3, c4, c5 = fpn_features
242
+ return c1, c2, c3, c4, c5
243
+
244
+ def forward(self, x):
245
+ return self._forward_impl(x)
246
+
247
+ def _load_pretrained_model(self):
248
+ pretrain_dict = model_zoo.load_url(
249
+ 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
250
+ model_dict = {}
251
+ state_dict = self.state_dict()
252
+ for k, v in pretrain_dict.items():
253
+ if k in state_dict:
254
+ model_dict[k] = v
255
+ state_dict.update(model_dict)
256
+ self.load_state_dict(state_dict)
257
+
258
+
259
+ class MobileV2_MLSD_Large(nn.Module):
260
+ def __init__(self):
261
+ super(MobileV2_MLSD_Large, self).__init__()
262
+
263
+ self.backbone = MobileNetV2(pretrained=False)
264
+ # A, B
265
+ self.block15 = BlockTypeA(in_c1=64,
266
+ in_c2=96,
267
+ out_c1=64,
268
+ out_c2=64,
269
+ upscale=False)
270
+ self.block16 = BlockTypeB(128, 64)
271
+
272
+ # A, B
273
+ self.block17 = BlockTypeA(in_c1=32, in_c2=64, out_c1=64, out_c2=64)
274
+ self.block18 = BlockTypeB(128, 64)
275
+
276
+ # A, B
277
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64, out_c1=64, out_c2=64)
278
+ self.block20 = BlockTypeB(128, 64)
279
+
280
+ # A, B, C
281
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64, out_c1=64, out_c2=64)
282
+ self.block22 = BlockTypeB(128, 64)
283
+
284
+ self.block23 = BlockTypeC(64, 16)
285
+
286
+ def forward(self, x):
287
+ c1, c2, c3, c4, c5 = self.backbone(x)
288
+
289
+ x = self.block15(c4, c5)
290
+ x = self.block16(x)
291
+
292
+ x = self.block17(c3, x)
293
+ x = self.block18(x)
294
+
295
+ x = self.block19(c2, x)
296
+ x = self.block20(x)
297
+
298
+ x = self.block21(c1, x)
299
+ x = self.block22(x)
300
+ x = self.block23(x)
301
+ x = x[:, 7:, :, :]
302
+
303
+ return x
annotator/mlsd/mbv2_mlsd_tiny.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.utils.model_zoo as model_zoo
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class BlockTypeA(nn.Module):
9
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale=True):
10
+ super(BlockTypeA, self).__init__()
11
+ self.conv1 = nn.Sequential(nn.Conv2d(in_c2, out_c2, kernel_size=1),
12
+ nn.BatchNorm2d(out_c2),
13
+ nn.ReLU(inplace=True))
14
+ self.conv2 = nn.Sequential(nn.Conv2d(in_c1, out_c1, kernel_size=1),
15
+ nn.BatchNorm2d(out_c1),
16
+ nn.ReLU(inplace=True))
17
+ self.upscale = upscale
18
+
19
+ def forward(self, a, b):
20
+ b = self.conv1(b)
21
+ a = self.conv2(a)
22
+ b = F.interpolate(b,
23
+ scale_factor=2.0,
24
+ mode='bilinear',
25
+ align_corners=True)
26
+ return torch.cat((a, b), dim=1)
27
+
28
+
29
+ class BlockTypeB(nn.Module):
30
+ def __init__(self, in_c, out_c):
31
+ super(BlockTypeB, self).__init__()
32
+ self.conv1 = nn.Sequential(
33
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
34
+ nn.BatchNorm2d(in_c), nn.ReLU())
35
+ self.conv2 = nn.Sequential(
36
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
37
+ nn.BatchNorm2d(out_c), nn.ReLU())
38
+
39
+ def forward(self, x):
40
+ x = self.conv1(x) + x
41
+ x = self.conv2(x)
42
+ return x
43
+
44
+
45
+ class BlockTypeC(nn.Module):
46
+ def __init__(self, in_c, out_c):
47
+ super(BlockTypeC, self).__init__()
48
+ self.conv1 = nn.Sequential(
49
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
50
+ nn.BatchNorm2d(in_c), nn.ReLU())
51
+ self.conv2 = nn.Sequential(
52
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
53
+ nn.BatchNorm2d(in_c), nn.ReLU())
54
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
55
+
56
+ def forward(self, x):
57
+ x = self.conv1(x)
58
+ x = self.conv2(x)
59
+ x = self.conv3(x)
60
+ return x
61
+
62
+
63
+ def _make_divisible(v, divisor, min_value=None):
64
+ """
65
+ This function is taken from the original tf repo.
66
+ It ensures that all layers have a channel number that is divisible by 8
67
+ It can be seen here:
68
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
69
+ :param v:
70
+ :param divisor:
71
+ :param min_value:
72
+ :return:
73
+ """
74
+ if min_value is None:
75
+ min_value = divisor
76
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
77
+ # Make sure that round down does not go down by more than 10%.
78
+ if new_v < 0.9 * v:
79
+ new_v += divisor
80
+ return new_v
81
+
82
+
83
+ class ConvBNReLU(nn.Sequential):
84
+ def __init__(self,
85
+ in_planes,
86
+ out_planes,
87
+ kernel_size=3,
88
+ stride=1,
89
+ groups=1):
90
+ self.channel_pad = out_planes - in_planes
91
+ self.stride = stride
92
+ # padding = (kernel_size - 1) // 2
93
+
94
+ # TFLite uses slightly different padding than PyTorch
95
+ if stride == 2:
96
+ padding = 0
97
+ else:
98
+ padding = (kernel_size - 1) // 2
99
+
100
+ super(ConvBNReLU, self).__init__(
101
+ nn.Conv2d(in_planes,
102
+ out_planes,
103
+ kernel_size,
104
+ stride,
105
+ padding,
106
+ groups=groups,
107
+ bias=False), nn.BatchNorm2d(out_planes),
108
+ nn.ReLU6(inplace=True))
109
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
110
+
111
+ def forward(self, x):
112
+ # TFLite uses different padding
113
+ if self.stride == 2:
114
+ x = F.pad(x, (0, 1, 0, 1), 'constant', 0)
115
+ # print(x.shape)
116
+
117
+ for module in self:
118
+ if not isinstance(module, nn.MaxPool2d):
119
+ x = module(x)
120
+ return x
121
+
122
+
123
+ class InvertedResidual(nn.Module):
124
+ def __init__(self, inp, oup, stride, expand_ratio):
125
+ super(InvertedResidual, self).__init__()
126
+ self.stride = stride
127
+ assert stride in [1, 2]
128
+
129
+ hidden_dim = int(round(inp * expand_ratio))
130
+ self.use_res_connect = self.stride == 1 and inp == oup
131
+
132
+ layers = []
133
+ if expand_ratio != 1:
134
+ # pw
135
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
136
+ layers.extend([
137
+ # dw
138
+ ConvBNReLU(hidden_dim,
139
+ hidden_dim,
140
+ stride=stride,
141
+ groups=hidden_dim),
142
+ # pw-linear
143
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
144
+ nn.BatchNorm2d(oup),
145
+ ])
146
+ self.conv = nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ if self.use_res_connect:
150
+ return x + self.conv(x)
151
+ else:
152
+ return self.conv(x)
153
+
154
+
155
+ class MobileNetV2(nn.Module):
156
+ def __init__(self, pretrained=True):
157
+ """
158
+ MobileNet V2 main class
159
+ Args:
160
+ num_classes (int): Number of classes
161
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
162
+ inverted_residual_setting: Network structure
163
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
164
+ Set to 1 to turn off rounding
165
+ block: Module specifying inverted residual building block for mobilenet
166
+ """
167
+ super(MobileNetV2, self).__init__()
168
+
169
+ block = InvertedResidual
170
+ input_channel = 32
171
+ last_channel = 1280
172
+ width_mult = 1.0
173
+ round_nearest = 8
174
+
175
+ inverted_residual_setting = [
176
+ # t, c, n, s
177
+ [1, 16, 1, 1],
178
+ [6, 24, 2, 2],
179
+ [6, 32, 3, 2],
180
+ [6, 64, 4, 2],
181
+ # [6, 96, 3, 1],
182
+ # [6, 160, 3, 2],
183
+ # [6, 320, 1, 1],
184
+ ]
185
+
186
+ # only check the first element, assuming user knows t,c,n,s are required
187
+ if len(inverted_residual_setting) == 0 or len(
188
+ inverted_residual_setting[0]) != 4:
189
+ raise ValueError('inverted_residual_setting should be non-empty '
190
+ 'or a 4-element list, got {}'.format(
191
+ inverted_residual_setting))
192
+
193
+ # building first layer
194
+ input_channel = _make_divisible(input_channel * width_mult,
195
+ round_nearest)
196
+ self.last_channel = _make_divisible(
197
+ last_channel * max(1.0, width_mult), round_nearest)
198
+ features = [ConvBNReLU(4, input_channel, stride=2)]
199
+ # building inverted residual blocks
200
+ for t, c, n, s in inverted_residual_setting:
201
+ output_channel = _make_divisible(c * width_mult, round_nearest)
202
+ for i in range(n):
203
+ stride = s if i == 0 else 1
204
+ features.append(
205
+ block(input_channel,
206
+ output_channel,
207
+ stride,
208
+ expand_ratio=t))
209
+ input_channel = output_channel
210
+ self.features = nn.Sequential(*features)
211
+
212
+ self.fpn_selected = [3, 6, 10]
213
+ # weight initialization
214
+ for m in self.modules():
215
+ if isinstance(m, nn.Conv2d):
216
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
217
+ if m.bias is not None:
218
+ nn.init.zeros_(m.bias)
219
+ elif isinstance(m, nn.BatchNorm2d):
220
+ nn.init.ones_(m.weight)
221
+ nn.init.zeros_(m.bias)
222
+ elif isinstance(m, nn.Linear):
223
+ nn.init.normal_(m.weight, 0, 0.01)
224
+ nn.init.zeros_(m.bias)
225
+
226
+ # if pretrained:
227
+ # self._load_pretrained_model()
228
+
229
+ def _forward_impl(self, x):
230
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
231
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
232
+ fpn_features = []
233
+ for i, f in enumerate(self.features):
234
+ if i > self.fpn_selected[-1]:
235
+ break
236
+ x = f(x)
237
+ if i in self.fpn_selected:
238
+ fpn_features.append(x)
239
+
240
+ c2, c3, c4 = fpn_features
241
+ return c2, c3, c4
242
+
243
+ def forward(self, x):
244
+ return self._forward_impl(x)
245
+
246
+ def _load_pretrained_model(self):
247
+ pretrain_dict = model_zoo.load_url(
248
+ 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
249
+ model_dict = {}
250
+ state_dict = self.state_dict()
251
+ for k, v in pretrain_dict.items():
252
+ if k in state_dict:
253
+ model_dict[k] = v
254
+ state_dict.update(model_dict)
255
+ self.load_state_dict(state_dict)
256
+
257
+
258
+ class MobileV2_MLSD_Tiny(nn.Module):
259
+ def __init__(self):
260
+ super(MobileV2_MLSD_Tiny, self).__init__()
261
+
262
+ self.backbone = MobileNetV2(pretrained=True)
263
+
264
+ self.block12 = BlockTypeA(in_c1=32, in_c2=64, out_c1=64, out_c2=64)
265
+ self.block13 = BlockTypeB(128, 64)
266
+
267
+ self.block14 = BlockTypeA(in_c1=24, in_c2=64, out_c1=32, out_c2=32)
268
+ self.block15 = BlockTypeB(64, 64)
269
+
270
+ self.block16 = BlockTypeC(64, 16)
271
+
272
+ def forward(self, x):
273
+ c2, c3, c4 = self.backbone(x)
274
+
275
+ x = self.block12(c3, c4)
276
+ x = self.block13(x)
277
+ x = self.block14(c2, x)
278
+ x = self.block15(x)
279
+ x = self.block16(x)
280
+ x = x[:, 7:, :, :]
281
+ # print(x.shape)
282
+ x = F.interpolate(x,
283
+ scale_factor=2.0,
284
+ mode='bilinear',
285
+ align_corners=True)
286
+
287
+ return x
annotator/mlsd/utils.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # modified by lihaoweicv
4
+ # pytorch version
5
+ #
6
+ # M-LSD
7
+ # Copyright 2021-present NAVER Corp.
8
+ # Apache License v2.0
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from torch.nn import functional as F
14
+
15
+
16
+ def deccode_output_score_and_ptss(tpMap, topk_n=200, ksize=5):
17
+ '''
18
+ tpMap:
19
+ center: tpMap[1, 0, :, :]
20
+ displacement: tpMap[1, 1:5, :, :]
21
+ '''
22
+ b, c, h, w = tpMap.shape
23
+ assert b == 1, 'only support bsize==1'
24
+ displacement = tpMap[:, 1:5, :, :][0]
25
+ center = tpMap[:, 0, :, :]
26
+ heat = torch.sigmoid(center)
27
+ hmax = F.max_pool2d(heat, (ksize, ksize),
28
+ stride=1,
29
+ padding=(ksize - 1) // 2)
30
+ keep = (hmax == heat).float()
31
+ heat = heat * keep
32
+ heat = heat.reshape(-1, )
33
+
34
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
35
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
36
+ xx = torch.fmod(indices, w).unsqueeze(-1)
37
+ ptss = torch.cat((yy, xx), dim=-1)
38
+
39
+ ptss = ptss.detach().cpu().numpy()
40
+ scores = scores.detach().cpu().numpy()
41
+ displacement = displacement.detach().cpu().numpy()
42
+ displacement = displacement.transpose((1, 2, 0))
43
+ return ptss, scores, displacement
44
+
45
+
46
+ def pred_lines(image,
47
+ model,
48
+ input_shape=[512, 512],
49
+ score_thr=0.10,
50
+ dist_thr=20.0,
51
+ device='cuda'):
52
+ h, w, _ = image.shape
53
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
54
+
55
+ resized_image = np.concatenate([
56
+ cv2.resize(image, (input_shape[1], input_shape[0]),
57
+ interpolation=cv2.INTER_AREA),
58
+ np.ones([input_shape[0], input_shape[1], 1])
59
+ ],
60
+ axis=-1)
61
+
62
+ resized_image = resized_image.transpose((2, 0, 1))
63
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
64
+ batch_image = (batch_image / 127.5) - 1.0
65
+
66
+ batch_image = torch.from_numpy(batch_image).float().to(device)
67
+ outputs = model(batch_image)
68
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
69
+ start = vmap[:, :, :2]
70
+ end = vmap[:, :, 2:]
71
+ dist_map = np.sqrt(np.sum((start - end)**2, axis=-1))
72
+
73
+ segments_list = []
74
+ for center, score in zip(pts, pts_score):
75
+ y, x = center
76
+ distance = dist_map[y, x]
77
+ if score > score_thr and distance > dist_thr:
78
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
79
+ x_start = x + disp_x_start
80
+ y_start = y + disp_y_start
81
+ x_end = x + disp_x_end
82
+ y_end = y + disp_y_end
83
+ segments_list.append([x_start, y_start, x_end, y_end])
84
+
85
+ lines = 2 * np.array(segments_list) # 256 > 512
86
+ lines[:, 0] = lines[:, 0] * w_ratio
87
+ lines[:, 1] = lines[:, 1] * h_ratio
88
+ lines[:, 2] = lines[:, 2] * w_ratio
89
+ lines[:, 3] = lines[:, 3] * h_ratio
90
+
91
+ return lines
92
+
93
+
94
+ def pred_squares(
95
+ image,
96
+ model,
97
+ input_shape=[512, 512],
98
+ device='cuda',
99
+ params={
100
+ 'score': 0.06,
101
+ 'outside_ratio': 0.28,
102
+ 'inside_ratio': 0.45,
103
+ 'w_overlap': 0.0,
104
+ 'w_degree': 1.95,
105
+ 'w_length': 0.0,
106
+ 'w_area': 1.86,
107
+ 'w_center': 0.14
108
+ }): # noqa
109
+ # shape = [height, width]
110
+ h, w, _ = image.shape
111
+ original_shape = [h, w]
112
+
113
+ resized_image = np.concatenate([
114
+ cv2.resize(image, (input_shape[0], input_shape[1]),
115
+ interpolation=cv2.INTER_AREA),
116
+ np.ones([input_shape[0], input_shape[1], 1])
117
+ ],
118
+ axis=-1)
119
+ resized_image = resized_image.transpose((2, 0, 1))
120
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
121
+ batch_image = (batch_image / 127.5) - 1.0
122
+
123
+ batch_image = torch.from_numpy(batch_image).float().to(device)
124
+ outputs = model(batch_image)
125
+
126
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
127
+ start = vmap[:, :, :2] # (x, y)
128
+ end = vmap[:, :, 2:] # (x, y)
129
+ dist_map = np.sqrt(np.sum((start - end)**2, axis=-1))
130
+
131
+ junc_list = []
132
+ segments_list = []
133
+ for junc, score in zip(pts, pts_score):
134
+ y, x = junc
135
+ distance = dist_map[y, x]
136
+ if score > params['score'] and distance > 20.0:
137
+ junc_list.append([x, y])
138
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
139
+ d_arrow = 1.0
140
+ x_start = x + d_arrow * disp_x_start
141
+ y_start = y + d_arrow * disp_y_start
142
+ x_end = x + d_arrow * disp_x_end
143
+ y_end = y + d_arrow * disp_y_end
144
+ segments_list.append([x_start, y_start, x_end, y_end])
145
+
146
+ segments = np.array(segments_list)
147
+
148
+ # post processing for squares
149
+ # 1. get unique lines
150
+ point = np.array([[0, 0]])
151
+ point = point[0]
152
+ start = segments[:, :2]
153
+ end = segments[:, 2:]
154
+ diff = start - end
155
+ a = diff[:, 1]
156
+ b = -diff[:, 0]
157
+ c = a * start[:, 0] + b * start[:, 1]
158
+
159
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a**2 + b**2 + 1e-10)
160
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
161
+ theta[theta < 0.0] += 180
162
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
163
+
164
+ d_quant = 1
165
+ theta_quant = 2
166
+ hough[:, 0] //= d_quant
167
+ hough[:, 1] //= theta_quant
168
+ _, indices, counts = np.unique(hough,
169
+ axis=0,
170
+ return_index=True,
171
+ return_counts=True)
172
+
173
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1],
174
+ dtype='float32')
175
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1],
176
+ dtype='int32') - 1
177
+ yx_indices = hough[indices, :].astype('int32')
178
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
179
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
180
+
181
+ acc_map_np = acc_map
182
+ # acc_map = acc_map[None, :, :, None]
183
+ #
184
+ # ### fast suppression using tensorflow op
185
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
186
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
187
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
188
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
189
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
190
+ # _, h, w, _ = acc_map.shape
191
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
192
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
193
+ # yx = tf.concat([y, x], axis=-1)
194
+
195
+ # fast suppression using pytorch op
196
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
197
+ _, _, h, w = acc_map.shape
198
+ max_acc_map = F.max_pool2d(acc_map, kernel_size=5, stride=1, padding=2)
199
+ acc_map = acc_map * ((acc_map == max_acc_map).float())
200
+ flatten_acc_map = acc_map.reshape([
201
+ -1,
202
+ ])
203
+
204
+ scores, indices = torch.topk(flatten_acc_map,
205
+ len(pts),
206
+ dim=-1,
207
+ largest=True)
208
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
209
+ xx = torch.fmod(indices, w).unsqueeze(-1)
210
+ yx = torch.cat((yy, xx), dim=-1)
211
+
212
+ yx = yx.detach().cpu().numpy()
213
+
214
+ topk_values = scores.detach().cpu().numpy()
215
+ indices = idx_map[yx[:, 0], yx[:, 1]]
216
+ basis = 5 // 2
217
+
218
+ merged_segments = []
219
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
220
+ y, x = yx_pt
221
+ if max_indice == -1 or value == 0:
222
+ continue
223
+ segment_list = []
224
+ for y_offset in range(-basis, basis + 1):
225
+ for x_offset in range(-basis, basis + 1):
226
+ indice = idx_map[y + y_offset, x + x_offset]
227
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
228
+ if indice != -1:
229
+ segment_list.append(segments[indice])
230
+ if cnt > 1:
231
+ check_cnt = 1
232
+ current_hough = hough[indice]
233
+ for new_indice, new_hough in enumerate(hough):
234
+ if (current_hough
235
+ == new_hough).all() and indice != new_indice:
236
+ segment_list.append(segments[new_indice])
237
+ check_cnt += 1
238
+ if check_cnt == cnt:
239
+ break
240
+ group_segments = np.array(segment_list).reshape([-1, 2])
241
+ sorted_group_segments = np.sort(group_segments, axis=0)
242
+ x_min, y_min = sorted_group_segments[0, :]
243
+ x_max, y_max = sorted_group_segments[-1, :]
244
+
245
+ deg = theta[max_indice]
246
+ if deg >= 90:
247
+ merged_segments.append([x_min, y_max, x_max, y_min])
248
+ else:
249
+ merged_segments.append([x_min, y_min, x_max, y_max])
250
+
251
+ # 2. get intersections
252
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
253
+ start = new_segments[:, :2] # (x1, y1)
254
+ end = new_segments[:, 2:] # (x2, y2)
255
+ new_centers = (start + end) / 2.0
256
+ diff = start - end
257
+ dist_segments = np.sqrt(np.sum(diff**2, axis=-1))
258
+
259
+ # ax + by = c
260
+ a = diff[:, 1]
261
+ b = -diff[:, 0]
262
+ c = a * start[:, 0] + b * start[:, 1]
263
+ pre_det = a[:, None] * b[None, :]
264
+ det = pre_det - np.transpose(pre_det)
265
+
266
+ pre_inter_y = a[:, None] * c[None, :]
267
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
268
+ pre_inter_x = c[:, None] * b[None, :]
269
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
270
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]],
271
+ axis=-1).astype('int32')
272
+
273
+ # 3. get corner information
274
+ # 3.1 get distance
275
+ '''
276
+ dist_segments:
277
+ | dist(0), dist(1), dist(2), ...|
278
+ dist_inter_to_segment1:
279
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
280
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
281
+ ...
282
+ dist_inter_to_semgnet2:
283
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
284
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
285
+ ...
286
+ '''
287
+
288
+ dist_inter_to_segment1_start = np.sqrt(
289
+ np.sum(((inter_pts - start[:, None, :])**2), axis=-1,
290
+ keepdims=True)) # [n_batch, n_batch, 1]
291
+ dist_inter_to_segment1_end = np.sqrt(
292
+ np.sum(((inter_pts - end[:, None, :])**2), axis=-1,
293
+ keepdims=True)) # [n_batch, n_batch, 1]
294
+ dist_inter_to_segment2_start = np.sqrt(
295
+ np.sum(((inter_pts - start[None, :, :])**2), axis=-1,
296
+ keepdims=True)) # [n_batch, n_batch, 1]
297
+ dist_inter_to_segment2_end = np.sqrt(
298
+ np.sum(((inter_pts - end[None, :, :])**2), axis=-1,
299
+ keepdims=True)) # [n_batch, n_batch, 1]
300
+
301
+ # sort ascending
302
+ dist_inter_to_segment1 = np.sort(np.concatenate(
303
+ [dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
304
+ axis=-1) # [n_batch, n_batch, 2]
305
+ dist_inter_to_segment2 = np.sort(np.concatenate(
306
+ [dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
307
+ axis=-1) # [n_batch, n_batch, 2]
308
+
309
+ # 3.2 get degree
310
+ inter_to_start = new_centers[:, None, :] - inter_pts
311
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1],
312
+ inter_to_start[:, :, 0]) * 180 / np.pi
313
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
314
+ inter_to_end = new_centers[None, :, :] - inter_pts
315
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1],
316
+ inter_to_end[:, :, 0]) * 180 / np.pi
317
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
318
+ '''
319
+ B -- G
320
+ | |
321
+ C -- R
322
+ B : blue / G: green / C: cyan / R: red
323
+
324
+ 0 -- 1
325
+ | |
326
+ 3 -- 2
327
+ '''
328
+ # rename variables
329
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
330
+ # sort deg ascending
331
+ deg_sort = np.sort(np.concatenate(
332
+ [deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1),
333
+ axis=-1)
334
+
335
+ deg_diff_map = np.abs(deg1_map - deg2_map)
336
+ # we only consider the smallest degree of intersect
337
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
338
+
339
+ # define available degree range
340
+ deg_range = [60, 120]
341
+
342
+ corner_dict = {corner_info: [] for corner_info in range(4)}
343
+ inter_points = []
344
+ for i in range(inter_pts.shape[0]):
345
+ for j in range(i + 1, inter_pts.shape[1]):
346
+ # i, j > line index, always i < j
347
+ x, y = inter_pts[i, j, :]
348
+ deg1, deg2 = deg_sort[i, j, :]
349
+ deg_diff = deg_diff_map[i, j]
350
+
351
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
352
+
353
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
354
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
355
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and
356
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or
357
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and
358
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
359
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and
360
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or
361
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and
362
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
363
+
364
+ if check_degree and check_distance:
365
+ corner_info = None # noqa
366
+
367
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
368
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
369
+ corner_info, color_info = 0, 'blue'
370
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125
371
+ and deg2 <= 225):
372
+ corner_info, color_info = 1, 'green'
373
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225
374
+ and deg2 <= 315):
375
+ corner_info, color_info = 2, 'black'
376
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
377
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
378
+ corner_info, color_info = 3, 'cyan'
379
+ else:
380
+ corner_info, color_info = 4, 'red' # we don't use it # noqa
381
+ continue
382
+
383
+ corner_dict[corner_info].append([x, y, i, j])
384
+ inter_points.append([x, y])
385
+
386
+ square_list = []
387
+ connect_list = []
388
+ segments_list = []
389
+ for corner0 in corner_dict[0]:
390
+ for corner1 in corner_dict[1]:
391
+ connect01 = False
392
+ for corner0_line in corner0[2:]:
393
+ if corner0_line in corner1[2:]:
394
+ connect01 = True
395
+ break
396
+ if connect01:
397
+ for corner2 in corner_dict[2]:
398
+ connect12 = False
399
+ for corner1_line in corner1[2:]:
400
+ if corner1_line in corner2[2:]:
401
+ connect12 = True
402
+ break
403
+ if connect12:
404
+ for corner3 in corner_dict[3]:
405
+ connect23 = False
406
+ for corner2_line in corner2[2:]:
407
+ if corner2_line in corner3[2:]:
408
+ connect23 = True
409
+ break
410
+ if connect23:
411
+ for corner3_line in corner3[2:]:
412
+ if corner3_line in corner0[2:]:
413
+ # SQUARE!!!
414
+ '''
415
+ 0 -- 1
416
+ | |
417
+ 3 -- 2
418
+ square_list:
419
+ order: 0 > 1 > 2 > 3
420
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
421
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
422
+ ...
423
+ connect_list:
424
+ order: 01 > 12 > 23 > 30
425
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
426
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
427
+ ...
428
+ segments_list:
429
+ order: 0 > 1 > 2 > 3
430
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i,
431
+ line_idx2_j, line_idx3_i, line_idx3_j |
432
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i,
433
+ line_idx2_j, line_idx3_i, line_idx3_j |
434
+ ...
435
+ '''
436
+ square_list.append(corner0[:2] +
437
+ corner1[:2] +
438
+ corner2[:2] +
439
+ corner3[:2])
440
+ connect_list.append([
441
+ corner0_line, corner1_line,
442
+ corner2_line, corner3_line
443
+ ])
444
+ segments_list.append(corner0[2:] +
445
+ corner1[2:] +
446
+ corner2[2:] +
447
+ corner3[2:])
448
+
449
+ def check_outside_inside(segments_info, connect_idx):
450
+ # return 'outside or inside', min distance, cover_param, peri_param
451
+ if connect_idx == segments_info[0]:
452
+ check_dist_mat = dist_inter_to_segment1
453
+ else:
454
+ check_dist_mat = dist_inter_to_segment2
455
+
456
+ i, j = segments_info
457
+ min_dist, max_dist = check_dist_mat[i, j, :]
458
+ connect_dist = dist_segments[connect_idx]
459
+ if max_dist > connect_dist:
460
+ return 'outside', min_dist, 0, 1
461
+ else:
462
+ return 'inside', min_dist, -1, -1
463
+
464
+ top_square = None # noqa
465
+
466
+ try:
467
+ map_size = input_shape[0] / 2
468
+ squares = np.array(square_list).reshape([-1, 4, 2])
469
+ score_array = []
470
+ connect_array = np.array(connect_list)
471
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
472
+
473
+ # get degree of corners:
474
+ squares_rollup = np.roll(squares, 1, axis=1)
475
+ squares_rolldown = np.roll(squares, -1, axis=1)
476
+ vec1 = squares_rollup - squares
477
+ normalized_vec1 = vec1 / (
478
+ np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
479
+ vec2 = squares_rolldown - squares
480
+ normalized_vec2 = vec2 / (
481
+ np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
482
+ inner_products = np.sum(normalized_vec1 * normalized_vec2,
483
+ axis=-1) # [n_squares, 4]
484
+ squares_degree = np.arccos(
485
+ inner_products) * 180 / np.pi # [n_squares, 4]
486
+
487
+ # get square score
488
+ overlap_scores = []
489
+ degree_scores = []
490
+ length_scores = []
491
+
492
+ for connects, segments, square, degree in zip(connect_array,
493
+ segments_array, squares,
494
+ squares_degree):
495
+ '''
496
+ 0 -- 1
497
+ | |
498
+ 3 -- 2
499
+
500
+ # segments: [4, 2]
501
+ # connects: [4]
502
+ '''
503
+
504
+ # OVERLAP SCORES
505
+ cover = 0
506
+ perimeter = 0
507
+ # check 0 > 1 > 2 > 3
508
+ square_length = []
509
+
510
+ for start_idx in range(4):
511
+ end_idx = (start_idx + 1) % 4
512
+
513
+ connect_idx = connects[start_idx] # segment idx of segment01
514
+ start_segments = segments[start_idx]
515
+ end_segments = segments[end_idx]
516
+
517
+ start_point = square[start_idx] # noqa
518
+ end_point = square[end_idx] # noqa
519
+
520
+ # check whether outside or inside
521
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(
522
+ start_segments, connect_idx)
523
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(
524
+ end_segments, connect_idx)
525
+
526
+ cover += dist_segments[
527
+ connect_idx] + start_cover_param * start_min + end_cover_param * end_min
528
+ perimeter += dist_segments[
529
+ connect_idx] + start_peri_param * start_min + end_peri_param * end_min
530
+
531
+ square_length.append(dist_segments[connect_idx] +
532
+ start_peri_param * start_min +
533
+ end_peri_param * end_min)
534
+
535
+ overlap_scores.append(cover / perimeter)
536
+ # DEGREE SCORES
537
+ '''
538
+ deg0 vs deg2
539
+ deg1 vs deg3
540
+ '''
541
+ deg0, deg1, deg2, deg3 = degree
542
+ deg_ratio1 = deg0 / deg2
543
+ if deg_ratio1 > 1.0:
544
+ deg_ratio1 = 1 / deg_ratio1
545
+ deg_ratio2 = deg1 / deg3
546
+ if deg_ratio2 > 1.0:
547
+ deg_ratio2 = 1 / deg_ratio2
548
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
549
+ # LENGTH SCORES
550
+ '''
551
+ len0 vs len2
552
+ len1 vs len3
553
+ '''
554
+ len0, len1, len2, len3 = square_length
555
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
556
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
557
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
558
+
559
+ ######################################
560
+
561
+ overlap_scores = np.array(overlap_scores)
562
+ overlap_scores /= np.max(overlap_scores)
563
+
564
+ degree_scores = np.array(degree_scores)
565
+ # degree_scores /= np.max(degree_scores)
566
+
567
+ length_scores = np.array(length_scores)
568
+
569
+ # AREA SCORES
570
+ area_scores = np.reshape(squares, [-1, 4, 2])
571
+ area_x = area_scores[:, :, 0]
572
+ area_y = area_scores[:, :, 1]
573
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:,
574
+ 0]
575
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(
576
+ area_y[:, :-1] * area_x[:, 1:], axis=-1)
577
+ area_scores = 0.5 * np.abs(area_scores + correction)
578
+ area_scores /= (map_size * map_size) # np.max(area_scores)
579
+
580
+ # CENTER SCORES
581
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
582
+ # squares: [n, 4, 2]
583
+ square_centers = np.mean(squares, axis=1) # [n, 2]
584
+ center2center = np.sqrt(np.sum((centers - square_centers)**2))
585
+ center_scores = center2center / (map_size / np.sqrt(2.0))
586
+ '''
587
+ score_w = [overlap, degree, area, center, length]
588
+ '''
589
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0] # noqa
590
+ score_array = (params['w_overlap'] * overlap_scores +
591
+ params['w_degree'] * degree_scores +
592
+ params['w_area'] * area_scores -
593
+ params['w_center'] * center_scores +
594
+ params['w_length'] * length_scores)
595
+
596
+ best_square = [] # noqa
597
+
598
+ sorted_idx = np.argsort(score_array)[::-1]
599
+ score_array = score_array[sorted_idx]
600
+ squares = squares[sorted_idx]
601
+
602
+ except Exception:
603
+ pass
604
+ '''return list
605
+ merged_lines, squares, scores
606
+ '''
607
+
608
+ try:
609
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[
610
+ 1] * original_shape[1]
611
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[
612
+ 0] * original_shape[0]
613
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[
614
+ 1] * original_shape[1]
615
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[
616
+ 0] * original_shape[0]
617
+ except Exception:
618
+ new_segments = []
619
+
620
+ try:
621
+ squares[:, :,
622
+ 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
623
+ squares[:, :,
624
+ 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
625
+ except Exception:
626
+ squares = []
627
+ score_array = []
628
+
629
+ try:
630
+ inter_points = np.array(inter_points)
631
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[
632
+ 1] * original_shape[1]
633
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[
634
+ 0] * original_shape[0]
635
+ except Exception:
636
+ inter_points = []
637
+
638
+ return new_segments, squares, score_array, inter_points
annotator/mlsd_op.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # MLSD Line Detection
4
+ # From https://github.com/navervision/mlsd
5
+ # Apache-2.0 license
6
+
7
+ import warnings
8
+ from abc import ABCMeta
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
16
+ from scepter.modules.annotator.mlsd.mbv2_mlsd_large import MobileV2_MLSD_Large
17
+ from scepter.modules.annotator.mlsd.utils import pred_lines
18
+ from scepter.modules.annotator.registry import ANNOTATORS
19
+ from scepter.modules.annotator.utils import resize_image, resize_image_ori
20
+ from scepter.modules.utils.config import dict_to_yaml
21
+ from scepter.modules.utils.distribute import we
22
+ from scepter.modules.utils.file_system import FS
23
+
24
+
25
+ @ANNOTATORS.register_class()
26
+ class MLSDdetector(BaseAnnotator, metaclass=ABCMeta):
27
+ def __init__(self, cfg, logger=None):
28
+ super().__init__(cfg, logger=logger)
29
+ model = MobileV2_MLSD_Large()
30
+ pretrained_model = cfg.get('PRETRAINED_MODEL', None)
31
+ if pretrained_model:
32
+ with FS.get_from(pretrained_model, wait_finish=True) as local_path:
33
+ model.load_state_dict(torch.load(local_path), strict=True)
34
+ self.model = model.eval()
35
+ self.thr_v = cfg.get('THR_V', 0.1)
36
+ self.thr_d = cfg.get('THR_D', 0.1)
37
+
38
+ @torch.no_grad()
39
+ @torch.inference_mode()
40
+ @torch.autocast('cuda', enabled=False)
41
+ def forward(self, image):
42
+ if isinstance(image, Image.Image):
43
+ image = np.array(image)
44
+ elif isinstance(image, torch.Tensor):
45
+ image = image.detach().cpu().numpy()
46
+ elif isinstance(image, np.ndarray):
47
+ image = image.copy()
48
+ else:
49
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
50
+ h, w, c = image.shape
51
+ image, k = resize_image(image, 1024 if min(h, w) > 1024 else min(h, w))
52
+ img_output = np.zeros_like(image)
53
+ try:
54
+ lines = pred_lines(image,
55
+ self.model, [image.shape[0], image.shape[1]],
56
+ self.thr_v,
57
+ self.thr_d,
58
+ device=we.device_id)
59
+ for line in lines:
60
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
61
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end),
62
+ [255, 255, 255], 1)
63
+ except Exception as e:
64
+ warnings.warn(f'{e}')
65
+ return None
66
+ img_output = resize_image_ori(h, w, img_output, k)
67
+ return img_output[:, :, 0]
68
+
69
+ @staticmethod
70
+ def get_config_template():
71
+ return dict_to_yaml('ANNOTATORS',
72
+ __class__.__name__,
73
+ MLSDdetector.para_dict,
74
+ set_name=True)
annotator/openpose.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # Openpose
4
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
5
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
6
+ # The implementation is modified from 3rd Edited Version by ControlNet
7
+ import math
8
+ import os
9
+ from abc import ABCMeta
10
+ from collections import OrderedDict
11
+
12
+ import cv2
13
+ import matplotlib
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from PIL import Image
18
+ from scipy.ndimage.filters import gaussian_filter
19
+ from skimage.measure import label
20
+
21
+ from scepter.modules.annotator.base_annotator import BaseAnnotator
22
+ from scepter.modules.annotator.registry import ANNOTATORS
23
+ from scepter.modules.utils.config import dict_to_yaml
24
+ from scepter.modules.utils.file_system import FS
25
+
26
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
27
+
28
+
29
+ def padRightDownCorner(img, stride, padValue):
30
+ h = img.shape[0]
31
+ w = img.shape[1]
32
+
33
+ pad = 4 * [None]
34
+ pad[0] = 0 # up
35
+ pad[1] = 0 # left
36
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
37
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
38
+
39
+ img_padded = img
40
+ pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
41
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
42
+ pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
43
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
44
+ pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
45
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
46
+ pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
47
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
48
+
49
+ return img_padded, pad
50
+
51
+
52
+ # transfer caffe model to pytorch which will match the layer name
53
+ def transfer(model, model_weights):
54
+ transfered_model_weights = {}
55
+ for weights_name in model.state_dict().keys():
56
+ transfered_model_weights[weights_name] = model_weights['.'.join(
57
+ weights_name.split('.')[1:])]
58
+ return transfered_model_weights
59
+
60
+
61
+ # draw the body keypoint and lims
62
+ def draw_bodypose(canvas, candidate, subset):
63
+ stickwidth = 4
64
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
65
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15],
66
+ [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
67
+
68
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0],
69
+ [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85],
70
+ [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255],
71
+ [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255],
72
+ [255, 0, 170], [255, 0, 85]]
73
+ for i in range(18):
74
+ for n in range(len(subset)):
75
+ index = int(subset[n][i])
76
+ if index == -1:
77
+ continue
78
+ x, y = candidate[index][0:2]
79
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
80
+ for i in range(17):
81
+ for n in range(len(subset)):
82
+ index = subset[n][np.array(limbSeq[i]) - 1]
83
+ if -1 in index:
84
+ continue
85
+ cur_canvas = canvas.copy()
86
+ Y = candidate[index.astype(int), 0]
87
+ X = candidate[index.astype(int), 1]
88
+ mX = np.mean(X)
89
+ mY = np.mean(Y)
90
+ length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
91
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
92
+ polygon = cv2.ellipse2Poly(
93
+ (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle),
94
+ 0, 360, 1)
95
+ cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
96
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
97
+ # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
98
+ # plt.imshow(canvas[:, :, [2, 1, 0]])
99
+ return canvas
100
+
101
+
102
+ # image drawed by opencv is not good.
103
+ def draw_handpose(canvas, all_hand_peaks, show_number=False):
104
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8],
105
+ [0, 9], [9, 10], [10, 11], [11, 12], [0, 13], [13, 14], [14, 15],
106
+ [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
107
+
108
+ for peaks in all_hand_peaks:
109
+ for ie, e in enumerate(edges):
110
+ if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
111
+ x1, y1 = peaks[e[0]]
112
+ x2, y2 = peaks[e[1]]
113
+ cv2.line(canvas, (x1, y1), (x2, y2),
114
+ matplotlib.colors.hsv_to_rgb(
115
+ [ie / float(len(edges)), 1.0, 1.0]) * 255,
116
+ thickness=2)
117
+
118
+ for i, keyponit in enumerate(peaks):
119
+ x, y = keyponit
120
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
121
+ if show_number:
122
+ cv2.putText(canvas,
123
+ str(i), (x, y),
124
+ cv2.FONT_HERSHEY_SIMPLEX,
125
+ 0.3, (0, 0, 0),
126
+ lineType=cv2.LINE_AA)
127
+ return canvas
128
+
129
+
130
+ # detect hand according to body pose keypoints
131
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/
132
+ # master/src/openpose/hand/handDetector.cpp
133
+ def handDetect(candidate, subset, oriImg):
134
+ # right hand: wrist 4, elbow 3, shoulder 2
135
+ # left hand: wrist 7, elbow 6, shoulder 5
136
+ ratioWristElbow = 0.33
137
+ detect_result = []
138
+ image_height, image_width = oriImg.shape[0:2]
139
+ for person in subset.astype(int):
140
+ # if any of three not detected
141
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
142
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
143
+ if not (has_left or has_right):
144
+ continue
145
+ hands = []
146
+ # left hand
147
+ if has_left:
148
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[
149
+ 5, 6, 7
150
+ ]]
151
+ x1, y1 = candidate[left_shoulder_index][:2]
152
+ x2, y2 = candidate[left_elbow_index][:2]
153
+ x3, y3 = candidate[left_wrist_index][:2]
154
+ hands.append([x1, y1, x2, y2, x3, y3, True])
155
+ # right hand
156
+ if has_right:
157
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[
158
+ [2, 3, 4]]
159
+ x1, y1 = candidate[right_shoulder_index][:2]
160
+ x2, y2 = candidate[right_elbow_index][:2]
161
+ x3, y3 = candidate[right_wrist_index][:2]
162
+ hands.append([x1, y1, x2, y2, x3, y3, False])
163
+
164
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
165
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
166
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
167
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
168
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
169
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
170
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
171
+ x = x3 + ratioWristElbow * (x3 - x2)
172
+ y = y3 + ratioWristElbow * (y3 - y2)
173
+ distanceWristElbow = math.sqrt((x3 - x2)**2 + (y3 - y2)**2)
174
+ distanceElbowShoulder = math.sqrt((x2 - x1)**2 + (y2 - y1)**2)
175
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
176
+ # x-y refers to the center --> offset to topLeft point
177
+ # handRectangle.x -= handRectangle.width / 2.f;
178
+ # handRectangle.y -= handRectangle.height / 2.f;
179
+ x -= width / 2
180
+ y -= width / 2 # width = height
181
+ # overflow the image
182
+ if x < 0:
183
+ x = 0
184
+ if y < 0:
185
+ y = 0
186
+ width1 = width
187
+ width2 = width
188
+ if x + width > image_width:
189
+ width1 = image_width - x
190
+ if y + width > image_height:
191
+ width2 = image_height - y
192
+ width = min(width1, width2)
193
+ # the max hand box value is 20 pixels
194
+ if width >= 20:
195
+ detect_result.append([int(x), int(y), int(width), is_left])
196
+ '''
197
+ return value: [[x, y, w, True if left hand else False]].
198
+ width=height since the network require squared input.
199
+ x, y is the coordinate of top left
200
+ '''
201
+ return detect_result
202
+
203
+
204
+ # get max index of 2d array
205
+ def npmax(array):
206
+ arrayindex = array.argmax(1)
207
+ arrayvalue = array.max(1)
208
+ i = arrayvalue.argmax()
209
+ j = arrayindex[i]
210
+ return i, j
211
+
212
+
213
+ def make_layers(block, no_relu_layers):
214
+ layers = []
215
+ for layer_name, v in block.items():
216
+ if 'pool' in layer_name:
217
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
218
+ layers.append((layer_name, layer))
219
+ else:
220
+ conv2d = nn.Conv2d(in_channels=v[0],
221
+ out_channels=v[1],
222
+ kernel_size=v[2],
223
+ stride=v[3],
224
+ padding=v[4])
225
+ layers.append((layer_name, conv2d))
226
+ if layer_name not in no_relu_layers:
227
+ layers.append(('relu_' + layer_name, nn.ReLU(inplace=True)))
228
+
229
+ return nn.Sequential(OrderedDict(layers))
230
+
231
+
232
+ class bodypose_model(nn.Module):
233
+ def __init__(self):
234
+ super(bodypose_model, self).__init__()
235
+
236
+ # these layers have no relu layer
237
+ no_relu_layers = [
238
+ 'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',
239
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',
240
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',
241
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'
242
+ ]
243
+ blocks = {}
244
+ block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]),
245
+ ('conv1_2', [64, 64, 3, 1, 1]),
246
+ ('pool1_stage1', [2, 2, 0]),
247
+ ('conv2_1', [64, 128, 3, 1, 1]),
248
+ ('conv2_2', [128, 128, 3, 1, 1]),
249
+ ('pool2_stage1', [2, 2, 0]),
250
+ ('conv3_1', [128, 256, 3, 1, 1]),
251
+ ('conv3_2', [256, 256, 3, 1, 1]),
252
+ ('conv3_3', [256, 256, 3, 1, 1]),
253
+ ('conv3_4', [256, 256, 3, 1, 1]),
254
+ ('pool3_stage1', [2, 2, 0]),
255
+ ('conv4_1', [256, 512, 3, 1, 1]),
256
+ ('conv4_2', [512, 512, 3, 1, 1]),
257
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
258
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])])
259
+
260
+ # Stage 1
261
+ block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
262
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
263
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
264
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
265
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])])
266
+
267
+ block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
268
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
269
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
270
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
271
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])])
272
+ blocks['block1_1'] = block1_1
273
+ blocks['block1_2'] = block1_2
274
+
275
+ self.model0 = make_layers(block0, no_relu_layers)
276
+
277
+ # Stages 2 - 6
278
+ for i in range(2, 7):
279
+ blocks['block%d_1' % i] = OrderedDict([
280
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
281
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
282
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
283
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
284
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
285
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
286
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
287
+ ])
288
+
289
+ blocks['block%d_2' % i] = OrderedDict([
290
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
291
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
292
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
293
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
294
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
295
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
296
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
297
+ ])
298
+
299
+ for k in blocks.keys():
300
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
301
+
302
+ self.model1_1 = blocks['block1_1']
303
+ self.model2_1 = blocks['block2_1']
304
+ self.model3_1 = blocks['block3_1']
305
+ self.model4_1 = blocks['block4_1']
306
+ self.model5_1 = blocks['block5_1']
307
+ self.model6_1 = blocks['block6_1']
308
+
309
+ self.model1_2 = blocks['block1_2']
310
+ self.model2_2 = blocks['block2_2']
311
+ self.model3_2 = blocks['block3_2']
312
+ self.model4_2 = blocks['block4_2']
313
+ self.model5_2 = blocks['block5_2']
314
+ self.model6_2 = blocks['block6_2']
315
+
316
+ def forward(self, x):
317
+
318
+ out1 = self.model0(x)
319
+
320
+ out1_1 = self.model1_1(out1)
321
+ out1_2 = self.model1_2(out1)
322
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
323
+
324
+ out2_1 = self.model2_1(out2)
325
+ out2_2 = self.model2_2(out2)
326
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
327
+
328
+ out3_1 = self.model3_1(out3)
329
+ out3_2 = self.model3_2(out3)
330
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
331
+
332
+ out4_1 = self.model4_1(out4)
333
+ out4_2 = self.model4_2(out4)
334
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
335
+
336
+ out5_1 = self.model5_1(out5)
337
+ out5_2 = self.model5_2(out5)
338
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
339
+
340
+ out6_1 = self.model6_1(out6)
341
+ out6_2 = self.model6_2(out6)
342
+
343
+ return out6_1, out6_2
344
+
345
+
346
+ class handpose_model(nn.Module):
347
+ def __init__(self):
348
+ super(handpose_model, self).__init__()
349
+
350
+ # these layers have no relu layer
351
+ no_relu_layers = [
352
+ 'conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3', 'Mconv7_stage4',
353
+ 'Mconv7_stage5', 'Mconv7_stage6'
354
+ ]
355
+ # stage 1
356
+ block1_0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]),
357
+ ('conv1_2', [64, 64, 3, 1, 1]),
358
+ ('pool1_stage1', [2, 2, 0]),
359
+ ('conv2_1', [64, 128, 3, 1, 1]),
360
+ ('conv2_2', [128, 128, 3, 1, 1]),
361
+ ('pool2_stage1', [2, 2, 0]),
362
+ ('conv3_1', [128, 256, 3, 1, 1]),
363
+ ('conv3_2', [256, 256, 3, 1, 1]),
364
+ ('conv3_3', [256, 256, 3, 1, 1]),
365
+ ('conv3_4', [256, 256, 3, 1, 1]),
366
+ ('pool3_stage1', [2, 2, 0]),
367
+ ('conv4_1', [256, 512, 3, 1, 1]),
368
+ ('conv4_2', [512, 512, 3, 1, 1]),
369
+ ('conv4_3', [512, 512, 3, 1, 1]),
370
+ ('conv4_4', [512, 512, 3, 1, 1]),
371
+ ('conv5_1', [512, 512, 3, 1, 1]),
372
+ ('conv5_2', [512, 512, 3, 1, 1]),
373
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])])
374
+
375
+ block1_1 = OrderedDict([('conv6_1_CPM', [128, 512, 1, 1, 0]),
376
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])])
377
+
378
+ blocks = {}
379
+ blocks['block1_0'] = block1_0
380
+ blocks['block1_1'] = block1_1
381
+
382
+ # stage 2-6
383
+ for i in range(2, 7):
384
+ blocks['block%d' % i] = OrderedDict([
385
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
386
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
387
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
388
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
389
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
390
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
391
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
392
+ ])
393
+
394
+ for k in blocks.keys():
395
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
396
+
397
+ self.model1_0 = blocks['block1_0']
398
+ self.model1_1 = blocks['block1_1']
399
+ self.model2 = blocks['block2']
400
+ self.model3 = blocks['block3']
401
+ self.model4 = blocks['block4']
402
+ self.model5 = blocks['block5']
403
+ self.model6 = blocks['block6']
404
+
405
+ def forward(self, x):
406
+ out1_0 = self.model1_0(x)
407
+ out1_1 = self.model1_1(out1_0)
408
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
409
+ out_stage2 = self.model2(concat_stage2)
410
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
411
+ out_stage3 = self.model3(concat_stage3)
412
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
413
+ out_stage4 = self.model4(concat_stage4)
414
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
415
+ out_stage5 = self.model5(concat_stage5)
416
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
417
+ out_stage6 = self.model6(concat_stage6)
418
+ return out_stage6
419
+
420
+
421
+ class Hand(object):
422
+ def __init__(self, model_path, device='cuda'):
423
+ self.model = handpose_model()
424
+ if torch.cuda.is_available():
425
+ self.model = self.model.to(device)
426
+ model_dict = transfer(self.model, torch.load(model_path))
427
+ self.model.load_state_dict(model_dict)
428
+ self.model.eval()
429
+ self.device = device
430
+
431
+ def __call__(self, oriImg):
432
+ scale_search = [0.5, 1.0, 1.5, 2.0]
433
+ # scale_search = [0.5]
434
+ boxsize = 368
435
+ stride = 8
436
+ padValue = 128
437
+ thre = 0.05
438
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
439
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
440
+ # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
441
+
442
+ for m in range(len(multiplier)):
443
+ scale = multiplier[m]
444
+ imageToTest = cv2.resize(oriImg, (0, 0),
445
+ fx=scale,
446
+ fy=scale,
447
+ interpolation=cv2.INTER_CUBIC)
448
+ imageToTest_padded, pad = padRightDownCorner(
449
+ imageToTest, stride, padValue)
450
+ im = np.transpose(
451
+ np.float32(imageToTest_padded[:, :, :, np.newaxis]),
452
+ (3, 2, 0, 1)) / 256 - 0.5
453
+ im = np.ascontiguousarray(im)
454
+
455
+ data = torch.from_numpy(im).float()
456
+ if torch.cuda.is_available():
457
+ data = data.to(self.device)
458
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
459
+ with torch.no_grad():
460
+ output = self.model(data).cpu().numpy()
461
+ # output = self.model(data).numpy()q
462
+
463
+ # extract outputs, resize, and remove padding
464
+ heatmap = np.transpose(np.squeeze(output),
465
+ (1, 2, 0)) # output 1 is heatmaps
466
+ heatmap = cv2.resize(heatmap, (0, 0),
467
+ fx=stride,
468
+ fy=stride,
469
+ interpolation=cv2.INTER_CUBIC)
470
+ heatmap = heatmap[:imageToTest_padded.shape[0] -
471
+ pad[2], :imageToTest_padded.shape[1] - pad[3], :]
472
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]),
473
+ interpolation=cv2.INTER_CUBIC)
474
+
475
+ heatmap_avg += heatmap / len(multiplier)
476
+
477
+ all_peaks = []
478
+ for part in range(21):
479
+ map_ori = heatmap_avg[:, :, part]
480
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
481
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
482
+ # 全部小于阈值
483
+ if np.sum(binary) == 0:
484
+ all_peaks.append([0, 0])
485
+ continue
486
+ label_img, label_numbers = label(binary,
487
+ return_num=True,
488
+ connectivity=binary.ndim)
489
+ max_index = np.argmax([
490
+ np.sum(map_ori[label_img == i])
491
+ for i in range(1, label_numbers + 1)
492
+ ]) + 1
493
+ label_img[label_img != max_index] = 0
494
+ map_ori[label_img == 0] = 0
495
+
496
+ y, x = npmax(map_ori)
497
+ all_peaks.append([x, y])
498
+ return np.array(all_peaks)
499
+
500
+
501
+ class Body(object):
502
+ def __init__(self, model_path, device='cuda'):
503
+ self.model = bodypose_model()
504
+ if torch.cuda.is_available():
505
+ self.model = self.model.to(device)
506
+ model_dict = transfer(self.model, torch.load(model_path))
507
+ self.model.load_state_dict(model_dict)
508
+ self.model.eval()
509
+ self.device = device
510
+
511
+ def __call__(self, oriImg):
512
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
513
+ scale_search = [0.5]
514
+ boxsize = 368
515
+ stride = 8
516
+ padValue = 128
517
+ thre1 = 0.1
518
+ thre2 = 0.05
519
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
520
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
521
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
522
+
523
+ for m in range(len(multiplier)):
524
+ scale = multiplier[m]
525
+ imageToTest = cv2.resize(oriImg, (0, 0),
526
+ fx=scale,
527
+ fy=scale,
528
+ interpolation=cv2.INTER_CUBIC)
529
+ imageToTest_padded, pad = padRightDownCorner(
530
+ imageToTest, stride, padValue)
531
+ im = np.transpose(
532
+ np.float32(imageToTest_padded[:, :, :, np.newaxis]),
533
+ (3, 2, 0, 1)) / 256 - 0.5
534
+ im = np.ascontiguousarray(im)
535
+
536
+ data = torch.from_numpy(im).float()
537
+ if torch.cuda.is_available():
538
+ data = data.to(self.device)
539
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
540
+ with torch.no_grad():
541
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
542
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
543
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
544
+
545
+ # extract outputs, resize, and remove padding
546
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0))
547
+ # output 1 is heatmaps
548
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2),
549
+ (1, 2, 0)) # output 1 is heatmaps
550
+ heatmap = cv2.resize(heatmap, (0, 0),
551
+ fx=stride,
552
+ fy=stride,
553
+ interpolation=cv2.INTER_CUBIC)
554
+ heatmap = heatmap[:imageToTest_padded.shape[0] -
555
+ pad[2], :imageToTest_padded.shape[1] - pad[3], :]
556
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]),
557
+ interpolation=cv2.INTER_CUBIC)
558
+
559
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
560
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1),
561
+ (1, 2, 0)) # output 0 is PAFs
562
+ paf = cv2.resize(paf, (0, 0),
563
+ fx=stride,
564
+ fy=stride,
565
+ interpolation=cv2.INTER_CUBIC)
566
+ paf = paf[:imageToTest_padded.shape[0] -
567
+ pad[2], :imageToTest_padded.shape[1] - pad[3], :]
568
+ paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]),
569
+ interpolation=cv2.INTER_CUBIC)
570
+
571
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
572
+ paf_avg += +paf / len(multiplier)
573
+
574
+ all_peaks = []
575
+ peak_counter = 0
576
+
577
+ for part in range(18):
578
+ map_ori = heatmap_avg[:, :, part]
579
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
580
+
581
+ map_left = np.zeros(one_heatmap.shape)
582
+ map_left[1:, :] = one_heatmap[:-1, :]
583
+ map_right = np.zeros(one_heatmap.shape)
584
+ map_right[:-1, :] = one_heatmap[1:, :]
585
+ map_up = np.zeros(one_heatmap.shape)
586
+ map_up[:, 1:] = one_heatmap[:, :-1]
587
+ map_down = np.zeros(one_heatmap.shape)
588
+ map_down[:, :-1] = one_heatmap[:, 1:]
589
+
590
+ peaks_binary = np.logical_and.reduce(
591
+ (one_heatmap >= map_left, one_heatmap >= map_right,
592
+ one_heatmap >= map_up, one_heatmap >= map_down,
593
+ one_heatmap > thre1))
594
+ peaks = list(
595
+ zip(np.nonzero(peaks_binary)[1],
596
+ np.nonzero(peaks_binary)[0])) # note reverse
597
+ peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks]
598
+ peak_id = range(peak_counter, peak_counter + len(peaks))
599
+ peaks_with_score_and_id = [
600
+ peaks_with_score[i] + (peak_id[i], )
601
+ for i in range(len(peak_id))
602
+ ]
603
+
604
+ all_peaks.append(peaks_with_score_and_id)
605
+ peak_counter += len(peaks)
606
+
607
+ # find connection in the specified sequence, center 29 is in the position 15
608
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9],
609
+ [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1],
610
+ [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
611
+ # the middle joints heatmap correpondence
612
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44],
613
+ [19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30],
614
+ [47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38],
615
+ [45, 46]]
616
+
617
+ connection_all = []
618
+ special_k = []
619
+ mid_num = 10
620
+
621
+ for k in range(len(mapIdx)):
622
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
623
+ candA = all_peaks[limbSeq[k][0] - 1]
624
+ candB = all_peaks[limbSeq[k][1] - 1]
625
+ nA = len(candA)
626
+ nB = len(candB)
627
+ indexA, indexB = limbSeq[k]
628
+ if (nA != 0 and nB != 0):
629
+ connection_candidate = []
630
+ for i in range(nA):
631
+ for j in range(nB):
632
+ vec = np.subtract(candB[j][:2], candA[i][:2])
633
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
634
+ norm = max(0.001, norm)
635
+ vec = np.divide(vec, norm)
636
+
637
+ startend = list(
638
+ zip(
639
+ np.linspace(candA[i][0],
640
+ candB[j][0],
641
+ num=mid_num),
642
+ np.linspace(candA[i][1],
643
+ candB[j][1],
644
+ num=mid_num)))
645
+
646
+ vec_x = np.array([
647
+ score_mid[int(round(startend[ii][1])),
648
+ int(round(startend[ii][0])), 0]
649
+ for ii in range(len(startend))
650
+ ])
651
+ vec_y = np.array([
652
+ score_mid[int(round(startend[ii][1])),
653
+ int(round(startend[ii][0])), 1]
654
+ for ii in range(len(startend))
655
+ ])
656
+
657
+ score_midpts = np.multiply(
658
+ vec_x, vec[0]) + np.multiply(vec_y, vec[1])
659
+ score_with_dist_prior = sum(score_midpts) / len(
660
+ score_midpts) + min(
661
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
662
+ criterion1 = len(np.nonzero(
663
+ score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
664
+ criterion2 = score_with_dist_prior > 0
665
+ if criterion1 and criterion2:
666
+ connection_candidate.append([
667
+ i, j, score_with_dist_prior,
668
+ score_with_dist_prior + candA[i][2] +
669
+ candB[j][2]
670
+ ])
671
+
672
+ connection_candidate = sorted(connection_candidate,
673
+ key=lambda x: x[2],
674
+ reverse=True)
675
+ connection = np.zeros((0, 5))
676
+ for c in range(len(connection_candidate)):
677
+ i, j, s = connection_candidate[c][0:3]
678
+ if (i not in connection[:, 3]
679
+ and j not in connection[:, 4]):
680
+ connection = np.vstack(
681
+ [connection, [candA[i][3], candB[j][3], s, i, j]])
682
+ if (len(connection) >= min(nA, nB)):
683
+ break
684
+
685
+ connection_all.append(connection)
686
+ else:
687
+ special_k.append(k)
688
+ connection_all.append([])
689
+
690
+ # last number in each row is the total parts number of that person
691
+ # the second last number in each row is the score of the overall configuration
692
+ subset = -1 * np.ones((0, 20))
693
+ candidate = np.array(
694
+ [item for sublist in all_peaks for item in sublist])
695
+
696
+ for k in range(len(mapIdx)):
697
+ if k not in special_k:
698
+ partAs = connection_all[k][:, 0]
699
+ partBs = connection_all[k][:, 1]
700
+ indexA, indexB = np.array(limbSeq[k]) - 1
701
+
702
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
703
+ found = 0
704
+ subset_idx = [-1, -1]
705
+ for j in range(len(subset)): # 1:size(subset,1):
706
+ if subset[j][indexA] == partAs[i] or subset[j][
707
+ indexB] == partBs[i]:
708
+ subset_idx[found] = j
709
+ found += 1
710
+
711
+ if found == 1:
712
+ j = subset_idx[0]
713
+ if subset[j][indexB] != partBs[i]:
714
+ subset[j][indexB] = partBs[i]
715
+ subset[j][-1] += 1
716
+ subset[j][-2] += candidate[
717
+ partBs[i].astype(int),
718
+ 2] + connection_all[k][i][2]
719
+ elif found == 2: # if found 2 and disjoint, merge them
720
+ j1, j2 = subset_idx
721
+ membership = ((subset[j1] >= 0).astype(int) +
722
+ (subset[j2] >= 0).astype(int))[:-2]
723
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
724
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
725
+ subset[j1][-2:] += subset[j2][-2:]
726
+ subset[j1][-2] += connection_all[k][i][2]
727
+ subset = np.delete(subset, j2, 0)
728
+ else: # as like found == 1
729
+ subset[j1][indexB] = partBs[i]
730
+ subset[j1][-1] += 1
731
+ subset[j1][-2] += candidate[
732
+ partBs[i].astype(int),
733
+ 2] + connection_all[k][i][2]
734
+
735
+ # if find no partA in the subset, create a new subset
736
+ elif not found and k < 17:
737
+ row = -1 * np.ones(20)
738
+ row[indexA] = partAs[i]
739
+ row[indexB] = partBs[i]
740
+ row[-1] = 2
741
+ row[-2] = sum(
742
+ candidate[connection_all[k][i, :2].astype(int),
743
+ 2]) + connection_all[k][i][2]
744
+ subset = np.vstack([subset, row])
745
+ # delete some rows of subset which has few parts occur
746
+ deleteIdx = []
747
+ for i in range(len(subset)):
748
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
749
+ deleteIdx.append(i)
750
+ subset = np.delete(subset, deleteIdx, axis=0)
751
+
752
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
753
+ # candidate: x, y, score, id
754
+ return candidate, subset
755
+
756
+
757
+ @ANNOTATORS.register_class()
758
+ class OpenposeAnnotator(BaseAnnotator, metaclass=ABCMeta):
759
+ para_dict = {}
760
+
761
+ def __init__(self, cfg, logger=None):
762
+ super().__init__(cfg, logger=logger)
763
+ with FS.get_from(cfg.BODY_MODEL_PATH,
764
+ wait_finish=True) as body_model_path:
765
+ self.body_estimation = Body(body_model_path, device='cpu')
766
+ with FS.get_from(cfg.HAND_MODEL_PATH,
767
+ wait_finish=True) as hand_model_path:
768
+ self.hand_estimation = Hand(hand_model_path, device='cpu')
769
+ self.use_hand = cfg.get('USE_HAND', False)
770
+
771
+ def to(self, device):
772
+ self.body_estimation.model = self.body_estimation.model.to(device)
773
+ self.body_estimation.device = device
774
+ self.hand_estimation.model = self.hand_estimation.model.to(device)
775
+ self.hand_estimation.device = device
776
+ return self
777
+
778
+ @torch.no_grad()
779
+ @torch.inference_mode()
780
+ @torch.autocast('cuda', enabled=False)
781
+ def forward(self, image):
782
+ if isinstance(image, Image.Image):
783
+ image = np.array(image)
784
+ elif isinstance(image, torch.Tensor):
785
+ image = image.detach().cpu().numpy()
786
+ elif isinstance(image, np.ndarray):
787
+ image = image.copy()
788
+ else:
789
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
790
+ image = image[:, :, ::-1]
791
+ candidate, subset = self.body_estimation(image)
792
+ canvas = np.zeros_like(image)
793
+ canvas = draw_bodypose(canvas, candidate, subset)
794
+ if self.use_hand:
795
+ hands_list = handDetect(candidate, subset, image)
796
+ all_hand_peaks = []
797
+ for x, y, w, is_left in hands_list:
798
+ peaks = self.hand_estimation(image[y:y + w, x:x + w, :])
799
+ peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0],
800
+ peaks[:, 0] + x)
801
+ peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1],
802
+ peaks[:, 1] + y)
803
+ all_hand_peaks.append(peaks)
804
+ canvas = draw_handpose(canvas, all_hand_peaks)
805
+ return canvas
806
+
807
+ @staticmethod
808
+ def get_config_template():
809
+ return dict_to_yaml('ANNOTATORS',
810
+ __class__.__name__,
811
+ OpenposeAnnotator.para_dict,
812
+ set_name=True)
annotator/registry.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from scepter.modules.utils.config import Config
4
+ from scepter.modules.utils.registry import Registry, build_from_config
5
+
6
+
7
+ def build_annotator(cfg, registry, logger=None, *args, **kwargs):
8
+ """ After build model, load pretrained model if exists key `pretrain`.
9
+
10
+ pretrain (str, dict): Describes how to load pretrained model.
11
+ str, treat pretrain as model path;
12
+ dict: should contains key `path`, and other parameters token by function load_pretrained();
13
+ """
14
+ if not isinstance(cfg, Config):
15
+ raise TypeError(f'Config must be type dict, got {type(cfg)}')
16
+ if cfg.have('PRETRAINED_MODEL'):
17
+ pretrain_cfg = cfg.PRETRAINED_MODEL
18
+ if pretrain_cfg is not None and not isinstance(pretrain_cfg, (str)):
19
+ raise TypeError('Pretrain parameter must be a string')
20
+ else:
21
+ pretrain_cfg = None
22
+
23
+ model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
24
+ if pretrain_cfg is not None:
25
+ if hasattr(model, 'load_pretrained_model'):
26
+ model.load_pretrained_model(pretrain_cfg)
27
+ return model
28
+
29
+
30
+ ANNOTATORS = Registry('ANNOTATORS', build_func=build_annotator)
annotator/utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import cv2
4
+ import numpy as np
5
+
6
+
7
+ def resize_image(input_image, resolution):
8
+ H, W, C = input_image.shape
9
+ H = float(H)
10
+ W = float(W)
11
+ k = float(resolution) / min(H, W)
12
+ H *= k
13
+ W *= k
14
+ H = int(np.round(H / 64.0)) * 64
15
+ W = int(np.round(W / 64.0)) * 64
16
+ img = cv2.resize(
17
+ input_image, (W, H),
18
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
19
+ return img, k
20
+
21
+
22
+ def resize_image_ori(h, w, image, k):
23
+ img = cv2.resize(
24
+ image, (w, h),
25
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
26
+ return img
27
+
28
+
29
+ class AnnotatorProcessor():
30
+ canny_cfg = {
31
+ 'NAME': 'CannyAnnotator',
32
+ 'LOW_THRESHOLD': 100,
33
+ 'HIGH_THRESHOLD': 200,
34
+ 'INPUT_KEYS': ['img'],
35
+ 'OUTPUT_KEYS': ['canny']
36
+ }
37
+ hed_cfg = {
38
+ 'NAME': 'HedAnnotator',
39
+ 'PRETRAINED_MODEL':
40
+ 'ms://damo/scepter_scedit@annotator/ckpts/ControlNetHED.pth',
41
+ 'INPUT_KEYS': ['img'],
42
+ 'OUTPUT_KEYS': ['hed']
43
+ }
44
+ openpose_cfg = {
45
+ 'NAME': 'OpenposeAnnotator',
46
+ 'BODY_MODEL_PATH':
47
+ 'ms://damo/scepter_scedit@annotator/ckpts/body_pose_model.pth',
48
+ 'HAND_MODEL_PATH':
49
+ 'ms://damo/scepter_scedit@annotator/ckpts/hand_pose_model.pth',
50
+ 'INPUT_KEYS': ['img'],
51
+ 'OUTPUT_KEYS': ['openpose']
52
+ }
53
+ midas_cfg = {
54
+ 'NAME': 'MidasDetector',
55
+ 'PRETRAINED_MODEL':
56
+ 'ms://damo/scepter_scedit@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt',
57
+ 'INPUT_KEYS': ['img'],
58
+ 'OUTPUT_KEYS': ['depth']
59
+ }
60
+ mlsd_cfg = {
61
+ 'NAME': 'MLSDdetector',
62
+ 'PRETRAINED_MODEL':
63
+ 'ms://damo/scepter_scedit@annotator/ckpts/mlsd_large_512_fp32.pth',
64
+ 'INPUT_KEYS': ['img'],
65
+ 'OUTPUT_KEYS': ['mlsd']
66
+ }
67
+ color_cfg = {
68
+ 'NAME': 'ColorAnnotator',
69
+ 'RATIO': 64,
70
+ 'INPUT_KEYS': ['img'],
71
+ 'OUTPUT_KEYS': ['color']
72
+ }
73
+
74
+ anno_type_map = {
75
+ 'canny': canny_cfg,
76
+ 'hed': hed_cfg,
77
+ 'pose': openpose_cfg,
78
+ 'depth': midas_cfg,
79
+ 'mlsd': mlsd_cfg,
80
+ 'color': color_cfg
81
+ }
82
+
83
+ def __init__(self, anno_type):
84
+ from scepter.modules.annotator.registry import ANNOTATORS
85
+ from scepter.modules.utils.config import Config
86
+ from scepter.modules.utils.distribute import we
87
+
88
+ if isinstance(anno_type, str):
89
+ assert anno_type in self.anno_type_map.keys()
90
+ anno_type = [anno_type]
91
+ elif isinstance(anno_type, (list, tuple)):
92
+ assert all(tp in self.anno_type_map.keys() for tp in anno_type)
93
+ else:
94
+ raise Exception(f'Error anno_type: {anno_type}')
95
+
96
+ general_dict = {
97
+ 'NAME': 'GeneralAnnotator',
98
+ 'ANNOTATORS': [self.anno_type_map[tp] for tp in anno_type]
99
+ }
100
+ general_anno = Config(cfg_dict=general_dict, load=False)
101
+ self.general_ins = ANNOTATORS.build(general_anno).to(we.device_id)
102
+
103
+ def run(self, image, anno_type=None):
104
+ output_image = self.general_ins({'img': image})
105
+ if anno_type is not None:
106
+ if isinstance(anno_type, str) and anno_type in output_image:
107
+ return output_image[anno_type]
108
+ else:
109
+ return {
110
+ tp: output_image[tp]
111
+ for tp in anno_type if tp in output_image
112
+ }
113
+ else:
114
+ return output_image
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from PIL import Image
4
+ from typing import List
5
+
6
+ from pipeline import prepare_white_image, MultiViewGenerator
7
+ from util import download_file, unzip_file
8
+
9
+ download_file("https://huggingface.co/aki-0421/character-360/resolve/main/v2.ckpt", "v2.ckpt")
10
+ download_file("https://huggingface.co/hbyang/Hi3D/resolve/main/ckpts.zip", "ckpts.zip")
11
+
12
+ unzip_file("ckpts.zip", ".")
13
+
14
+ multi_view_generator = MultiViewGenerator(checkpoint_path="v2.ckpt")
15
+
16
+ @spaces.GPU(duration=120)
17
+ def generate_images(input_image: Image.Image) -> List[Image.Image]:
18
+ white_image = prepare_white_image(input_image=input_image)
19
+
20
+ return multi_view_generator.infer(white_image=white_image)
21
+
22
+
23
+ with gr.Blocks() as demo:
24
+ gr.Markdown("# GPU-accelerated Image Processing")
25
+ with gr.Row():
26
+ input_image = gr.Image(label="Input Image", type="pil") # 入力はPIL形式
27
+ output_gallery = gr.Gallery(label="Output Images (25 Variations)").style(grid=(5, 5))
28
+ submit_button = gr.Button("Generate")
29
+
30
+ submit_button.click(generate_images, inputs=input_image, outputs=output_gallery)
31
+
32
+ demo.launch()
dataset/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ benchmarks/benchmarking_Random_grayscale.png
dataset/opencv_transforms/__init__.py ADDED
File without changes
dataset/opencv_transforms/functional.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from PIL import Image, ImageEnhance, ImageOps
6
+
7
+ try:
8
+ import accimage
9
+ except ImportError:
10
+ accimage = None
11
+ import collections
12
+ import numbers
13
+ import types
14
+ import warnings
15
+
16
+ import cv2
17
+ import numpy as np
18
+ from PIL import Image
19
+
20
+ _cv2_pad_to_str = {
21
+ 'constant': cv2.BORDER_CONSTANT,
22
+ 'edge': cv2.BORDER_REPLICATE,
23
+ 'reflect': cv2.BORDER_REFLECT_101,
24
+ 'symmetric': cv2.BORDER_REFLECT
25
+ }
26
+ _cv2_interpolation_to_str = {
27
+ 'nearest': cv2.INTER_NEAREST,
28
+ 'bilinear': cv2.INTER_LINEAR,
29
+ 'area': cv2.INTER_AREA,
30
+ 'bicubic': cv2.INTER_CUBIC,
31
+ 'lanczos': cv2.INTER_LANCZOS4
32
+ }
33
+ _cv2_interpolation_from_str = {v: k for k, v in _cv2_interpolation_to_str.items()}
34
+
35
+
36
+ def _is_pil_image(img):
37
+ if accimage is not None:
38
+ return isinstance(img, (Image.Image, accimage.Image))
39
+ else:
40
+ return isinstance(img, Image.Image)
41
+
42
+
43
+ def _is_tensor_image(img):
44
+ return torch.is_tensor(img) and img.ndimension() == 3
45
+
46
+
47
+ def _is_numpy_image(img):
48
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
49
+
50
+
51
+ def to_tensor(pic):
52
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
53
+ See ``ToTensor`` for more details.
54
+ Args:
55
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
56
+ Returns:
57
+ Tensor: Converted image.
58
+ """
59
+ if not (_is_numpy_image(pic)):
60
+ raise TypeError('pic should be ndarray. Got {}'.format(type(pic)))
61
+
62
+ # handle numpy array
63
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
64
+ # backward compatibility
65
+ if isinstance(img, torch.ByteTensor) or img.dtype == torch.uint8:
66
+ return img.float().div(255)
67
+ else:
68
+ return img
69
+
70
+
71
+ def normalize(tensor, mean, std):
72
+ """Normalize a tensor image with mean and standard deviation.
73
+ .. note::
74
+ This transform acts in-place, i.e., it mutates the input tensor.
75
+ See :class:`~torchvision.transforms.Normalize` for more details.
76
+ Args:
77
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
78
+ mean (sequence): Sequence of means for each channel.
79
+ std (sequence): Sequence of standard deviations for each channely.
80
+ Returns:
81
+ Tensor: Normalized Tensor image.
82
+ """
83
+ if not _is_tensor_image(tensor):
84
+ raise TypeError('tensor is not a torch image.')
85
+
86
+ # This is faster than using broadcasting, don't change without benchmarking
87
+ for t, m, s in zip(tensor, mean, std):
88
+ t.sub_(m).div_(s)
89
+ return tensor
90
+
91
+
92
+ def resize(img, size, interpolation=cv2.INTER_LINEAR):
93
+ r"""Resize the input numpy ndarray to the given size.
94
+ Args:
95
+ img (numpy ndarray): Image to be resized.
96
+ size (sequence or int): Desired output size. If size is a sequence like
97
+ (h, w), the output size will be matched to this. If size is an int,
98
+ the smaller edge of the image will be matched to this number maintaing
99
+ the aspect ratio. i.e, if height > width, then image will be rescaled to
100
+ :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
101
+ interpolation (int, optional): Desired interpolation. Default is
102
+ ``cv2.INTER_LINEAR``
103
+ Returns:
104
+ PIL Image: Resized image.
105
+ """
106
+ if not _is_numpy_image(img):
107
+ raise TypeError('img should be numpy image. Got {}'.format(type(img)))
108
+ if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):
109
+ raise TypeError('Got inappropriate size arg: {}'.format(size))
110
+ h, w = img.shape[0], img.shape[1]
111
+
112
+ if isinstance(size, int):
113
+ if (w <= h and w == size) or (h <= w and h == size):
114
+ return img
115
+ if w < h:
116
+ ow = size
117
+ oh = int(size * h / w)
118
+ else:
119
+ oh = size
120
+ ow = int(size * w / h)
121
+ else:
122
+ ow, oh = size[1], size[0]
123
+ output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)
124
+ if img.shape[2] == 1:
125
+ return output[:, :, np.newaxis]
126
+ else:
127
+ return output
128
+
129
+
130
+ def scale(*args, **kwargs):
131
+ warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.")
132
+ return resize(*args, **kwargs)
133
+
134
+
135
+ def pad(img, padding, fill=0, padding_mode='constant'):
136
+ r"""Pad the given numpy ndarray on all sides with specified padding mode and fill value.
137
+ Args:
138
+ img (numpy ndarray): image to be padded.
139
+ padding (int or tuple): Padding on each border. If a single int is provided this
140
+ is used to pad all borders. If tuple of length 2 is provided this is the padding
141
+ on left/right and top/bottom respectively. If a tuple of length 4 is provided
142
+ this is the padding for the left, top, right and bottom borders
143
+ respectively.
144
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
145
+ length 3, it is used to fill R, G, B channels respectively.
146
+ This value is only used when the padding_mode is constant
147
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
148
+ - constant: pads with a constant value, this value is specified with fill
149
+ - edge: pads with the last value on the edge of the image
150
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
151
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
152
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
153
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
154
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
155
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
156
+ Returns:
157
+ Numpy image: padded image.
158
+ """
159
+ if not _is_numpy_image(img):
160
+ raise TypeError('img should be numpy ndarray. Got {}'.format(type(img)))
161
+ if not isinstance(padding, (numbers.Number, tuple, list)):
162
+ raise TypeError('Got inappropriate padding arg')
163
+ if not isinstance(fill, (numbers.Number, str, tuple)):
164
+ raise TypeError('Got inappropriate fill arg')
165
+ if not isinstance(padding_mode, str):
166
+ raise TypeError('Got inappropriate padding_mode arg')
167
+ if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
168
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
169
+ "{} element tuple".format(len(padding)))
170
+
171
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
172
+ 'Padding mode should be either constant, edge, reflect or symmetric'
173
+
174
+ if isinstance(padding, int):
175
+ pad_left = pad_right = pad_top = pad_bottom = padding
176
+ if isinstance(padding, collections.Sequence) and len(padding) == 2:
177
+ pad_left = pad_right = padding[0]
178
+ pad_top = pad_bottom = padding[1]
179
+ if isinstance(padding, collections.Sequence) and len(padding) == 4:
180
+ pad_left = padding[0]
181
+ pad_top = padding[1]
182
+ pad_right = padding[2]
183
+ pad_bottom = padding[3]
184
+ if img.shape[2] == 1:
185
+ return cv2.copyMakeBorder(img,
186
+ top=pad_top,
187
+ bottom=pad_bottom,
188
+ left=pad_left,
189
+ right=pad_right,
190
+ borderType=_cv2_pad_to_str[padding_mode],
191
+ value=fill)[:, :, np.newaxis]
192
+ else:
193
+ return cv2.copyMakeBorder(img,
194
+ top=pad_top,
195
+ bottom=pad_bottom,
196
+ left=pad_left,
197
+ right=pad_right,
198
+ borderType=_cv2_pad_to_str[padding_mode],
199
+ value=fill)
200
+
201
+
202
+ def crop(img, i, j, h, w):
203
+ """Crop the given PIL Image.
204
+ Args:
205
+ img (numpy ndarray): Image to be cropped.
206
+ i: Upper pixel coordinate.
207
+ j: Left pixel coordinate.
208
+ h: Height of the cropped image.
209
+ w: Width of the cropped image.
210
+ Returns:
211
+ numpy ndarray: Cropped image.
212
+ """
213
+ if not _is_numpy_image(img):
214
+ raise TypeError('img should be numpy image. Got {}'.format(type(img)))
215
+
216
+ return img[i:i + h, j:j + w, :]
217
+
218
+
219
+ def center_crop(img, output_size):
220
+ if isinstance(output_size, numbers.Number):
221
+ output_size = (int(output_size), int(output_size))
222
+ h, w = img.shape[0:2]
223
+ th, tw = output_size
224
+ i = int(round((h - th) / 2.))
225
+ j = int(round((w - tw) / 2.))
226
+ return crop(img, i, j, th, tw)
227
+
228
+
229
+ def resized_crop(img, i, j, h, w, size, interpolation=cv2.INTER_LINEAR):
230
+ """Crop the given numpy ndarray and resize it to desired size.
231
+ Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
232
+ Args:
233
+ img (numpy ndarray): Image to be cropped.
234
+ i: Upper pixel coordinate.
235
+ j: Left pixel coordinate.
236
+ h: Height of the cropped image.
237
+ w: Width of the cropped image.
238
+ size (sequence or int): Desired output size. Same semantics as ``scale``.
239
+ interpolation (int, optional): Desired interpolation. Default is
240
+ ``cv2.INTER_CUBIC``.
241
+ Returns:
242
+ PIL Image: Cropped image.
243
+ """
244
+ assert _is_numpy_image(img), 'img should be numpy image'
245
+ img = crop(img, i, j, h, w)
246
+ img = resize(img, size, interpolation=interpolation)
247
+ return img
248
+
249
+
250
+ def hflip(img):
251
+ """Horizontally flip the given numpy ndarray.
252
+ Args:
253
+ img (numpy ndarray): image to be flipped.
254
+ Returns:
255
+ numpy ndarray: Horizontally flipped image.
256
+ """
257
+ if not _is_numpy_image(img):
258
+ raise TypeError('img should be numpy image. Got {}'.format(type(img)))
259
+ # img[:,::-1] is much faster, but doesn't work with torch.from_numpy()!
260
+ if img.shape[2] == 1:
261
+ return cv2.flip(img, 1)[:, :, np.newaxis]
262
+ else:
263
+ return cv2.flip(img, 1)
264
+
265
+
266
+ def vflip(img):
267
+ """Vertically flip the given numpy ndarray.
268
+ Args:
269
+ img (numpy ndarray): Image to be flipped.
270
+ Returns:
271
+ numpy ndarray: Vertically flipped image.
272
+ """
273
+ if not _is_numpy_image(img):
274
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
275
+ if img.shape[2] == 1:
276
+ return cv2.flip(img, 0)[:, :, np.newaxis]
277
+ else:
278
+ return cv2.flip(img, 0)
279
+ # img[::-1] is much faster, but doesn't work with torch.from_numpy()!
280
+
281
+
282
+ def five_crop(img, size):
283
+ """Crop the given numpy ndarray into four corners and the central crop.
284
+ .. Note::
285
+ This transform returns a tuple of images and there may be a
286
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
287
+ Args:
288
+ size (sequence or int): Desired output size of the crop. If size is an
289
+ int instead of sequence like (h, w), a square crop (size, size) is
290
+ made.
291
+ Returns:
292
+ tuple: tuple (tl, tr, bl, br, center)
293
+ Corresponding top left, top right, bottom left, bottom right and center crop.
294
+ """
295
+ if isinstance(size, numbers.Number):
296
+ size = (int(size), int(size))
297
+ else:
298
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
299
+
300
+ h, w = img.shape[0:2]
301
+ crop_h, crop_w = size
302
+ if crop_w > w or crop_h > h:
303
+ raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w)))
304
+ tl = crop(img, 0, 0, crop_h, crop_w)
305
+ tr = crop(img, 0, w - crop_w, crop_h, crop_w)
306
+ bl = crop(img, h - crop_h, 0, crop_h, crop_w)
307
+ br = crop(img, h - crop_h, w - crop_w, crop_h, crop_w)
308
+ center = center_crop(img, (crop_h, crop_w))
309
+ return tl, tr, bl, br, center
310
+
311
+
312
+ def ten_crop(img, size, vertical_flip=False):
313
+ r"""Crop the given numpy ndarray into four corners and the central crop plus the
314
+ flipped version of these (horizontal flipping is used by default).
315
+ .. Note::
316
+ This transform returns a tuple of images and there may be a
317
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
318
+ Args:
319
+ size (sequence or int): Desired output size of the crop. If size is an
320
+ int instead of sequence like (h, w), a square crop (size, size) is
321
+ made.
322
+ vertical_flip (bool): Use vertical flipping instead of horizontal
323
+ Returns:
324
+ tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
325
+ Corresponding top left, top right, bottom left, bottom right and center crop
326
+ and same for the flipped image.
327
+ """
328
+ if isinstance(size, numbers.Number):
329
+ size = (int(size), int(size))
330
+ else:
331
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
332
+
333
+ first_five = five_crop(img, size)
334
+
335
+ if vertical_flip:
336
+ img = vflip(img)
337
+ else:
338
+ img = hflip(img)
339
+
340
+ second_five = five_crop(img, size)
341
+ return first_five + second_five
342
+
343
+
344
+ def adjust_brightness(img, brightness_factor):
345
+ """Adjust brightness of an Image.
346
+ Args:
347
+ img (numpy ndarray): numpy ndarray to be adjusted.
348
+ brightness_factor (float): How much to adjust the brightness. Can be
349
+ any non negative number. 0 gives a black image, 1 gives the
350
+ original image while 2 increases the brightness by a factor of 2.
351
+ Returns:
352
+ numpy ndarray: Brightness adjusted image.
353
+ """
354
+ if not _is_numpy_image(img):
355
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
356
+ table = np.array([i * brightness_factor for i in range(0, 256)]).clip(0, 255).astype('uint8')
357
+ # same thing but a bit slower
358
+ # cv2.convertScaleAbs(img, alpha=brightness_factor, beta=0)
359
+ if img.shape[2] == 1:
360
+ return cv2.LUT(img, table)[:, :, np.newaxis]
361
+ else:
362
+ return cv2.LUT(img, table)
363
+
364
+
365
+ def adjust_contrast(img, contrast_factor):
366
+ """Adjust contrast of an mage.
367
+ Args:
368
+ img (numpy ndarray): numpy ndarray to be adjusted.
369
+ contrast_factor (float): How much to adjust the contrast. Can be any
370
+ non negative number. 0 gives a solid gray image, 1 gives the
371
+ original image while 2 increases the contrast by a factor of 2.
372
+ Returns:
373
+ numpy ndarray: Contrast adjusted image.
374
+ """
375
+ # much faster to use the LUT construction than anything else I've tried
376
+ # it's because you have to change dtypes multiple times
377
+ if not _is_numpy_image(img):
378
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
379
+
380
+ # input is RGB
381
+ if img.ndim > 2 and img.shape[2] == 3:
382
+ mean_value = round(cv2.mean(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY))[0])
383
+ elif img.ndim == 2:
384
+ # grayscale input
385
+ mean_value = round(cv2.mean(img)[0])
386
+ else:
387
+ # multichannel input
388
+ mean_value = round(np.mean(img))
389
+
390
+ table = np.array([(i - mean_value) * contrast_factor + mean_value for i in range(0, 256)]).clip(0,
391
+ 255).astype('uint8')
392
+ # enhancer = ImageEnhance.Contrast(img)
393
+ # img = enhancer.enhance(contrast_factor)
394
+ if img.ndim == 2 or img.shape[2] == 1:
395
+ return cv2.LUT(img, table)[:, :, np.newaxis]
396
+ else:
397
+ return cv2.LUT(img, table)
398
+
399
+
400
+ def adjust_saturation(img, saturation_factor):
401
+ """Adjust color saturation of an image.
402
+ Args:
403
+ img (numpy ndarray): numpy ndarray to be adjusted.
404
+ saturation_factor (float): How much to adjust the saturation. 0 will
405
+ give a black and white image, 1 will give the original image while
406
+ 2 will enhance the saturation by a factor of 2.
407
+ Returns:
408
+ numpy ndarray: Saturation adjusted image.
409
+ """
410
+ # ~10ms slower than PIL!
411
+ if not _is_numpy_image(img):
412
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
413
+ img = Image.fromarray(img)
414
+ enhancer = ImageEnhance.Color(img)
415
+ img = enhancer.enhance(saturation_factor)
416
+ return np.array(img)
417
+
418
+
419
+ def adjust_hue(img, hue_factor):
420
+ """Adjust hue of an image.
421
+ The image hue is adjusted by converting the image to HSV and
422
+ cyclically shifting the intensities in the hue channel (H).
423
+ The image is then converted back to original image mode.
424
+ `hue_factor` is the amount of shift in H channel and must be in the
425
+ interval `[-0.5, 0.5]`.
426
+ See `Hue`_ for more details.
427
+ .. _Hue: https://en.wikipedia.org/wiki/Hue
428
+ Args:
429
+ img (numpy ndarray): numpy ndarray to be adjusted.
430
+ hue_factor (float): How much to shift the hue channel. Should be in
431
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
432
+ HSV space in positive and negative direction respectively.
433
+ 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
434
+ with complementary colors while 0 gives the original image.
435
+ Returns:
436
+ numpy ndarray: Hue adjusted image.
437
+ """
438
+ # After testing, found that OpenCV calculates the Hue in a call to
439
+ # cv2.cvtColor(..., cv2.COLOR_BGR2HSV) differently from PIL
440
+
441
+ # This function takes 160ms! should be avoided
442
+ if not (-0.5 <= hue_factor <= 0.5):
443
+ raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
444
+ if not _is_numpy_image(img):
445
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
446
+ img = Image.fromarray(img)
447
+ input_mode = img.mode
448
+ if input_mode in {'L', '1', 'I', 'F'}:
449
+ return np.array(img)
450
+
451
+ h, s, v = img.convert('HSV').split()
452
+
453
+ np_h = np.array(h, dtype=np.uint8)
454
+ # uint8 addition take cares of rotation across boundaries
455
+ with np.errstate(over='ignore'):
456
+ np_h += np.uint8(hue_factor * 255)
457
+ h = Image.fromarray(np_h, 'L')
458
+
459
+ img = Image.merge('HSV', (h, s, v)).convert(input_mode)
460
+ return np.array(img)
461
+
462
+
463
+ def adjust_gamma(img, gamma, gain=1):
464
+ r"""Perform gamma correction on an image.
465
+ Also known as Power Law Transform. Intensities in RGB mode are adjusted
466
+ based on the following equation:
467
+ .. math::
468
+ I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
469
+ See `Gamma Correction`_ for more details.
470
+ .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
471
+ Args:
472
+ img (numpy ndarray): numpy ndarray to be adjusted.
473
+ gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
474
+ gamma larger than 1 make the shadows darker,
475
+ while gamma smaller than 1 make dark regions lighter.
476
+ gain (float): The constant multiplier.
477
+ """
478
+ if not _is_numpy_image(img):
479
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
480
+
481
+ if gamma < 0:
482
+ raise ValueError('Gamma should be a non-negative real number')
483
+ # from here
484
+ # https://stackoverflow.com/questions/33322488/how-to-change-image-illumination-in-opencv-python/41061351
485
+ table = np.array([((i / 255.0)**gamma) * 255 * gain for i in np.arange(0, 256)]).astype('uint8')
486
+ if img.shape[2] == 1:
487
+ return cv2.LUT(img, table)[:, :, np.newaxis]
488
+ else:
489
+ return cv2.LUT(img, table)
490
+
491
+
492
+ def rotate(img, angle, resample=False, expand=False, center=None):
493
+ """Rotate the image by angle.
494
+ Args:
495
+ img (numpy ndarray): numpy ndarray to be rotated.
496
+ angle (float or int): In degrees degrees counter clockwise order.
497
+ resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
498
+ An optional resampling filter. See `filters`_ for more information.
499
+ If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
500
+ expand (bool, optional): Optional expansion flag.
501
+ If true, expands the output image to make it large enough to hold the entire rotated image.
502
+ If false or omitted, make the output image the same size as the input image.
503
+ Note that the expand flag assumes rotation around the center and no translation.
504
+ center (2-tuple, optional): Optional center of rotation.
505
+ Origin is the upper left corner.
506
+ Default is the center of the image.
507
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
508
+ """
509
+ if not _is_numpy_image(img):
510
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
511
+ rows, cols = img.shape[0:2]
512
+ if center is None:
513
+ center = (cols / 2, rows / 2)
514
+ M = cv2.getRotationMatrix2D(center, angle, 1)
515
+ if img.shape[2] == 1:
516
+ return cv2.warpAffine(img, M, (cols, rows))[:, :, np.newaxis]
517
+ else:
518
+ return cv2.warpAffine(img, M, (cols, rows))
519
+
520
+
521
+ def _get_affine_matrix(center, angle, translate, scale, shear):
522
+ # Helper method to compute matrix for affine transformation
523
+ # We need compute affine transformation matrix: M = T * C * RSS * C^-1
524
+ # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
525
+ # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
526
+ # RSS is rotation with scale and shear matrix
527
+ # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0]
528
+ # [ sin(a)*scale cos(a + shear)*scale 0]
529
+ # [ 0 0 1]
530
+
531
+ angle = math.radians(angle)
532
+ shear = math.radians(shear)
533
+ # scale = 1.0 / scale
534
+
535
+ T = np.array([[1, 0, translate[0]], [0, 1, translate[1]], [0, 0, 1]])
536
+ C = np.array([[1, 0, center[0]], [0, 1, center[1]], [0, 0, 1]])
537
+ RSS = np.array([[math.cos(angle) * scale, -math.sin(angle + shear) * scale, 0],
538
+ [math.sin(angle) * scale, math.cos(angle + shear) * scale, 0], [0, 0, 1]])
539
+ matrix = T @ C @ RSS @ np.linalg.inv(C)
540
+
541
+ return matrix[:2, :]
542
+
543
+
544
+ def affine(img, angle, translate, scale, shear, interpolation=cv2.INTER_LINEAR, mode=cv2.BORDER_CONSTANT, fillcolor=0):
545
+ """Apply affine transformation on the image keeping image center invariant
546
+ Args:
547
+ img (numpy ndarray): numpy ndarray to be transformed.
548
+ angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
549
+ translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
550
+ scale (float): overall scale
551
+ shear (float): shear angle value in degrees between -180 to 180, clockwise direction.
552
+ interpolation (``cv2.INTER_NEAREST` or ``cv2.INTER_LINEAR`` or ``cv2.INTER_AREA``, ``cv2.INTER_CUBIC``):
553
+ An optional resampling filter.
554
+ See `filters`_ for more information.
555
+ If omitted, it is set to ``cv2.INTER_CUBIC``, for bicubic interpolation.
556
+ mode (``cv2.BORDER_CONSTANT`` or ``cv2.BORDER_REPLICATE`` or ``cv2.BORDER_REFLECT`` or ``cv2.BORDER_REFLECT_101``)
557
+ Method for filling in border regions.
558
+ Defaults to cv2.BORDER_CONSTANT, meaning areas outside the image are filled with a value (val, default 0)
559
+ val (int): Optional fill color for the area outside the transform in the output image. Default: 0
560
+ """
561
+ if not _is_numpy_image(img):
562
+ raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
563
+
564
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
565
+ "Argument translate should be a list or tuple of length 2"
566
+
567
+ assert scale > 0.0, "Argument scale should be positive"
568
+
569
+ output_size = img.shape[0:2]
570
+ center = (img.shape[1] * 0.5 + 0.5, img.shape[0] * 0.5 + 0.5)
571
+ matrix = _get_affine_matrix(center, angle, translate, scale, shear)
572
+
573
+ if img.shape[2] == 1:
574
+ return cv2.warpAffine(img, matrix, output_size[::-1], interpolation, borderMode=mode,
575
+ borderValue=fillcolor)[:, :, np.newaxis]
576
+ else:
577
+ return cv2.warpAffine(img, matrix, output_size[::-1], interpolation, borderMode=mode, borderValue=fillcolor)
578
+
579
+
580
+ def to_grayscale(img, num_output_channels: int = 1):
581
+ """Convert image to grayscale version of image.
582
+ Args:
583
+ img (numpy ndarray): Image to be converted to grayscale.
584
+ num_output_channels: int
585
+ if 1 : returned image is single channel
586
+ if 3 : returned image is 3 channel with r = g = b
587
+ Returns:
588
+ numpy ndarray: Grayscale version of the image.
589
+ """
590
+ if not _is_numpy_image(img):
591
+ raise TypeError('img should be numpy ndarray. Got {}'.format(type(img)))
592
+
593
+ if num_output_channels == 1:
594
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
595
+ elif num_output_channels == 3:
596
+ # much faster than doing cvtColor to go back to gray
597
+ img = np.broadcast_to(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis], img.shape)
598
+ return img
dataset/opencv_transforms/transforms.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import collections
4
+ import math
5
+ import numbers
6
+ import random
7
+ import types
8
+ import warnings
9
+
10
+ # from PIL import Image, ImageOps, ImageEnhance
11
+ try:
12
+ import accimage
13
+ except ImportError:
14
+ accimage = None
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+
19
+ from . import functional as F
20
+
21
+ __all__ = [
22
+ "Compose", "ToTensor", "Normalize", "Resize", "Scale",
23
+ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice",
24
+ "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip",
25
+ "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
26
+ "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine",
27
+ "Grayscale", "RandomGrayscale"
28
+ ]
29
+
30
+ _cv2_pad_to_str = {
31
+ 'constant': cv2.BORDER_CONSTANT,
32
+ 'edge': cv2.BORDER_REPLICATE,
33
+ 'reflect': cv2.BORDER_REFLECT_101,
34
+ 'symmetric': cv2.BORDER_REFLECT
35
+ }
36
+ _cv2_interpolation_to_str = {
37
+ 'nearest': cv2.INTER_NEAREST,
38
+ 'bilinear': cv2.INTER_LINEAR,
39
+ 'area': cv2.INTER_AREA,
40
+ 'bicubic': cv2.INTER_CUBIC,
41
+ 'lanczos': cv2.INTER_LANCZOS4
42
+ }
43
+ _cv2_interpolation_from_str = {
44
+ v: k
45
+ for k, v in _cv2_interpolation_to_str.items()
46
+ }
47
+
48
+
49
+ class Compose(object):
50
+ """Composes several transforms together.
51
+ Args:
52
+ transforms (list of ``Transform`` objects): list of transforms to compose.
53
+ Example:
54
+ >>> transforms.Compose([
55
+ >>> transforms.CenterCrop(10),
56
+ >>> transforms.ToTensor(),
57
+ >>> ])
58
+ """
59
+ def __init__(self, transforms):
60
+ self.transforms = transforms
61
+
62
+ def __call__(self, img):
63
+ for t in self.transforms:
64
+ img = t(img)
65
+ return img
66
+
67
+ def __repr__(self):
68
+ format_string = self.__class__.__name__ + '('
69
+ for t in self.transforms:
70
+ format_string += '\n'
71
+ format_string += ' {0}'.format(t)
72
+ format_string += '\n)'
73
+ return format_string
74
+
75
+
76
+ class ToTensor(object):
77
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
78
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
79
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
80
+ """
81
+ def __call__(self, pic):
82
+ """
83
+ Args:
84
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
85
+ Returns:
86
+ Tensor: Converted image.
87
+ """
88
+ return F.to_tensor(pic)
89
+
90
+ def __repr__(self):
91
+ return self.__class__.__name__ + '()'
92
+
93
+
94
+ class Normalize(object):
95
+ """Normalize a tensor image with mean and standard deviation.
96
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
97
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
98
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
99
+ .. note::
100
+ This transform acts in-place, i.e., it mutates the input tensor.
101
+ Args:
102
+ mean (sequence): Sequence of means for each channel.
103
+ std (sequence): Sequence of standard deviations for each channel.
104
+ """
105
+ def __init__(self, mean, std):
106
+ self.mean = mean
107
+ self.std = std
108
+
109
+ def __call__(self, tensor):
110
+ """
111
+ Args:
112
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
113
+ Returns:
114
+ Tensor: Normalized Tensor image.
115
+ """
116
+ return F.normalize(tensor, self.mean, self.std)
117
+
118
+ def __repr__(self):
119
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(
120
+ self.mean, self.std)
121
+
122
+
123
+ class Resize(object):
124
+ """Resize the input numpy ndarray to the given size.
125
+ Args:
126
+ size (sequence or int): Desired output size. If size is a sequence like
127
+ (h, w), output size will be matched to this. If size is an int,
128
+ smaller edge of the image will be matched to this number.
129
+ i.e, if height > width, then image will be rescaled to
130
+ (size * height / width, size)
131
+ interpolation (int, optional): Desired interpolation. Default is
132
+ ``cv2.INTER_CUBIC``, bicubic interpolation
133
+ """
134
+
135
+ def __init__(self, size, interpolation=cv2.INTER_LINEAR):
136
+ # assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
137
+ if isinstance(size, int):
138
+ self.size = size
139
+ elif isinstance(size, collections.abc.Iterable) and len(size) == 2:
140
+ if type(size) == list:
141
+ size = tuple(size)
142
+ self.size = size
143
+ else:
144
+ raise ValueError('Unknown inputs for size: {}'.format(size))
145
+ self.interpolation = interpolation
146
+
147
+ def __call__(self, img):
148
+ """
149
+ Args:
150
+ img (numpy ndarray): Image to be scaled.
151
+ Returns:
152
+ numpy ndarray: Rescaled image.
153
+ """
154
+ return F.resize(img, self.size, self.interpolation)
155
+
156
+ def __repr__(self):
157
+ interpolate_str = _cv2_interpolation_from_str[self.interpolation]
158
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
159
+ self.size, interpolate_str)
160
+
161
+
162
+ class Scale(Resize):
163
+ """
164
+ Note: This transform is deprecated in favor of Resize.
165
+ """
166
+ def __init__(self, *args, **kwargs):
167
+ warnings.warn(
168
+ "The use of the transforms.Scale transform is deprecated, " +
169
+ "please use transforms.Resize instead.")
170
+ super(Scale, self).__init__(*args, **kwargs)
171
+
172
+
173
+ class CenterCrop(object):
174
+ """Crops the given numpy ndarray at the center.
175
+ Args:
176
+ size (sequence or int): Desired output size of the crop. If size is an
177
+ int instead of sequence like (h, w), a square crop (size, size) is
178
+ made.
179
+ """
180
+ def __init__(self, size):
181
+ if isinstance(size, numbers.Number):
182
+ self.size = (int(size), int(size))
183
+ else:
184
+ self.size = size
185
+
186
+ def __call__(self, img):
187
+ """
188
+ Args:
189
+ img (numpy ndarray): Image to be cropped.
190
+ Returns:
191
+ numpy ndarray: Cropped image.
192
+ """
193
+ return F.center_crop(img, self.size)
194
+
195
+ def __repr__(self):
196
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
197
+
198
+
199
+ class Pad(object):
200
+ """Pad the given numpy ndarray on all sides with the given "pad" value.
201
+ Args:
202
+ padding (int or tuple): Padding on each border. If a single int is provided this
203
+ is used to pad all borders. If tuple of length 2 is provided this is the padding
204
+ on left/right and top/bottom respectively. If a tuple of length 4 is provided
205
+ this is the padding for the left, top, right and bottom borders
206
+ respectively.
207
+ fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
208
+ length 3, it is used to fill R, G, B channels respectively.
209
+ This value is only used when the padding_mode is constant
210
+ padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
211
+ Default is constant.
212
+ - constant: pads with a constant value, this value is specified with fill
213
+ - edge: pads with the last value at the edge of the image
214
+ - reflect: pads with reflection of image without repeating the last value on the edge
215
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
216
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
217
+ - symmetric: pads with reflection of image repeating the last value on the edge
218
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
219
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
220
+ """
221
+ def __init__(self, padding, fill=0, padding_mode='constant'):
222
+ assert isinstance(padding, (numbers.Number, tuple, list))
223
+ assert isinstance(fill, (numbers.Number, str, tuple))
224
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
225
+ if isinstance(padding,
226
+ collections.Sequence) and len(padding) not in [2, 4]:
227
+ raise ValueError(
228
+ "Padding must be an int or a 2, or 4 element tuple, not a " +
229
+ "{} element tuple".format(len(padding)))
230
+
231
+ self.padding = padding
232
+ self.fill = fill
233
+ self.padding_mode = padding_mode
234
+
235
+ def __call__(self, img):
236
+ """
237
+ Args:
238
+ img (numpy ndarray): Image to be padded.
239
+ Returns:
240
+ numpy ndarray: Padded image.
241
+ """
242
+ return F.pad(img, self.padding, self.fill, self.padding_mode)
243
+
244
+ def __repr__(self):
245
+ return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
246
+ format(self.padding, self.fill, self.padding_mode)
247
+
248
+
249
+ class Lambda(object):
250
+ """Apply a user-defined lambda as a transform.
251
+ Args:
252
+ lambd (function): Lambda/function to be used for transform.
253
+ """
254
+ def __init__(self, lambd):
255
+ assert isinstance(lambd, types.LambdaType)
256
+ self.lambd = lambd
257
+
258
+ def __call__(self, img):
259
+ return self.lambd(img)
260
+
261
+ def __repr__(self):
262
+ return self.__class__.__name__ + '()'
263
+
264
+
265
+ class RandomTransforms(object):
266
+ """Base class for a list of transformations with randomness
267
+ Args:
268
+ transforms (list or tuple): list of transformations
269
+ """
270
+ def __init__(self, transforms):
271
+ assert isinstance(transforms, (list, tuple))
272
+ self.transforms = transforms
273
+
274
+ def __call__(self, *args, **kwargs):
275
+ raise NotImplementedError()
276
+
277
+ def __repr__(self):
278
+ format_string = self.__class__.__name__ + '('
279
+ for t in self.transforms:
280
+ format_string += '\n'
281
+ format_string += ' {0}'.format(t)
282
+ format_string += '\n)'
283
+ return format_string
284
+
285
+
286
+ class RandomApply(RandomTransforms):
287
+ """Apply randomly a list of transformations with a given probability
288
+ Args:
289
+ transforms (list or tuple): list of transformations
290
+ p (float): probability
291
+ """
292
+ def __init__(self, transforms, p=0.5):
293
+ super(RandomApply, self).__init__(transforms)
294
+ self.p = p
295
+
296
+ def __call__(self, img):
297
+ if self.p < random.random():
298
+ return img
299
+ for t in self.transforms:
300
+ img = t(img)
301
+ return img
302
+
303
+ def __repr__(self):
304
+ format_string = self.__class__.__name__ + '('
305
+ format_string += '\n p={}'.format(self.p)
306
+ for t in self.transforms:
307
+ format_string += '\n'
308
+ format_string += ' {0}'.format(t)
309
+ format_string += '\n)'
310
+ return format_string
311
+
312
+
313
+ class RandomOrder(RandomTransforms):
314
+ """Apply a list of transformations in a random order
315
+ """
316
+ def __call__(self, img):
317
+ order = list(range(len(self.transforms)))
318
+ random.shuffle(order)
319
+ for i in order:
320
+ img = self.transforms[i](img)
321
+ return img
322
+
323
+
324
+ class RandomChoice(RandomTransforms):
325
+ """Apply single transformation randomly picked from a list
326
+ """
327
+ def __call__(self, img):
328
+ t = random.choice(self.transforms)
329
+ return t(img)
330
+
331
+
332
+ class RandomCrop(object):
333
+ """Crop the given numpy ndarray at a random location.
334
+ Args:
335
+ size (sequence or int): Desired output size of the crop. If size is an
336
+ int instead of sequence like (h, w), a square crop (size, size) is
337
+ made.
338
+ padding (int or sequence, optional): Optional padding on each border
339
+ of the image. Default is None, i.e no padding. If a sequence of length
340
+ 4 is provided, it is used to pad left, top, right, bottom borders
341
+ respectively. If a sequence of length 2 is provided, it is used to
342
+ pad left/right, top/bottom borders, respectively.
343
+ pad_if_needed (boolean): It will pad the image if smaller than the
344
+ desired size to avoid raising an exception.
345
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
346
+ length 3, it is used to fill R, G, B channels respectively.
347
+ This value is only used when the padding_mode is constant
348
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
349
+ - constant: pads with a constant value, this value is specified with fill
350
+ - edge: pads with the last value on the edge of the image
351
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
352
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
353
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
354
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
355
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
356
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
357
+ """
358
+ def __init__(self,
359
+ size,
360
+ padding=None,
361
+ pad_if_needed=False,
362
+ fill=0,
363
+ padding_mode='constant'):
364
+ if isinstance(size, numbers.Number):
365
+ self.size = (int(size), int(size))
366
+ else:
367
+ self.size = size
368
+ self.padding = padding
369
+ self.pad_if_needed = pad_if_needed
370
+ self.fill = fill
371
+ self.padding_mode = padding_mode
372
+
373
+ @staticmethod
374
+ def get_params(img, output_size):
375
+ """Get parameters for ``crop`` for a random crop.
376
+ Args:
377
+ img (numpy ndarray): Image to be cropped.
378
+ output_size (tuple): Expected output size of the crop.
379
+ Returns:
380
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
381
+ """
382
+ h, w = img.shape[0:2]
383
+ th, tw = output_size
384
+ if w == tw and h == th:
385
+ return 0, 0, h, w
386
+
387
+ i = random.randint(0, h - th)
388
+ j = random.randint(0, w - tw)
389
+ return i, j, th, tw
390
+
391
+ def __call__(self, img):
392
+ """
393
+ Args:
394
+ img (numpy ndarray): Image to be cropped.
395
+ Returns:
396
+ numpy ndarray: Cropped image.
397
+ """
398
+ if self.padding is not None:
399
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
400
+
401
+ # pad the width if needed
402
+ if self.pad_if_needed and img.shape[1] < self.size[1]:
403
+ img = F.pad(img, (self.size[1] - img.shape[1], 0), self.fill,
404
+ self.padding_mode)
405
+ # pad the height if needed
406
+ if self.pad_if_needed and img.shape[0] < self.size[0]:
407
+ img = F.pad(img, (0, self.size[0] - img.shape[0]), self.fill,
408
+ self.padding_mode)
409
+
410
+ i, j, h, w = self.get_params(img, self.size)
411
+
412
+ return F.crop(img, i, j, h, w)
413
+
414
+ def __repr__(self):
415
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(
416
+ self.size, self.padding)
417
+
418
+
419
+ class RandomHorizontalFlip(object):
420
+ """Horizontally flip the given PIL Image randomly with a given probability.
421
+ Args:
422
+ p (float): probability of the image being flipped. Default value is 0.5
423
+ """
424
+ def __init__(self, p=0.5):
425
+ self.p = p
426
+
427
+ def __call__(self, img):
428
+ """random
429
+ Args:
430
+ img (numpy ndarray): Image to be flipped.
431
+ Returns:
432
+ numpy ndarray: Randomly flipped image.
433
+ """
434
+ # if random.random() < self.p:
435
+ # print('flip')
436
+ # return F.hflip(img)
437
+ if random.random() < self.p:
438
+ return F.hflip(img)
439
+ return img
440
+
441
+ def __repr__(self):
442
+ return self.__class__.__name__ + '(p={})'.format(self.p)
443
+
444
+
445
+ class RandomVerticalFlip(object):
446
+ """Vertically flip the given PIL Image randomly with a given probability.
447
+ Args:
448
+ p (float): probability of the image being flipped. Default value is 0.5
449
+ """
450
+ def __init__(self, p=0.5):
451
+ self.p = p
452
+
453
+ def __call__(self, img):
454
+ """
455
+ Args:
456
+ img (numpy ndarray): Image to be flipped.
457
+ Returns:
458
+ numpy ndarray: Randomly flipped image.
459
+ """
460
+ if random.random() < self.p:
461
+ return F.vflip(img)
462
+ return img
463
+
464
+ def __repr__(self):
465
+ return self.__class__.__name__ + '(p={})'.format(self.p)
466
+
467
+
468
+ class RandomResizedCrop(object):
469
+ """Crop the given numpy ndarray to random size and aspect ratio.
470
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
471
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
472
+ is finally resized to given size.
473
+ This is popularly used to train the Inception networks.
474
+ Args:
475
+ size: expected output size of each edge
476
+ scale: range of size of the origin size cropped
477
+ ratio: range of aspect ratio of the origin aspect ratio cropped
478
+ interpolation: Default: cv2.INTER_CUBIC
479
+ """
480
+ def __init__(self,
481
+ size,
482
+ scale=(0.08, 1.0),
483
+ ratio=(3. / 4., 4. / 3.),
484
+ interpolation=cv2.INTER_LINEAR):
485
+ self.size = (size, size)
486
+ self.interpolation = interpolation
487
+ self.scale = scale
488
+ self.ratio = ratio
489
+
490
+ @staticmethod
491
+ def get_params(img, scale, ratio):
492
+ """Get parameters for ``crop`` for a random sized crop.
493
+ Args:
494
+ img (numpy ndarray): Image to be cropped.
495
+ scale (tuple): range of size of the origin size cropped
496
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
497
+ Returns:
498
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
499
+ sized crop.
500
+ """
501
+ for attempt in range(10):
502
+ area = img.shape[0] * img.shape[1]
503
+ target_area = random.uniform(*scale) * area
504
+ aspect_ratio = random.uniform(*ratio)
505
+
506
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
507
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
508
+
509
+ if random.random() < 0.5:
510
+ w, h = h, w
511
+
512
+ if w <= img.shape[1] and h <= img.shape[0]:
513
+ i = random.randint(0, img.shape[0] - h)
514
+ j = random.randint(0, img.shape[1] - w)
515
+ return i, j, h, w
516
+
517
+ # Fallback
518
+ w = min(img.shape[0], img.shape[1])
519
+ i = (img.shape[0] - w) // 2
520
+ j = (img.shape[1] - w) // 2
521
+ return i, j, w, w
522
+
523
+ def __call__(self, img):
524
+ """
525
+ Args:
526
+ img (numpy ndarray): Image to be cropped and resized.
527
+ Returns:
528
+ numpy ndarray: Randomly cropped and resized image.
529
+ """
530
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
531
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
532
+
533
+ def __repr__(self):
534
+ interpolate_str = _cv2_interpolation_from_str[self.interpolation]
535
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
536
+ format_string += ', scale={0}'.format(
537
+ tuple(round(s, 4) for s in self.scale))
538
+ format_string += ', ratio={0}'.format(
539
+ tuple(round(r, 4) for r in self.ratio))
540
+ format_string += ', interpolation={0})'.format(interpolate_str)
541
+ return format_string
542
+
543
+
544
+ class RandomSizedCrop(RandomResizedCrop):
545
+ """
546
+ Note: This transform is deprecated in favor of RandomResizedCrop.
547
+ """
548
+ def __init__(self, *args, **kwargs):
549
+ warnings.warn(
550
+ "The use of the transforms.RandomSizedCrop transform is deprecated, "
551
+ + "please use transforms.RandomResizedCrop instead.")
552
+ super(RandomSizedCrop, self).__init__(*args, **kwargs)
553
+
554
+
555
+ class FiveCrop(object):
556
+ """Crop the given numpy ndarray into four corners and the central crop
557
+ .. Note::
558
+ This transform returns a tuple of images and there may be a mismatch in the number of
559
+ inputs and targets your Dataset returns. See below for an example of how to deal with
560
+ this.
561
+ Args:
562
+ size (sequence or int): Desired output size of the crop. If size is an ``int``
563
+ instead of sequence like (h, w), a square crop of size (size, size) is made.
564
+ Example:
565
+ >>> transform = Compose([
566
+ >>> FiveCrop(size), # this is a list of numpy ndarrays
567
+ >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
568
+ >>> ])
569
+ >>> #In your test loop you can do the following:
570
+ >>> input, target = batch # input is a 5d tensor, target is 2d
571
+ >>> bs, ncrops, c, h, w = input.size()
572
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
573
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
574
+ """
575
+ def __init__(self, size):
576
+ self.size = size
577
+ if isinstance(size, numbers.Number):
578
+ self.size = (int(size), int(size))
579
+ else:
580
+ assert len(
581
+ size
582
+ ) == 2, "Please provide only two dimensions (h, w) for size."
583
+ self.size = size
584
+
585
+ def __call__(self, img):
586
+ return F.five_crop(img, self.size)
587
+
588
+ def __repr__(self):
589
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
590
+
591
+
592
+ class TenCrop(object):
593
+ """Crop the given numpy ndarray into four corners and the central crop plus the flipped version of
594
+ these (horizontal flipping is used by default)
595
+ .. Note::
596
+ This transform returns a tuple of images and there may be a mismatch in the number of
597
+ inputs and targets your Dataset returns. See below for an example of how to deal with
598
+ this.
599
+ Args:
600
+ size (sequence or int): Desired output size of the crop. If size is an
601
+ int instead of sequence like (h, w), a square crop (size, size) is
602
+ made.
603
+ vertical_flip(bool): Use vertical flipping instead of horizontal
604
+ Example:
605
+ >>> transform = Compose([
606
+ >>> TenCrop(size), # this is a list of PIL Images
607
+ >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
608
+ >>> ])
609
+ >>> #In your test loop you can do the following:
610
+ >>> input, target = batch # input is a 5d tensor, target is 2d
611
+ >>> bs, ncrops, c, h, w = input.size()
612
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
613
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
614
+ """
615
+ def __init__(self, size, vertical_flip=False):
616
+ self.size = size
617
+ if isinstance(size, numbers.Number):
618
+ self.size = (int(size), int(size))
619
+ else:
620
+ assert len(
621
+ size
622
+ ) == 2, "Please provide only two dimensions (h, w) for size."
623
+ self.size = size
624
+ self.vertical_flip = vertical_flip
625
+
626
+ def __call__(self, img):
627
+ return F.ten_crop(img, self.size, self.vertical_flip)
628
+
629
+ def __repr__(self):
630
+ return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(
631
+ self.size, self.vertical_flip)
632
+
633
+
634
+ class LinearTransformation(object):
635
+ """Transform a tensor image with a square transformation matrix computed
636
+ offline.
637
+ Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
638
+ product with the transformation matrix and reshape the tensor to its
639
+ original shape.
640
+ Applications:
641
+ - whitening: zero-center the data, compute the data covariance matrix
642
+ [D x D] with np.dot(X.T, X), perform SVD on this matrix and
643
+ pass it as transformation_matrix.
644
+ Args:
645
+ transformation_matrix (Tensor): tensor [D x D], D = C x H x W
646
+ """
647
+ def __init__(self, transformation_matrix):
648
+ if transformation_matrix.size(0) != transformation_matrix.size(1):
649
+ raise ValueError("transformation_matrix should be square. Got " +
650
+ "[{} x {}] rectangular matrix.".format(
651
+ *transformation_matrix.size()))
652
+ self.transformation_matrix = transformation_matrix
653
+
654
+ def __call__(self, tensor):
655
+ """
656
+ Args:
657
+ tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
658
+ Returns:
659
+ Tensor: Transformed image.
660
+ """
661
+ if tensor.size(0) * tensor.size(1) * tensor.size(
662
+ 2) != self.transformation_matrix.size(0):
663
+ raise ValueError(
664
+ "tensor and transformation matrix have incompatible shape." +
665
+ "[{} x {} x {}] != ".format(*tensor.size()) +
666
+ "{}".format(self.transformation_matrix.size(0)))
667
+ flat_tensor = tensor.view(1, -1)
668
+ transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
669
+ tensor = transformed_tensor.view(tensor.size())
670
+ return tensor
671
+
672
+ def __repr__(self):
673
+ format_string = self.__class__.__name__ + '('
674
+ format_string += (str(self.transformation_matrix.numpy().tolist()) +
675
+ ')')
676
+ return format_string
677
+
678
+
679
+ class ColorJitter(object):
680
+ """Randomly change the brightness, contrast and saturation of an image.
681
+ Args:
682
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
683
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
684
+ or the given [min, max]. Should be non negative numbers.
685
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
686
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
687
+ or the given [min, max]. Should be non negative numbers.
688
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
689
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
690
+ or the given [min, max]. Should be non negative numbers.
691
+ hue (float or tuple of float (min, max)): How much to jitter hue.
692
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
693
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
694
+ """
695
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
696
+ self.brightness = self._check_input(brightness, 'brightness')
697
+ self.contrast = self._check_input(contrast, 'contrast')
698
+ self.saturation = self._check_input(saturation, 'saturation')
699
+ self.hue = self._check_input(hue,
700
+ 'hue',
701
+ center=0,
702
+ bound=(-0.5, 0.5),
703
+ clip_first_on_zero=False)
704
+ if self.saturation is not None:
705
+ warnings.warn(
706
+ 'Saturation jitter enabled. Will slow down loading immensely.')
707
+ if self.hue is not None:
708
+ warnings.warn(
709
+ 'Hue jitter enabled. Will slow down loading immensely.')
710
+
711
+ def _check_input(self,
712
+ value,
713
+ name,
714
+ center=1,
715
+ bound=(0, float('inf')),
716
+ clip_first_on_zero=True):
717
+ if isinstance(value, numbers.Number):
718
+ if value < 0:
719
+ raise ValueError(
720
+ "If {} is a single number, it must be non negative.".
721
+ format(name))
722
+ value = [center - value, center + value]
723
+ if clip_first_on_zero:
724
+ value[0] = max(value[0], 0)
725
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
726
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
727
+ raise ValueError("{} values should be between {}".format(
728
+ name, bound))
729
+ else:
730
+ raise TypeError(
731
+ "{} should be a single number or a list/tuple with length 2.".
732
+ format(name))
733
+
734
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
735
+ # or (0., 0.) for hue, do nothing
736
+ if value[0] == value[1] == center:
737
+ value = None
738
+ return value
739
+
740
+ @staticmethod
741
+ def get_params(brightness, contrast, saturation, hue):
742
+ """Get a randomized transform to be applied on image.
743
+ Arguments are same as that of __init__.
744
+ Returns:
745
+ Transform which randomly adjusts brightness, contrast and
746
+ saturation in a random order.
747
+ """
748
+ transforms = []
749
+
750
+ if brightness is not None:
751
+ brightness_factor = random.uniform(brightness[0], brightness[1])
752
+ transforms.append(
753
+ Lambda(
754
+ lambda img: F.adjust_brightness(img, brightness_factor)))
755
+
756
+ if contrast is not None:
757
+ contrast_factor = random.uniform(contrast[0], contrast[1])
758
+ transforms.append(
759
+ Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
760
+
761
+ if saturation is not None:
762
+ saturation_factor = random.uniform(saturation[0], saturation[1])
763
+ transforms.append(
764
+ Lambda(
765
+ lambda img: F.adjust_saturation(img, saturation_factor)))
766
+
767
+ if hue is not None:
768
+ hue_factor = random.uniform(hue[0], hue[1])
769
+ transforms.append(
770
+ Lambda(lambda img: F.adjust_hue(img, hue_factor)))
771
+
772
+ random.shuffle(transforms)
773
+ transform = Compose(transforms)
774
+
775
+ return transform
776
+
777
+ def __call__(self, img):
778
+ """
779
+ Args:
780
+ img (numpy ndarray): Input image.
781
+ Returns:
782
+ numpy ndarray: Color jittered image.
783
+ """
784
+ transform = self.get_params(self.brightness, self.contrast,
785
+ self.saturation, self.hue)
786
+ return transform(img)
787
+
788
+ def __repr__(self):
789
+ format_string = self.__class__.__name__ + '('
790
+ format_string += 'brightness={0}'.format(self.brightness)
791
+ format_string += ', contrast={0}'.format(self.contrast)
792
+ format_string += ', saturation={0}'.format(self.saturation)
793
+ format_string += ', hue={0})'.format(self.hue)
794
+ return format_string
795
+
796
+
797
+ class RandomRotation(object):
798
+ """Rotate the image by angle.
799
+ Args:
800
+ degrees (sequence or float or int): Range of degrees to select from.
801
+ If degrees is a number instead of sequence like (min, max), the range of degrees
802
+ will be (-degrees, +degrees).
803
+ resample ({cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4}, optional):
804
+ An optional resampling filter. See `filters`_ for more information.
805
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
806
+ expand (bool, optional): Optional expansion flag.
807
+ If true, expands the output to make it large enough to hold the entire rotated image.
808
+ If false or omitted, make the output image the same size as the input image.
809
+ Note that the expand flag assumes rotation around the center and no translation.
810
+ center (2-tuple, optional): Optional center of rotation.
811
+ Origin is the upper left corner.
812
+ Default is the center of the image.
813
+ """
814
+ def __init__(self, degrees, resample=False, expand=False, center=None):
815
+ if isinstance(degrees, numbers.Number):
816
+ if degrees < 0:
817
+ raise ValueError(
818
+ "If degrees is a single number, it must be positive.")
819
+ self.degrees = (-degrees, degrees)
820
+ else:
821
+ if len(degrees) != 2:
822
+ raise ValueError(
823
+ "If degrees is a sequence, it must be of len 2.")
824
+ self.degrees = degrees
825
+
826
+ self.resample = resample
827
+ self.expand = expand
828
+ self.center = center
829
+
830
+ @staticmethod
831
+ def get_params(degrees):
832
+ """Get parameters for ``rotate`` for a random rotation.
833
+ Returns:
834
+ sequence: params to be passed to ``rotate`` for random rotation.
835
+ """
836
+ angle = random.uniform(degrees[0], degrees[1])
837
+
838
+ return angle
839
+
840
+ def __call__(self, img):
841
+ """
842
+ img (numpy ndarray): Image to be rotated.
843
+ Returns:
844
+ numpy ndarray: Rotated image.
845
+ """
846
+
847
+ angle = self.get_params(self.degrees)
848
+
849
+ return F.rotate(img, angle, self.resample, self.expand, self.center)
850
+
851
+ def __repr__(self):
852
+ format_string = self.__class__.__name__ + '(degrees={0}'.format(
853
+ self.degrees)
854
+ format_string += ', resample={0}'.format(self.resample)
855
+ format_string += ', expand={0}'.format(self.expand)
856
+ if self.center is not None:
857
+ format_string += ', center={0}'.format(self.center)
858
+ format_string += ')'
859
+ return format_string
860
+
861
+
862
+ class RandomAffine(object):
863
+ """Random affine transformation of the image keeping center invariant
864
+ Args:
865
+ degrees (sequence or float or int): Range of degrees to select from.
866
+ If degrees is a number instead of sequence like (min, max), the range of degrees
867
+ will be (-degrees, +degrees). Set to 0 to deactivate rotations.
868
+ translate (tuple, optional): tuple of maximum absolute fraction for horizontal
869
+ and vertical translations. For example translate=(a, b), then horizontal shift
870
+ is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
871
+ randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
872
+ scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
873
+ randomly sampled from the range a <= scale <= b. Will keep original scale by default.
874
+ shear (sequence or float or int, optional): Range of degrees to select from.
875
+ If degrees is a number instead of sequence like (min, max), the range of degrees
876
+ will be (-degrees, +degrees). Will not apply shear by default
877
+ resample ({cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4}, optional):
878
+ An optional resampling filter. See `filters`_ for more information.
879
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
880
+ fillcolor (int): Optional fill color for the area outside the transform in the output image.
881
+ """
882
+ def __init__(self,
883
+ degrees,
884
+ translate=None,
885
+ scale=None,
886
+ shear=None,
887
+ interpolation=cv2.INTER_LINEAR,
888
+ fillcolor=0):
889
+ if isinstance(degrees, numbers.Number):
890
+ if degrees < 0:
891
+ raise ValueError(
892
+ "If degrees is a single number, it must be positive.")
893
+ self.degrees = (-degrees, degrees)
894
+ else:
895
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
896
+ "degrees should be a list or tuple and it must be of length 2."
897
+ self.degrees = degrees
898
+
899
+ if translate is not None:
900
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
901
+ "translate should be a list or tuple and it must be of length 2."
902
+ for t in translate:
903
+ if not (0.0 <= t <= 1.0):
904
+ raise ValueError(
905
+ "translation values should be between 0 and 1")
906
+ self.translate = translate
907
+
908
+ if scale is not None:
909
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
910
+ "scale should be a list or tuple and it must be of length 2."
911
+ for s in scale:
912
+ if s <= 0:
913
+ raise ValueError("scale values should be positive")
914
+ self.scale = scale
915
+
916
+ if shear is not None:
917
+ if isinstance(shear, numbers.Number):
918
+ if shear < 0:
919
+ raise ValueError(
920
+ "If shear is a single number, it must be positive.")
921
+ self.shear = (-shear, shear)
922
+ else:
923
+ assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
924
+ "shear should be a list or tuple and it must be of length 2."
925
+ self.shear = shear
926
+ else:
927
+ self.shear = shear
928
+
929
+ # self.resample = resample
930
+ self.interpolation = interpolation
931
+ self.fillcolor = fillcolor
932
+
933
+ @staticmethod
934
+ def get_params(degrees, translate, scale_ranges, shears, img_size):
935
+ """Get parameters for affine transformation
936
+ Returns:
937
+ sequence: params to be passed to the affine transformation
938
+ """
939
+ angle = random.uniform(degrees[0], degrees[1])
940
+ if translate is not None:
941
+ max_dx = translate[0] * img_size[0]
942
+ max_dy = translate[1] * img_size[1]
943
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
944
+ np.round(random.uniform(-max_dy, max_dy)))
945
+ else:
946
+ translations = (0, 0)
947
+
948
+ if scale_ranges is not None:
949
+ scale = random.uniform(scale_ranges[0], scale_ranges[1])
950
+ else:
951
+ scale = 1.0
952
+
953
+ if shears is not None:
954
+ shear = random.uniform(shears[0], shears[1])
955
+ else:
956
+ shear = 0.0
957
+
958
+ return angle, translations, scale, shear
959
+
960
+ def __call__(self, img):
961
+ """
962
+ img (numpy ndarray): Image to be transformed.
963
+ Returns:
964
+ numpy ndarray: Affine transformed image.
965
+ """
966
+ ret = self.get_params(self.degrees, self.translate, self.scale,
967
+ self.shear, (img.shape[1], img.shape[0]))
968
+ return F.affine(img,
969
+ *ret,
970
+ interpolation=self.interpolation,
971
+ fillcolor=self.fillcolor)
972
+
973
+ def __repr__(self):
974
+ s = '{name}(degrees={degrees}'
975
+ if self.translate is not None:
976
+ s += ', translate={translate}'
977
+ if self.scale is not None:
978
+ s += ', scale={scale}'
979
+ if self.shear is not None:
980
+ s += ', shear={shear}'
981
+ if self.resample > 0:
982
+ s += ', resample={resample}'
983
+ if self.fillcolor != 0:
984
+ s += ', fillcolor={fillcolor}'
985
+ s += ')'
986
+ d = dict(self.__dict__)
987
+ d['resample'] = _cv2_interpolation_to_str[d['resample']]
988
+ return s.format(name=self.__class__.__name__, **d)
989
+
990
+
991
+ class Grayscale(object):
992
+ """Convert image to grayscale.
993
+ Args:
994
+ num_output_channels (int): (1 or 3) number of channels desired for output image
995
+ Returns:
996
+ numpy ndarray: Grayscale version of the input.
997
+ - If num_output_channels == 1 : returned image is single channel
998
+ - If num_output_channels == 3 : returned image is 3 channel with r == g == b
999
+ """
1000
+ def __init__(self, num_output_channels=1):
1001
+ self.num_output_channels = num_output_channels
1002
+
1003
+ def __call__(self, img):
1004
+ """
1005
+ Args:
1006
+ img (numpy ndarray): Image to be converted to grayscale.
1007
+ Returns:
1008
+ numpy ndarray: Randomly grayscaled image.
1009
+ """
1010
+ return F.to_grayscale(img,
1011
+ num_output_channels=self.num_output_channels)
1012
+
1013
+ def __repr__(self):
1014
+ return self.__class__.__name__ + '(num_output_channels={0})'.format(
1015
+ self.num_output_channels)
1016
+
1017
+
1018
+ class RandomGrayscale(object):
1019
+ """Randomly convert image to grayscale with a probability of p (default 0.1).
1020
+ Args:
1021
+ p (float): probability that image should be converted to grayscale.
1022
+ Returns:
1023
+ numpy ndarray: Grayscale version of the input image with probability p and unchanged
1024
+ with probability (1-p).
1025
+ - If input image is 1 channel: grayscale version is 1 channel
1026
+ - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1027
+ """
1028
+ def __init__(self, p=0.1):
1029
+ self.p = p
1030
+
1031
+ def __call__(self, img):
1032
+ """
1033
+ Args:
1034
+ img (numpy ndarray): Image to be converted to grayscale.
1035
+ Returns:
1036
+ numpy ndarray: Randomly grayscaled image.
1037
+ """
1038
+ num_output_channels = 3
1039
+ if random.random() < self.p:
1040
+ return F.to_grayscale(img, num_output_channels=num_output_channels)
1041
+ return img
1042
+
1043
+ def __repr__(self):
1044
+ return self.__class__.__name__ + '(p={0})'.format(self.p)
dataset/setup.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+ with open('README.md', 'r') as fh:
4
+ long_description = fh.read()
5
+
6
+ setuptools.setup(
7
+ name='opencv_transforms',
8
+ version='0.0.6',
9
+ author='Jim Bohnslav',
10
+ author_email='[email protected]',
11
+ description='A drop-in replacement for Torchvision Transforms using OpenCV',
12
+ keywords='pytorch image augmentations',
13
+ long_description=long_description,
14
+ long_description_content_type='text/markdown',
15
+ url='https://github.com/jbohnslav/opencv_transforms',
16
+ packages=setuptools.find_packages(),
17
+ classifiers=[
18
+ "Programming Language :: Python :: 3",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Operating System :: OS Independent",
21
+ ],
22
+ python_requires='>=3.6',
23
+ )
dataset/tests/compare_to_pil_for_testing.ipynb ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import glob\n",
10
+ "import numpy as np\n",
11
+ "import random\n",
12
+ "\n",
13
+ "import cv2\n",
14
+ "import matplotlib.pyplot as plt\n",
15
+ "from PIL import Image\n",
16
+ "\n",
17
+ "from torchvision import transforms as pil_transforms\n",
18
+ "from torchvision.transforms import functional as F_pil\n",
19
+ "\n",
20
+ "import sys\n",
21
+ "sys.path.insert(0, '..')\n",
22
+ "from opencv_transforms import transforms\n",
23
+ "from opencv_transforms import functional as F\n",
24
+ "\n",
25
+ "from setup_testing_directory import get_testing_directory"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "datadir = get_testing_directory()\n",
35
+ "print(datadir)"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True)\n",
45
+ "train_images.sort()\n",
46
+ "print('Number of training images: {:,}'.format(len(train_images)))"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "random.seed(1)\n",
56
+ "imfile = random.choice(train_images)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "def plot_pil_and_opencv(pil_image, opencv_image, orientation='row'):\n",
66
+ " if orientation == 'row':\n",
67
+ " rows, cols = 1,3\n",
68
+ " size = (8, 4)\n",
69
+ " else: \n",
70
+ " rows, cols = 3,1\n",
71
+ " size = (12, 6)\n",
72
+ " fig, axes = plt.subplots(rows, cols,figsize=size)\n",
73
+ " ax = axes[0]\n",
74
+ " ax.imshow(pil_image)\n",
75
+ " ax.set_title('PIL')\n",
76
+ "\n",
77
+ " ax = axes[1]\n",
78
+ " ax.imshow(opencv_image)\n",
79
+ " ax.set_title('opencv')\n",
80
+ "\n",
81
+ " ax = axes[2]\n",
82
+ " l1 = np.abs(pil_image - opencv_image).mean(axis=2)\n",
83
+ " ax.imshow(l1)\n",
84
+ " ax.set_title('| PIL - opencv|\\nMAE:{:.4f}'.format(l1.mean()))\n",
85
+ " plt.tight_layout()\n",
86
+ " plt.show()"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "pil_image = Image.open(imfile)\n",
96
+ "image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB)"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "plot_pil_and_opencv(pil_image, image)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "pil_resized = pil_transforms.Resize((224, 224))(pil_image)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "resized = transforms.Resize(224)(image)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "plot_pil_and_opencv(pil_resized, resized)\n",
133
+ "plt.show()"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "def L1(pil: Image, image: np.ndarray) -> float:\n",
143
+ " return np.mean(np.abs(np.asarray(pil) - image))"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "TOL = 1e-4\n",
153
+ "\n",
154
+ "l1 = L1(pil_resized, resized)\n",
155
+ "assert l1 - 88.9559 < TOL"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "random.seed(1)\n",
165
+ "pil = pil_transforms.RandomRotation(10)(pil_image)\n",
166
+ "random.seed(1)\n",
167
+ "np_img = transforms.RandomRotation(10)(image)"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "plot_pil_and_opencv(pil, np_img)"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "pil = pil_transforms.FiveCrop((224, 224))(pil_image)\n",
186
+ "cv = transforms.FiveCrop((224,224))(image)"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "pil_stacked = np.hstack([np.asarray(i) for i in pil])\n",
196
+ "cv_stacked = np.hstack(cv)\n",
197
+ "\n",
198
+ "plot_pil_and_opencv(pil_stacked, cv_stacked, orientation='col')"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "pil_stacked.shape"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "l1"
217
+ ]
218
+ }
219
+ ],
220
+ "metadata": {
221
+ "kernelspec": {
222
+ "display_name": "opencv_transforms",
223
+ "language": "python",
224
+ "name": "opencv_transforms"
225
+ },
226
+ "language_info": {
227
+ "codemirror_mode": {
228
+ "name": "ipython",
229
+ "version": 3
230
+ },
231
+ "file_extension": ".py",
232
+ "mimetype": "text/x-python",
233
+ "name": "python",
234
+ "nbconvert_exporter": "python",
235
+ "pygments_lexer": "ipython3",
236
+ "version": "3.7.9"
237
+ }
238
+ },
239
+ "nbformat": 4,
240
+ "nbformat_minor": 4
241
+ }
dataset/tests/setup_testing_directory.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Union
4
+ import warnings
5
+
6
+
7
+ def get_testing_directory() -> str:
8
+ directory_file = 'testing_directory.txt'
9
+ directory_files = [directory_file, os.path.join('tests', directory_file)]
10
+
11
+ for directory_file in directory_files:
12
+ if os.path.isfile(directory_file):
13
+ with open(directory_file, 'r') as f:
14
+ testing_directory = f.read()
15
+ return testing_directory
16
+ raise ValueError('please run setup_testing_directory.py before attempting to run unit tests')
17
+
18
+
19
+ def setup_testing_directory(datadir: Union[str, os.PathLike], overwrite: bool = False) -> str:
20
+ testing_path_file = 'testing_directory.txt'
21
+
22
+ should_setup = True
23
+ if os.path.isfile(testing_path_file):
24
+ with open(testing_path_file, 'r') as f:
25
+ testing_directory = f.read()
26
+ if not os.path.isfile(testing_directory):
27
+ raise ValueError('saved testing directory {} does not exist, re-run ')
28
+ warnings.warn(
29
+ 'Saved testing directory {} does not exist, downloading Thumos14...'.format(testing_directory))
30
+ else:
31
+ should_setup = False
32
+ if not should_setup:
33
+ return testing_directory
34
+
35
+ testing_directory = datadir
36
+ assert os.path.isdir(testing_directory)
37
+ assert os.path.isdir(os.path.join(testing_directory, 'train'))
38
+ assert os.path.isdir(os.path.join(testing_directory, 'val'))
39
+ with open('testing_directory.txt', 'w') as f:
40
+ f.write(testing_directory)
41
+ return testing_directory
42
+
43
+
44
+ if __name__ == '__main__':
45
+ parser = argparse.ArgumentParser('Setting up image directory for opencv transforms testing')
46
+ parser.add_argument('-d', '--datadir', default=os.getcwd(), help='Imagenet directory')
47
+
48
+ args = parser.parse_args()
49
+
50
+ setup_testing_directory(args.datadir)
dataset/tests/test_color.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import numpy as np
3
+ import random
4
+ from typing import Union
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ from PIL.Image import Image as PIL_image # for typing
10
+ import pytest
11
+ from torchvision import transforms as pil_transforms
12
+ from torchvision.transforms import functional as F_pil
13
+
14
+ from opencv_transforms import transforms
15
+ from opencv_transforms import functional as F
16
+ from setup_testing_directory import get_testing_directory
17
+
18
+ TOL = 1e-4
19
+
20
+ datadir = get_testing_directory()
21
+ train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True)
22
+ train_images.sort()
23
+ print('Number of training images: {:,}'.format(len(train_images)))
24
+
25
+ random.seed(1)
26
+ imfile = random.choice(train_images)
27
+ pil_image = Image.open(imfile)
28
+ image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB)
29
+
30
+
31
+ class TestContrast:
32
+ @pytest.mark.parametrize('random_seed', [1, 2, 3, 4])
33
+ @pytest.mark.parametrize('contrast_factor', [0.0, 0.5, 1.0, 2.0])
34
+ def test_contrast(self, contrast_factor, random_seed):
35
+ random.seed(random_seed)
36
+ imfile = random.choice(train_images)
37
+ pil_image = Image.open(imfile)
38
+ image = np.array(pil_image).copy()
39
+
40
+ pil_enhanced = F_pil.adjust_contrast(pil_image, contrast_factor)
41
+ np_enhanced = F.adjust_contrast(image, contrast_factor)
42
+ assert np.array_equal(np.array(pil_enhanced), np_enhanced.squeeze())
43
+
44
+ @pytest.mark.parametrize('n_images', [1, 11])
45
+ def test_multichannel_contrast(self, n_images, contrast_factor=0.1):
46
+ imfile = random.choice(train_images)
47
+
48
+ pil_image = Image.open(imfile)
49
+ image = np.array(pil_image).copy()
50
+
51
+ multichannel_image = np.concatenate([image for _ in range(n_images)], axis=-1)
52
+ # this will raise an exception in version 0.0.5
53
+ np_enchanced = F.adjust_contrast(multichannel_image, contrast_factor)
54
+
55
+ @pytest.mark.parametrize('contrast_factor', [0, 0.5, 1.0])
56
+ def test_grayscale_contrast(self, contrast_factor):
57
+ imfile = random.choice(train_images)
58
+
59
+ pil_image = Image.open(imfile)
60
+ image = np.array(pil_image).copy()
61
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
62
+
63
+ # make sure grayscale images work
64
+ pil_image = pil_image.convert('L')
65
+
66
+ pil_enhanced = F_pil.adjust_contrast(pil_image, contrast_factor)
67
+ np_enhanced = F.adjust_contrast(image, contrast_factor)
68
+ assert np.array_equal(np.array(pil_enhanced), np_enhanced.squeeze())
dataset/tests/test_spatial.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import numpy as np
3
+ import random
4
+ from typing import Union
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ from PIL.Image import Image as PIL_image # for typing
10
+
11
+ from torchvision import transforms as pil_transforms
12
+ from torchvision.transforms import functional as F_pil
13
+ from opencv_transforms import transforms
14
+ from opencv_transforms import functional as F
15
+
16
+ from setup_testing_directory import get_testing_directory
17
+ from utils import L1
18
+
19
+ TOL = 1e-4
20
+
21
+ datadir = get_testing_directory()
22
+ train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True)
23
+ train_images.sort()
24
+ print('Number of training images: {:,}'.format(len(train_images)))
25
+
26
+ random.seed(1)
27
+ imfile = random.choice(train_images)
28
+ pil_image = Image.open(imfile)
29
+ image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB)
30
+
31
+
32
+ def test_resize():
33
+ pil_resized = pil_transforms.Resize((224, 224))(pil_image)
34
+ resized = transforms.Resize((224, 224))(image)
35
+ l1 = L1(pil_resized, resized)
36
+ assert l1 - 88.9559 < TOL
37
+
38
+ def test_rotation():
39
+ random.seed(1)
40
+ pil = pil_transforms.RandomRotation(10)(pil_image)
41
+ random.seed(1)
42
+ np_img = transforms.RandomRotation(10)(image)
43
+ l1 = L1(pil, np_img)
44
+ assert l1 - 86.7955 < TOL
45
+
46
+ def test_five_crop():
47
+ pil = pil_transforms.FiveCrop((224, 224))(pil_image)
48
+ cv = transforms.FiveCrop((224, 224))(image)
49
+ pil_stacked = np.hstack([np.asarray(i) for i in pil])
50
+ cv_stacked = np.hstack(cv)
51
+ l1 = L1(pil_stacked, cv_stacked)
52
+ assert l1 - 22.0444 < TOL
dataset/tests/utils.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ from PIL.Image import Image as PIL_image # for typing
5
+
6
+
7
+ def L1(pil: Union[PIL_image, np.ndarray], np_image: np.ndarray) -> float:
8
+ return np.abs(np.asarray(pil) - np_image).mean()
inference.yaml ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: vtdm.vtdm_gen_v01.VideoLDM
3
+ base_learning_rate: 1.0e-05
4
+ params:
5
+ input_key: video
6
+ scale_factor: 0.18215
7
+ log_keys: caption
8
+ num_samples: 25 #frame_rate
9
+ trained_param_keys:
10
+ - diffusion_model.label_emb.0.0.weight
11
+ - .emb_layers.
12
+ - .time_stack.
13
+ en_and_decode_n_samples_a_time: 25 #frame_rate
14
+ disable_first_stage_autocast: true
15
+ denoiser_config:
16
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
17
+ params:
18
+ scaling_config:
19
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
20
+ network_config:
21
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
22
+ params:
23
+ adm_in_channels: 768
24
+ num_classes: sequential
25
+ use_checkpoint: true
26
+ in_channels: 8
27
+ out_channels: 4
28
+ model_channels: 320
29
+ attention_resolutions:
30
+ - 4
31
+ - 2
32
+ - 1
33
+ num_res_blocks: 2
34
+ channel_mult:
35
+ - 1
36
+ - 2
37
+ - 4
38
+ - 4
39
+ num_head_channels: 64
40
+ use_linear_in_transformer: true
41
+ transformer_depth: 1
42
+ context_dim: 1024
43
+ spatial_transformer_attn_type: softmax-xformers
44
+ extra_ff_mix_layer: true
45
+ use_spatial_context: true
46
+ merge_strategy: learned_with_images
47
+ video_kernel_size:
48
+ - 3
49
+ - 1
50
+ - 1
51
+ conditioner_config:
52
+ target: sgm.modules.GeneralConditioner
53
+ params:
54
+ emb_models:
55
+ - is_trainable: false
56
+ input_key: cond_frames_without_noise
57
+ ucg_rate: 0.1
58
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
59
+ params:
60
+ n_cond_frames: 1
61
+ n_copies: 1
62
+ open_clip_embedding_config:
63
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
64
+ params:
65
+ version: ckpts/open_clip_pytorch_model.bin
66
+ freeze: true
67
+ - is_trainable: false
68
+ input_key: video
69
+ ucg_rate: 0.0
70
+ target: vtdm.encoders.AesEmbedder
71
+ - is_trainable: false
72
+ input_key: elevation
73
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
74
+ params:
75
+ outdim: 256
76
+ - input_key: cond_frames
77
+ is_trainable: false
78
+ ucg_rate: 0.1
79
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
80
+ params:
81
+ disable_encoder_autocast: true
82
+ n_cond_frames: 1
83
+ n_copies: 25 #frame_rate
84
+ is_ae: true
85
+ encoder_config:
86
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
87
+ params:
88
+ embed_dim: 4
89
+ monitor: val/rec_loss
90
+ ddconfig:
91
+ attn_type: vanilla-xformers
92
+ double_z: true
93
+ z_channels: 4
94
+ resolution: 256
95
+ in_channels: 3
96
+ out_ch: 3
97
+ ch: 128
98
+ ch_mult:
99
+ - 1
100
+ - 2
101
+ - 4
102
+ - 4
103
+ num_res_blocks: 2
104
+ attn_resolutions: []
105
+ dropout: 0.0
106
+ lossconfig:
107
+ target: torch.nn.Identity
108
+ - input_key: cond_aug
109
+ is_trainable: false
110
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
111
+ params:
112
+ outdim: 256
113
+ first_stage_config:
114
+ target: sgm.models.autoencoder.AutoencoderKL
115
+ params:
116
+ embed_dim: 4
117
+ monitor: val/rec_loss
118
+ ddconfig:
119
+ attn_type: vanilla-xformers
120
+ double_z: true
121
+ z_channels: 4
122
+ resolution: 256
123
+ in_channels: 3
124
+ out_ch: 3
125
+ ch: 128
126
+ ch_mult:
127
+ - 1
128
+ - 2
129
+ - 4
130
+ - 4
131
+ num_res_blocks: 2
132
+ attn_resolutions: []
133
+ dropout: 0.0
134
+ lossconfig:
135
+ target: torch.nn.Identity
136
+ loss_fn_config:
137
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
138
+ params:
139
+ num_frames: 25 #frame_rate
140
+ batch2model_keys:
141
+ - num_video_frames
142
+ - image_only_indicator
143
+ sigma_sampler_config:
144
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
145
+ params:
146
+ p_mean: 1.0
147
+ p_std: 1.6
148
+ loss_weighting_config:
149
+ target: sgm.modules.diffusionmodules.loss_weighting.VWeighting
150
+ sampler_config:
151
+ target: sgm.modules.diffusionmodules.sampling.LinearMultistepSampler
152
+ params:
153
+ num_steps: 50
154
+ verbose: True
155
+
156
+ discretization_config:
157
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
158
+ params:
159
+ sigma_max: 700.0
160
+
161
+ guider_config:
162
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
163
+ params:
164
+ num_frames: 25 #frame_rate
165
+ max_scale: 2.5
166
+ min_scale: 1.0
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libopencv-dev
2
+ build-essential
pipeline.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from PIL import Image
3
+ import numpy as np
4
+ import math
5
+ import random
6
+ import cv2
7
+ from typing import List
8
+
9
+ import torch
10
+ import einops
11
+ from pytorch_lightning import seed_everything
12
+ from transparent_background import Remover
13
+
14
+ from dataset.opencv_transforms.functional import to_tensor, center_crop
15
+ from vtdm.model import create_model
16
+ from vtdm.util import tensor2vid
17
+
18
+ remover = Remover(jit=False)
19
+
20
+ def cv2_to_pil(cv_image: np.ndarray) -> Image.Image:
21
+ return Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
22
+
23
+
24
+ def pil_to_cv2(pil_image: Image.Image) -> np.ndarray:
25
+ cv_image = np.array(pil_image)
26
+ cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
27
+ return cv_image
28
+
29
+ def prepare_white_image(input_image: Image.Image) -> Image.Image:
30
+ # remove bg
31
+ output = remover.process(input_image, type='rgba')
32
+
33
+ # expand image
34
+ width, height = output.size
35
+ max_side = max(width, height)
36
+ white_image = Image.new('RGBA', (max_side, max_side), (0, 0, 0, 0))
37
+ x_offset = (max_side - width) // 2
38
+ y_offset = (max_side - height) // 2
39
+ white_image.paste(output, (x_offset, y_offset))
40
+
41
+ return white_image
42
+
43
+
44
+ class MultiViewGenerator:
45
+ def __init__(self, checkpoint_path, config_path="inference.yaml"):
46
+ self.models = {}
47
+ denoising_model = create_model(config_path).cpu()
48
+ denoising_model.init_from_ckpt(checkpoint_path)
49
+ denoising_model = denoising_model.cuda().half()
50
+ self.models["denoising_model"] = denoising_model
51
+
52
+ def denoising(self, frames, args):
53
+ with torch.no_grad():
54
+ C, T, H, W = frames.shape
55
+ batch = {"video": frames.unsqueeze(0)}
56
+ batch["elevation"] = (
57
+ torch.Tensor([args["elevation"]]).to(torch.int64).to(frames.device)
58
+ )
59
+ batch["fps_id"] = torch.Tensor([7]).to(torch.int64).to(frames.device)
60
+ batch["motion_bucket_id"] = (
61
+ torch.Tensor([127]).to(torch.int64).to(frames.device)
62
+ )
63
+ batch = self.models["denoising_model"].add_custom_cond(batch, infer=True)
64
+
65
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
66
+ c, uc = self.models[
67
+ "denoising_model"
68
+ ].conditioner.get_unconditional_conditioning(
69
+ batch,
70
+ force_uc_zero_embeddings=["cond_frames", "cond_frames_without_noise"],
71
+ )
72
+
73
+ additional_model_inputs = {
74
+ "image_only_indicator": torch.zeros(2, T).to(
75
+ self.models["denoising_model"].device
76
+ ),
77
+ "num_video_frames": batch["num_video_frames"],
78
+ }
79
+
80
+ def denoiser(input, sigma, c):
81
+ return self.models["denoising_model"].denoiser(
82
+ self.models["denoising_model"].model,
83
+ input,
84
+ sigma,
85
+ c,
86
+ **additional_model_inputs
87
+ )
88
+
89
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
90
+ randn = torch.randn(
91
+ [T, 4, H // 8, W // 8], device=self.models["denoising_model"].device
92
+ )
93
+ samples = self.models["denoising_model"].sampler(denoiser, randn, cond=c, uc=uc)
94
+
95
+ samples = self.models["denoising_model"].decode_first_stage(samples.half())
96
+ samples = einops.rearrange(samples, "(b t) c h w -> b c t h w", t=T)
97
+
98
+ return tensor2vid(samples)
99
+
100
+ def video_pipeline(self, frames, args) -> List[Image.Image]:
101
+ num_iter = args["num_iter"]
102
+ out_list = []
103
+
104
+ for _ in range(num_iter):
105
+ with torch.no_grad():
106
+ results = self.denoising(frames, args)
107
+
108
+ if len(out_list) == 0:
109
+ out_list = out_list + results
110
+ else:
111
+ out_list = out_list + results[1:]
112
+
113
+ img = out_list[-1]
114
+ img = to_tensor(img)
115
+ img = (img - 0.5) * 2.0
116
+ frames[:, 0] = img
117
+
118
+ result = []
119
+
120
+ for i, frame in enumerate(out_list):
121
+ input_image = cv2_to_pil(frame)
122
+ output_image = remover.process(input_image, type='rgba')
123
+ result.append(output_image)
124
+
125
+ return result
126
+
127
+ def process(self, white_image: Image.Image, args) -> List[Image.Image]:
128
+ img = pil_to_cv2(white_image)
129
+ frame_list = [img] * args["clip_size"]
130
+
131
+ h, w = frame_list[0].shape[0:2]
132
+ rate = max(
133
+ args["input_resolution"][0] * 1.0 / h, args["input_resolution"][1] * 1.0 / w
134
+ )
135
+ frame_list = [
136
+ cv2.resize(f, [math.ceil(w * rate), math.ceil(h * rate)]) for f in frame_list
137
+ ]
138
+ frame_list = [
139
+ center_crop(f, [args["input_resolution"][0], args["input_resolution"][1]])
140
+ for f in frame_list
141
+ ]
142
+ frame_list = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frame_list]
143
+
144
+ frame_list = [to_tensor(f) for f in frame_list]
145
+ frame_list = [(f - 0.5) * 2.0 for f in frame_list]
146
+ frames = torch.stack(frame_list, 1)
147
+ frames = frames.cuda()
148
+
149
+ self.models["denoising_model"].num_samples = args["clip_size"]
150
+ self.models["denoising_model"].image_size = args["input_resolution"]
151
+
152
+ return self.video_pipeline(frames, args)
153
+
154
+ def infer(self, white_image: Image.Image) -> List[Image.Image]:
155
+ seed = random.randint(0, 65535)
156
+ seed_everything(seed)
157
+
158
+ params = {
159
+ "clip_size": 25,
160
+ "input_resolution": [512, 512],
161
+ "num_iter": 1,
162
+ "aes": 6.0,
163
+ "mv": [0.0, 0.0, 0.0, 10.0],
164
+ "elevation": 0,
165
+ }
166
+
167
+ return self.process(white_image, params)
168
+
requirements.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ av
2
+ black==23.7.0
3
+ chardet==5.1.0
4
+ clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
5
+ cupy-cuda113
6
+ einops>=0.6.1
7
+ fairscale>=0.4.13
8
+ fire>=0.5.0
9
+ fsspec>=2023.6.0
10
+ invisible-watermark>=0.2.0
11
+ kornia==0.6.9
12
+ matplotlib>=3.7.2
13
+ natsort>=8.4.0
14
+ ninja>=1.11.1
15
+ numpy==1.26.4
16
+ omegaconf>=2.3.0
17
+ open-clip-torch>=2.20.0
18
+ opencv-python==4.6.0.66
19
+ pandas>=2.0.3
20
+ pillow>=9.5.0
21
+ pudb>=2022.1.3
22
+ pytorch-lightning==1.9
23
+ pyyaml>=6.0.1
24
+ transparent_background
25
+ scipy>=1.10.1
26
+ streamlit>=0.73.1
27
+ tensorboardx==2.6
28
+ timm>=0.9.2
29
+ tokenizers==0.12.1
30
+ torch>=2.1.0
31
+ torchaudio>=2.1.0
32
+ torchdata>=0.6.1
33
+ torchmetrics>=1.0.1
34
+ torchvision>=0.16.0
35
+ tqdm>=4.65.0
36
+ transformers==4.19.1
37
+ triton>=2.0.0
38
+ urllib3<1.27,>=1.25.4
39
+ wandb>=0.15.6
40
+ webdataset>=0.2.33
41
+ wheel>=0.41.0
42
+ xformers>=0.0.20
43
+ gradio
44
+ streamlit-keyup==0.2.0
45
+ deepspeed==0.14.5
46
+ test-tube
47
+ -e git+https://github.com/Stability-AI/datapipelines.git@8bce77d147033b3a5285b6d45ee85f33866964fc#egg=sdata
48
+ basicsr
49
+ pillow-heif
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import StableDataModuleFromConfig
sgm/data/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torchdata.datapipes.iter
4
+ import webdataset as wds
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import LightningDataModule
7
+
8
+ try:
9
+ from sdata import create_dataset, create_dummy_dataset, create_loader
10
+ except ImportError as e:
11
+ print("#" * 100)
12
+ print("Datasets not yet available")
13
+ print("to enable, we need to add stable-datasets as a submodule")
14
+ print("please use ``git submodule update --init --recursive``")
15
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
+ print("#" * 100)
17
+ exit(1)
18
+
19
+
20
+ class StableDataModuleFromConfig(LightningDataModule):
21
+ def __init__(
22
+ self,
23
+ train: DictConfig,
24
+ validation: Optional[DictConfig] = None,
25
+ test: Optional[DictConfig] = None,
26
+ skip_val_loader: bool = False,
27
+ dummy: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.train_config = train
31
+ assert (
32
+ "datapipeline" in self.train_config and "loader" in self.train_config
33
+ ), "train config requires the fields `datapipeline` and `loader`"
34
+
35
+ self.val_config = validation
36
+ if not skip_val_loader:
37
+ if self.val_config is not None:
38
+ assert (
39
+ "datapipeline" in self.val_config and "loader" in self.val_config
40
+ ), "validation config requires the fields `datapipeline` and `loader`"
41
+ else:
42
+ print(
43
+ "Warning: No Validation datapipeline defined, using that one from training"
44
+ )
45
+ self.val_config = train
46
+
47
+ self.test_config = test
48
+ if self.test_config is not None:
49
+ assert (
50
+ "datapipeline" in self.test_config and "loader" in self.test_config
51
+ ), "test config requires the fields `datapipeline` and `loader`"
52
+
53
+ self.dummy = dummy
54
+ if self.dummy:
55
+ print("#" * 100)
56
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
+ print("#" * 100)
58
+
59
+ def setup(self, stage: str) -> None:
60
+ print("Preparing datasets")
61
+ if self.dummy:
62
+ data_fn = create_dummy_dataset
63
+ else:
64
+ data_fn = create_dataset
65
+
66
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
+ if self.val_config:
68
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
+ if self.test_config:
70
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
+
72
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
+ return loader
75
+
76
+ def val_dataloader(self) -> wds.DataPipeline:
77
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
78
+
79
+ def test_dataloader(self) -> wds.DataPipeline:
80
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
sgm/data/video_dataset.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import numpy as np
3
+ import torch
4
+ import PIL
5
+ import os
6
+ import random
7
+ from skimage.io import imread
8
+ import webdataset as wds
9
+ import PIL.Image as Image
10
+ from torch.utils.data import Dataset
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from pathlib import Path
13
+
14
+ # from ldm.base_utils import read_pickle, pose_inverse
15
+ import torchvision.transforms as transforms
16
+ import torchvision
17
+ from einops import rearrange
18
+
19
+ def add_margin(pil_img, color=0, size=256):
20
+ width, height = pil_img.size
21
+ result = Image.new(pil_img.mode, (size, size), color)
22
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
23
+ return result
24
+
25
+ def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256):
26
+ image_input = Image.open(image_path)
27
+
28
+ if crop_size!=-1:
29
+ alpha_np = np.asarray(image_input)[:, :, 3]
30
+ coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
31
+ min_x, min_y = np.min(coords, 0)
32
+ max_x, max_y = np.max(coords, 0)
33
+ ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
34
+ h, w = ref_img_.height, ref_img_.width
35
+ scale = crop_size / max(h, w)
36
+ h_, w_ = int(scale * h), int(scale * w)
37
+ ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
38
+ image_input = add_margin(ref_img_, size=image_size)
39
+ else:
40
+ image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
41
+ image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
42
+
43
+ image_input = np.asarray(image_input)
44
+ image_input = image_input.astype(np.float32) / 255.0
45
+ ref_mask = image_input[:, :, 3:]
46
+ image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
47
+ image_input = image_input[:, :, :3] * 2.0 - 1.0
48
+ image_input = torch.from_numpy(image_input.astype(np.float32))
49
+ elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
50
+ return {"input_image": image_input, "input_elevation": elevation_input}
51
+
52
+
53
+ class VideoTrainDataset(Dataset):
54
+ def __init__(self, base_folder='/data/yanghaibo/datas/OBJAVERSE-LVIS/images', width=1024, height=576, sample_frames=25):
55
+ """
56
+ Args:
57
+ num_samples (int): Number of samples in the dataset.
58
+ channels (int): Number of channels, default is 3 for RGB.
59
+ """
60
+ # Define the path to the folder containing video frames
61
+ self.base_folder = base_folder
62
+ self.folders = os.listdir(self.base_folder)
63
+ self.num_samples = len(self.folders)
64
+ self.channels = 3
65
+ self.width = width
66
+ self.height = height
67
+ self.sample_frames = sample_frames
68
+ self.elevations = [-10, 0, 10, 20, 30, 40]
69
+
70
+ def __len__(self):
71
+ return self.num_samples
72
+
73
+ def load_im(self, path):
74
+ img = imread(path)
75
+ img = img.astype(np.float32) / 255.0
76
+ mask = img[:,:,3:]
77
+ img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background
78
+ img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
79
+ return img, mask
80
+
81
+ def __getitem__(self, idx):
82
+ """
83
+ Args:
84
+ idx (int): Index of the sample to return.
85
+
86
+ Returns:
87
+ dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
88
+ """
89
+ # Randomly select a folder (representing a video) from the base folder
90
+ chosen_folder = random.choice(self.folders)
91
+ folder_path = os.path.join(self.base_folder, chosen_folder)
92
+ frames = os.listdir(folder_path)
93
+ # Sort the frames by name
94
+ frames.sort()
95
+
96
+ # Ensure the selected folder has at least `sample_frames`` frames
97
+ if len(frames) < self.sample_frames:
98
+ raise ValueError(
99
+ f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.")
100
+
101
+ # Randomly select a start index for frame sequence. Fixed elevation
102
+ start_idx = random.randint(0, len(frames) - 1)
103
+ range_id = int(start_idx / 16) # 0, 1, 2, 3, 4, 5
104
+ elevation = self.elevations[range_id]
105
+ selected_frames = []
106
+
107
+ for frame_idx in range(start_idx, (range_id + 1) * 16):
108
+ selected_frames.append(frames[frame_idx])
109
+ for frame_idx in range((range_id) * 16, start_idx):
110
+ selected_frames.append(frames[frame_idx])
111
+
112
+ # Initialize a tensor to store the pixel values
113
+ pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width))
114
+
115
+ # Load and process each frame
116
+ for i, frame_name in enumerate(selected_frames):
117
+ frame_path = os.path.join(folder_path, frame_name)
118
+ img, mask = self.load_im(frame_path)
119
+ # Resize the image and convert it to a tensor
120
+ img_resized = img.resize((self.width, self.height))
121
+ img_tensor = torch.from_numpy(np.array(img_resized)).float()
122
+
123
+ # Normalize the image by scaling pixel values to [-1, 1]
124
+ img_normalized = img_tensor / 127.5 - 1
125
+
126
+ # Rearrange channels if necessary
127
+ if self.channels == 3:
128
+ img_normalized = img_normalized.permute(
129
+ 2, 0, 1) # For RGB images
130
+ elif self.channels == 1:
131
+ img_normalized = img_normalized.mean(
132
+ dim=2, keepdim=True) # For grayscale images
133
+
134
+ pixel_values[i] = img_normalized
135
+
136
+ pixel_values = rearrange(pixel_values, 't c h w -> c t h w')
137
+
138
+ caption = chosen_folder + "_" + str(start_idx)
139
+
140
+ return {'video': pixel_values, 'elevation': elevation, 'caption': caption, "fps_id": 7, "motion_bucket_id": 127}
141
+
142
+ class SyncDreamerEvalData(Dataset):
143
+ def __init__(self, image_dir):
144
+ self.image_size = 512
145
+ self.image_dir = Path(image_dir)
146
+ self.crop_size = 20
147
+
148
+ self.fns = []
149
+ for fn in Path(image_dir).iterdir():
150
+ if fn.suffix=='.png':
151
+ self.fns.append(fn)
152
+ print('============= length of dataset %d =============' % len(self.fns))
153
+
154
+ def __len__(self):
155
+ return len(self.fns)
156
+
157
+ def get_data_for_index(self, index):
158
+ input_img_fn = self.fns[index]
159
+ elevation = 0
160
+ return prepare_inputs(input_img_fn, elevation, 512)
161
+
162
+ def __getitem__(self, index):
163
+ return self.get_data_for_index(index)
164
+
165
+ class VideoDataset(pl.LightningDataModule):
166
+ def __init__(self, base_folder, eval_folder, width, height, sample_frames, batch_size, num_workers=4, seed=0, **kwargs):
167
+ super().__init__()
168
+ self.base_folder = base_folder
169
+ self.eval_folder = eval_folder
170
+ self.width = width
171
+ self.height = height
172
+ self.sample_frames = sample_frames
173
+ self.batch_size = batch_size
174
+ self.num_workers = num_workers
175
+ self.seed = seed
176
+ self.additional_args = kwargs
177
+
178
+ def setup(self):
179
+ self.train_dataset = VideoTrainDataset(self.base_folder, self.width, self.height, self.sample_frames)
180
+ self.val_dataset = SyncDreamerEvalData(image_dir=self.eval_folder)
181
+
182
+ def train_dataloader(self):
183
+ sampler = DistributedSampler(self.train_dataset, seed=self.seed)
184
+ return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
185
+
186
+ def val_dataloader(self):
187
+ loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
188
+ return loader
189
+
190
+ def test_dataloader(self):
191
+ return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
sgm/data/video_dataset_stage2_degradeImages.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import numpy as np
3
+ import torch
4
+ import PIL
5
+ import os
6
+ import random
7
+ from skimage.io import imread
8
+ import webdataset as wds
9
+ import PIL.Image as Image
10
+ from torch.utils.data import Dataset
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from pathlib import Path
13
+
14
+ # from ldm.base_utils import read_pickle, pose_inverse
15
+ import torchvision.transforms as transforms
16
+ import torchvision
17
+ from einops import rearrange
18
+
19
+ # for the degraded images
20
+ import yaml
21
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
22
+ import math
23
+
24
+ def add_margin(pil_img, color=0, size=256):
25
+ width, height = pil_img.size
26
+ result = Image.new(pil_img.mode, (size, size), color)
27
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
28
+ return result
29
+
30
+ def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256):
31
+ image_input = Image.open(image_path)
32
+
33
+ if crop_size!=-1:
34
+ alpha_np = np.asarray(image_input)[:, :, 3]
35
+ coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
36
+ min_x, min_y = np.min(coords, 0)
37
+ max_x, max_y = np.max(coords, 0)
38
+ ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
39
+ h, w = ref_img_.height, ref_img_.width
40
+ scale = crop_size / max(h, w)
41
+ h_, w_ = int(scale * h), int(scale * w)
42
+ ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
43
+ image_input = add_margin(ref_img_, size=image_size)
44
+ else:
45
+ image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
46
+ image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
47
+
48
+ image_input = np.asarray(image_input)
49
+ image_input = image_input.astype(np.float32) / 255.0
50
+ ref_mask = image_input[:, :, 3:]
51
+ image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
52
+ image_input = image_input[:, :, :3] * 2.0 - 1.0
53
+ image_input = torch.from_numpy(image_input.astype(np.float32))
54
+ elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
55
+ return {"input_image": image_input, "input_elevation": elevation_input}
56
+
57
+
58
+ class VideoTrainDataset(Dataset):
59
+ def __init__(self, base_folder='/data/yanghaibo/datas/OBJAVERSE-LVIS/images', depth_folder="/mnt/drive2/3d/OBJAVERSE-DEPTH/depth256", width=1024, height=576, sample_frames=25):
60
+ """
61
+ Args:
62
+ num_samples (int): Number of samples in the dataset.
63
+ channels (int): Number of channels, default is 3 for RGB.
64
+ """
65
+ # Define the path to the folder containing video frames
66
+ self.base_folder = base_folder
67
+ self.depth_folder = depth_folder
68
+ # self.folders1 = os.listdir(self.base_folder)
69
+ # self.folders2 = os.listdir(self.depth_folder)
70
+ # self.folders = list(set(self.folders1).intersection(set(self.folders2)))
71
+ self.folders = os.listdir(self.base_folder)
72
+ self.num_samples = len(self.folders)
73
+ self.channels = 3
74
+ self.width = width
75
+ self.height = height
76
+ self.sample_frames = sample_frames
77
+ self.elevations = [-10, 0, 10, 20, 30, 40]
78
+
79
+ # for degraded images
80
+ with open('configs/train_realesrnet_x4plus.yml', mode='r') as f:
81
+ opt = yaml.load(f, Loader=yaml.FullLoader)
82
+ self.opt = opt
83
+ # blur settings for the first degradation
84
+ self.blur_kernel_size = opt['blur_kernel_size']
85
+ self.kernel_list = opt['kernel_list']
86
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
87
+ self.blur_sigma = opt['blur_sigma']
88
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
89
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
90
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
91
+
92
+ # blur settings for the second degradation
93
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
94
+ self.kernel_list2 = opt['kernel_list2']
95
+ self.kernel_prob2 = opt['kernel_prob2']
96
+ self.blur_sigma2 = opt['blur_sigma2']
97
+ self.betag_range2 = opt['betag_range2']
98
+ self.betap_range2 = opt['betap_range2']
99
+ self.sinc_prob2 = opt['sinc_prob2']
100
+
101
+ # a final sinc filter
102
+ self.final_sinc_prob = opt['final_sinc_prob']
103
+
104
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
105
+ # TODO: kernel range is now hard-coded, should be in the configure file
106
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
107
+ self.pulse_tensor[10, 10] = 1
108
+
109
+
110
+ def __len__(self):
111
+ return self.num_samples
112
+
113
+ def load_im(self, path):
114
+ img = imread(path)
115
+ img = img.astype(np.float32) / 255.0
116
+ mask = img[:,:,3:]
117
+ img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background
118
+ img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
119
+ return img, mask
120
+
121
+ def __getitem__(self, idx):
122
+ """
123
+ Args:
124
+ idx (int): Index of the sample to return.
125
+
126
+ Returns:
127
+ dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
128
+ """
129
+ # Randomly select a folder (representing a video) from the base folder
130
+ chosen_folder = random.choice(self.folders)
131
+ folder_path = os.path.join(self.base_folder, chosen_folder)
132
+ frames = os.listdir(folder_path)
133
+ # Sort the frames by name
134
+ frames.sort()
135
+
136
+ # Ensure the selected folder has at least `sample_frames`` frames
137
+ if len(frames) < self.sample_frames:
138
+ raise ValueError(
139
+ f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.")
140
+
141
+ # Randomly select a start index for frame sequence. Fixed elevation
142
+ start_idx = random.randint(0, len(frames) - 1)
143
+ # start_idx = random.choice([0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, 84, 88, 92])
144
+ range_id = int(start_idx / 16) # 0, 1, 2, 3, 4, 5
145
+ elevation = self.elevations[range_id]
146
+ selected_frames = []
147
+
148
+ for frame_idx in range(start_idx, (range_id + 1) * 16):
149
+ selected_frames.append(frames[frame_idx])
150
+ for frame_idx in range((range_id) * 16, start_idx):
151
+ selected_frames.append(frames[frame_idx])
152
+
153
+ # Initialize a tensor to store the pixel values
154
+ pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width))
155
+ masks = []
156
+
157
+ # Load and process each frame
158
+ for i, frame_name in enumerate(selected_frames):
159
+ frame_path = os.path.join(folder_path, frame_name)
160
+ img, mask = self.load_im(frame_path)
161
+ mask = mask.squeeze(-1)
162
+ masks.append(mask)
163
+ # Resize the image and convert it to a tensor
164
+ img_resized = img.resize((self.width, self.height))
165
+ img_tensor = torch.from_numpy(np.array(img_resized)).float()
166
+
167
+ # Normalize the image by scaling pixel values to [-1, 1]
168
+ img_normalized = img_tensor / 127.5 - 1
169
+
170
+ # Rearrange channels if necessary
171
+ if self.channels == 3:
172
+ img_normalized = img_normalized.permute(
173
+ 2, 0, 1) # For RGB images
174
+ elif self.channels == 1:
175
+ img_normalized = img_normalized.mean(
176
+ dim=2, keepdim=True) # For grayscale images
177
+
178
+ pixel_values[i] = img_normalized
179
+
180
+ pixel_values = rearrange(pixel_values, 't c h w -> c t h w')
181
+ masks = torch.from_numpy(np.array(masks))
182
+ caption = chosen_folder
183
+
184
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! get the kernels for degraded images
185
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
186
+ kernels = []
187
+ kernel2s = []
188
+ sinc_kernels = []
189
+ for i in range(16):
190
+ kernel_size = random.choice(self.kernel_range)
191
+ if np.random.uniform() < self.opt['sinc_prob']:
192
+ # this sinc filter setting is for kernels ranging from [7, 21]
193
+ if kernel_size < 13:
194
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
195
+ else:
196
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
197
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
198
+ else:
199
+ kernel = random_mixed_kernels(
200
+ self.kernel_list,
201
+ self.kernel_prob,
202
+ kernel_size,
203
+ self.blur_sigma,
204
+ self.blur_sigma, [-math.pi, math.pi],
205
+ self.betag_range,
206
+ self.betap_range,
207
+ noise_range=None)
208
+ # pad kernel
209
+ pad_size = (21 - kernel_size) // 2
210
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
211
+ kernels.append(kernel)
212
+
213
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
214
+ kernel_size = random.choice(self.kernel_range)
215
+ if np.random.uniform() < self.opt['sinc_prob2']:
216
+ if kernel_size < 13:
217
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
218
+ else:
219
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
220
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
221
+ else:
222
+ kernel2 = random_mixed_kernels(
223
+ self.kernel_list2,
224
+ self.kernel_prob2,
225
+ kernel_size,
226
+ self.blur_sigma2,
227
+ self.blur_sigma2, [-math.pi, math.pi],
228
+ self.betag_range2,
229
+ self.betap_range2,
230
+ noise_range=None)
231
+
232
+ # pad kernel
233
+ pad_size = (21 - kernel_size) // 2
234
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
235
+ kernel2s.append(kernel2)
236
+
237
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
238
+ if np.random.uniform() < self.opt['final_sinc_prob']:
239
+ kernel_size = random.choice(self.kernel_range)
240
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
241
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
242
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
243
+ else:
244
+ sinc_kernel = self.pulse_tensor
245
+ sinc_kernels.append(sinc_kernel)
246
+ kernels = np.array(kernels)
247
+ kernel2s = np.array(kernel2s)
248
+ sinc_kernels = torch.stack(sinc_kernels, 0)
249
+ kernels = torch.FloatTensor(kernels)
250
+ kernel2s = torch.FloatTensor(kernel2s)
251
+ return {'video': pixel_values, 'masks': masks, 'elevation': elevation, 'caption': caption, 'kernel1s': kernels, 'kernel2s': kernel2s, 'sinc_kernels': sinc_kernels} # (16, 3, 512, 512)-> (3, 16, 512, 512)
252
+
253
+ class SyncDreamerEvalData(Dataset):
254
+ def __init__(self, image_dir):
255
+ self.image_size = 512
256
+ self.image_dir = Path(image_dir)
257
+ self.crop_size = 20
258
+
259
+ self.fns = []
260
+ for fn in Path(image_dir).iterdir():
261
+ if fn.suffix=='.png':
262
+ self.fns.append(fn)
263
+ print('============= length of dataset %d =============' % len(self.fns))
264
+
265
+ def __len__(self):
266
+ return len(self.fns)
267
+
268
+ def get_data_for_index(self, index):
269
+ input_img_fn = self.fns[index]
270
+ elevation = 0
271
+ return prepare_inputs(input_img_fn, elevation, 512)
272
+
273
+ def __getitem__(self, index):
274
+ return self.get_data_for_index(index)
275
+
276
+ class VideoDataset(pl.LightningDataModule):
277
+ def __init__(self, base_folder, depth_folder, eval_folder, width, height, sample_frames, batch_size, num_workers=4, seed=0, **kwargs):
278
+ super().__init__()
279
+ self.base_folder = base_folder
280
+ self.depth_folder = depth_folder
281
+ self.eval_folder = eval_folder
282
+ self.width = width
283
+ self.height = height
284
+ self.sample_frames = sample_frames
285
+ self.batch_size = batch_size
286
+ self.num_workers = num_workers
287
+ self.seed = seed
288
+ self.additional_args = kwargs
289
+
290
+ def setup(self):
291
+ self.train_dataset = VideoTrainDataset(self.base_folder, self.depth_folder, self.width, self.height, self.sample_frames)
292
+ self.val_dataset = SyncDreamerEvalData(image_dir=self.eval_folder)
293
+
294
+ def train_dataloader(self):
295
+ sampler = DistributedSampler(self.train_dataset, seed=self.seed)
296
+ return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
297
+
298
+ def val_dataloader(self):
299
+ loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
300
+ return loader
301
+
302
+ def test_dataloader(self):
303
+ return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
sgm/inference/api.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from dataclasses import asdict, dataclass
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ from omegaconf import OmegaConf
7
+
8
+ from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
9
+ do_sample)
10
+ from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
11
+ DPMPP2SAncestralSampler,
12
+ EulerAncestralSampler,
13
+ EulerEDMSampler,
14
+ HeunEDMSampler,
15
+ LinearMultistepSampler)
16
+ from sgm.util import load_model_from_config
17
+
18
+
19
+ class ModelArchitecture(str, Enum):
20
+ SD_2_1 = "stable-diffusion-v2-1"
21
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
22
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
23
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
24
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
25
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
26
+
27
+
28
+ class Sampler(str, Enum):
29
+ EULER_EDM = "EulerEDMSampler"
30
+ HEUN_EDM = "HeunEDMSampler"
31
+ EULER_ANCESTRAL = "EulerAncestralSampler"
32
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
33
+ DPMPP2M = "DPMPP2MSampler"
34
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
35
+
36
+
37
+ class Discretization(str, Enum):
38
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
39
+ EDM = "EDMDiscretization"
40
+
41
+
42
+ class Guider(str, Enum):
43
+ VANILLA = "VanillaCFG"
44
+ IDENTITY = "IdentityGuider"
45
+
46
+
47
+ class Thresholder(str, Enum):
48
+ NONE = "None"
49
+
50
+
51
+ @dataclass
52
+ class SamplingParams:
53
+ width: int = 1024
54
+ height: int = 1024
55
+ steps: int = 50
56
+ sampler: Sampler = Sampler.DPMPP2M
57
+ discretization: Discretization = Discretization.LEGACY_DDPM
58
+ guider: Guider = Guider.VANILLA
59
+ thresholder: Thresholder = Thresholder.NONE
60
+ scale: float = 6.0
61
+ aesthetic_score: float = 5.0
62
+ negative_aesthetic_score: float = 5.0
63
+ img2img_strength: float = 1.0
64
+ orig_width: int = 1024
65
+ orig_height: int = 1024
66
+ crop_coords_top: int = 0
67
+ crop_coords_left: int = 0
68
+ sigma_min: float = 0.0292
69
+ sigma_max: float = 14.6146
70
+ rho: float = 3.0
71
+ s_churn: float = 0.0
72
+ s_tmin: float = 0.0
73
+ s_tmax: float = 999.0
74
+ s_noise: float = 1.0
75
+ eta: float = 1.0
76
+ order: int = 4
77
+
78
+
79
+ @dataclass
80
+ class SamplingSpec:
81
+ width: int
82
+ height: int
83
+ channels: int
84
+ factor: int
85
+ is_legacy: bool
86
+ config: str
87
+ ckpt: str
88
+ is_guided: bool
89
+
90
+
91
+ model_specs = {
92
+ ModelArchitecture.SD_2_1: SamplingSpec(
93
+ height=512,
94
+ width=512,
95
+ channels=4,
96
+ factor=8,
97
+ is_legacy=True,
98
+ config="sd_2_1.yaml",
99
+ ckpt="v2-1_512-ema-pruned.safetensors",
100
+ is_guided=True,
101
+ ),
102
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
103
+ height=768,
104
+ width=768,
105
+ channels=4,
106
+ factor=8,
107
+ is_legacy=True,
108
+ config="sd_2_1_768.yaml",
109
+ ckpt="v2-1_768-ema-pruned.safetensors",
110
+ is_guided=True,
111
+ ),
112
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
113
+ height=1024,
114
+ width=1024,
115
+ channels=4,
116
+ factor=8,
117
+ is_legacy=False,
118
+ config="sd_xl_base.yaml",
119
+ ckpt="sd_xl_base_0.9.safetensors",
120
+ is_guided=True,
121
+ ),
122
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
123
+ height=1024,
124
+ width=1024,
125
+ channels=4,
126
+ factor=8,
127
+ is_legacy=True,
128
+ config="sd_xl_refiner.yaml",
129
+ ckpt="sd_xl_refiner_0.9.safetensors",
130
+ is_guided=True,
131
+ ),
132
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
133
+ height=1024,
134
+ width=1024,
135
+ channels=4,
136
+ factor=8,
137
+ is_legacy=False,
138
+ config="sd_xl_base.yaml",
139
+ ckpt="sd_xl_base_1.0.safetensors",
140
+ is_guided=True,
141
+ ),
142
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
143
+ height=1024,
144
+ width=1024,
145
+ channels=4,
146
+ factor=8,
147
+ is_legacy=True,
148
+ config="sd_xl_refiner.yaml",
149
+ ckpt="sd_xl_refiner_1.0.safetensors",
150
+ is_guided=True,
151
+ ),
152
+ }
153
+
154
+
155
+ class SamplingPipeline:
156
+ def __init__(
157
+ self,
158
+ model_id: ModelArchitecture,
159
+ model_path="checkpoints",
160
+ config_path="configs/inference",
161
+ device="cuda",
162
+ use_fp16=True,
163
+ ) -> None:
164
+ if model_id not in model_specs:
165
+ raise ValueError(f"Model {model_id} not supported")
166
+ self.model_id = model_id
167
+ self.specs = model_specs[self.model_id]
168
+ self.config = str(pathlib.Path(config_path, self.specs.config))
169
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
170
+ self.device = device
171
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
172
+
173
+ def _load_model(self, device="cuda", use_fp16=True):
174
+ config = OmegaConf.load(self.config)
175
+ model = load_model_from_config(config, self.ckpt)
176
+ if model is None:
177
+ raise ValueError(f"Model {self.model_id} could not be loaded")
178
+ model.to(device)
179
+ if use_fp16:
180
+ model.conditioner.half()
181
+ model.model.half()
182
+ return model
183
+
184
+ def text_to_image(
185
+ self,
186
+ params: SamplingParams,
187
+ prompt: str,
188
+ negative_prompt: str = "",
189
+ samples: int = 1,
190
+ return_latents: bool = False,
191
+ ):
192
+ sampler = get_sampler_config(params)
193
+ value_dict = asdict(params)
194
+ value_dict["prompt"] = prompt
195
+ value_dict["negative_prompt"] = negative_prompt
196
+ value_dict["target_width"] = params.width
197
+ value_dict["target_height"] = params.height
198
+ return do_sample(
199
+ self.model,
200
+ sampler,
201
+ value_dict,
202
+ samples,
203
+ params.height,
204
+ params.width,
205
+ self.specs.channels,
206
+ self.specs.factor,
207
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
208
+ return_latents=return_latents,
209
+ filter=None,
210
+ )
211
+
212
+ def image_to_image(
213
+ self,
214
+ params: SamplingParams,
215
+ image,
216
+ prompt: str,
217
+ negative_prompt: str = "",
218
+ samples: int = 1,
219
+ return_latents: bool = False,
220
+ ):
221
+ sampler = get_sampler_config(params)
222
+
223
+ if params.img2img_strength < 1.0:
224
+ sampler.discretization = Img2ImgDiscretizationWrapper(
225
+ sampler.discretization,
226
+ strength=params.img2img_strength,
227
+ )
228
+ height, width = image.shape[2], image.shape[3]
229
+ value_dict = asdict(params)
230
+ value_dict["prompt"] = prompt
231
+ value_dict["negative_prompt"] = negative_prompt
232
+ value_dict["target_width"] = width
233
+ value_dict["target_height"] = height
234
+ return do_img2img(
235
+ image,
236
+ self.model,
237
+ sampler,
238
+ value_dict,
239
+ samples,
240
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
241
+ return_latents=return_latents,
242
+ filter=None,
243
+ )
244
+
245
+ def refiner(
246
+ self,
247
+ params: SamplingParams,
248
+ image,
249
+ prompt: str,
250
+ negative_prompt: Optional[str] = None,
251
+ samples: int = 1,
252
+ return_latents: bool = False,
253
+ ):
254
+ sampler = get_sampler_config(params)
255
+ value_dict = {
256
+ "orig_width": image.shape[3] * 8,
257
+ "orig_height": image.shape[2] * 8,
258
+ "target_width": image.shape[3] * 8,
259
+ "target_height": image.shape[2] * 8,
260
+ "prompt": prompt,
261
+ "negative_prompt": negative_prompt,
262
+ "crop_coords_top": 0,
263
+ "crop_coords_left": 0,
264
+ "aesthetic_score": 6.0,
265
+ "negative_aesthetic_score": 2.5,
266
+ }
267
+
268
+ return do_img2img(
269
+ image,
270
+ self.model,
271
+ sampler,
272
+ value_dict,
273
+ samples,
274
+ skip_encode=True,
275
+ return_latents=return_latents,
276
+ filter=None,
277
+ )
278
+
279
+
280
+ def get_guider_config(params: SamplingParams):
281
+ if params.guider == Guider.IDENTITY:
282
+ guider_config = {
283
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
284
+ }
285
+ elif params.guider == Guider.VANILLA:
286
+ scale = params.scale
287
+
288
+ thresholder = params.thresholder
289
+
290
+ if thresholder == Thresholder.NONE:
291
+ dyn_thresh_config = {
292
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
293
+ }
294
+ else:
295
+ raise NotImplementedError
296
+
297
+ guider_config = {
298
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
299
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
300
+ }
301
+ else:
302
+ raise NotImplementedError
303
+ return guider_config
304
+
305
+
306
+ def get_discretization_config(params: SamplingParams):
307
+ if params.discretization == Discretization.LEGACY_DDPM:
308
+ discretization_config = {
309
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
310
+ }
311
+ elif params.discretization == Discretization.EDM:
312
+ discretization_config = {
313
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
314
+ "params": {
315
+ "sigma_min": params.sigma_min,
316
+ "sigma_max": params.sigma_max,
317
+ "rho": params.rho,
318
+ },
319
+ }
320
+ else:
321
+ raise ValueError(f"unknown discretization {params.discretization}")
322
+ return discretization_config
323
+
324
+
325
+ def get_sampler_config(params: SamplingParams):
326
+ discretization_config = get_discretization_config(params)
327
+ guider_config = get_guider_config(params)
328
+ sampler = None
329
+ if params.sampler == Sampler.EULER_EDM:
330
+ return EulerEDMSampler(
331
+ num_steps=params.steps,
332
+ discretization_config=discretization_config,
333
+ guider_config=guider_config,
334
+ s_churn=params.s_churn,
335
+ s_tmin=params.s_tmin,
336
+ s_tmax=params.s_tmax,
337
+ s_noise=params.s_noise,
338
+ verbose=True,
339
+ )
340
+ if params.sampler == Sampler.HEUN_EDM:
341
+ return HeunEDMSampler(
342
+ num_steps=params.steps,
343
+ discretization_config=discretization_config,
344
+ guider_config=guider_config,
345
+ s_churn=params.s_churn,
346
+ s_tmin=params.s_tmin,
347
+ s_tmax=params.s_tmax,
348
+ s_noise=params.s_noise,
349
+ verbose=True,
350
+ )
351
+ if params.sampler == Sampler.EULER_ANCESTRAL:
352
+ return EulerAncestralSampler(
353
+ num_steps=params.steps,
354
+ discretization_config=discretization_config,
355
+ guider_config=guider_config,
356
+ eta=params.eta,
357
+ s_noise=params.s_noise,
358
+ verbose=True,
359
+ )
360
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
361
+ return DPMPP2SAncestralSampler(
362
+ num_steps=params.steps,
363
+ discretization_config=discretization_config,
364
+ guider_config=guider_config,
365
+ eta=params.eta,
366
+ s_noise=params.s_noise,
367
+ verbose=True,
368
+ )
369
+ if params.sampler == Sampler.DPMPP2M:
370
+ return DPMPP2MSampler(
371
+ num_steps=params.steps,
372
+ discretization_config=discretization_config,
373
+ guider_config=guider_config,
374
+ verbose=True,
375
+ )
376
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
377
+ return LinearMultistepSampler(
378
+ num_steps=params.steps,
379
+ discretization_config=discretization_config,
380
+ guider_config=guider_config,
381
+ order=params.order,
382
+ verbose=True,
383
+ )
384
+
385
+ raise ValueError(f"unknown sampler {params.sampler}!")
sgm/inference/helpers.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from imwatermark import WatermarkEncoder
9
+ from omegaconf import ListConfig
10
+ from PIL import Image
11
+ from torch import autocast
12
+
13
+ from sgm.util import append_dims
14
+
15
+
16
+ class WatermarkEmbedder:
17
+ def __init__(self, watermark):
18
+ self.watermark = watermark
19
+ self.num_bits = len(WATERMARK_BITS)
20
+ self.encoder = WatermarkEncoder()
21
+ self.encoder.set_watermark("bits", self.watermark)
22
+
23
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Adds a predefined watermark to the input image
26
+
27
+ Args:
28
+ image: ([N,] B, RGB, H, W) in range [0, 1]
29
+
30
+ Returns:
31
+ same as input but watermarked
32
+ """
33
+ squeeze = len(image.shape) == 4
34
+ if squeeze:
35
+ image = image[None, ...]
36
+ n = image.shape[0]
37
+ image_np = rearrange(
38
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
39
+ ).numpy()[:, :, :, ::-1]
40
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
41
+ # watermarking libary expects input as cv2 BGR format
42
+ for k in range(image_np.shape[0]):
43
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
44
+ image = torch.from_numpy(
45
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
46
+ ).to(image.device)
47
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
48
+ if squeeze:
49
+ image = image[0]
50
+ return image
51
+
52
+
53
+ # A fixed 48-bit message that was choosen at random
54
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
55
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
56
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
57
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
58
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
59
+
60
+
61
+ def get_unique_embedder_keys_from_conditioner(conditioner):
62
+ return list({x.input_key for x in conditioner.embedders})
63
+
64
+
65
+ def perform_save_locally(save_path, samples):
66
+ os.makedirs(os.path.join(save_path), exist_ok=True)
67
+ base_count = len(os.listdir(os.path.join(save_path)))
68
+ samples = embed_watermark(samples)
69
+ for sample in samples:
70
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
71
+ Image.fromarray(sample.astype(np.uint8)).save(
72
+ os.path.join(save_path, f"{base_count:09}.png")
73
+ )
74
+ base_count += 1
75
+
76
+
77
+ class Img2ImgDiscretizationWrapper:
78
+ """
79
+ wraps a discretizer, and prunes the sigmas
80
+ params:
81
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
82
+ """
83
+
84
+ def __init__(self, discretization, strength: float = 1.0):
85
+ self.discretization = discretization
86
+ self.strength = strength
87
+ assert 0.0 <= self.strength <= 1.0
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ # sigmas start large first, and decrease then
91
+ sigmas = self.discretization(*args, **kwargs)
92
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
93
+ sigmas = torch.flip(sigmas, (0,))
94
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
95
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
96
+ sigmas = torch.flip(sigmas, (0,))
97
+ print(f"sigmas after pruning: ", sigmas)
98
+ return sigmas
99
+
100
+
101
+ def do_sample(
102
+ model,
103
+ sampler,
104
+ value_dict,
105
+ num_samples,
106
+ H,
107
+ W,
108
+ C,
109
+ F,
110
+ force_uc_zero_embeddings: Optional[List] = None,
111
+ batch2model_input: Optional[List] = None,
112
+ return_latents=False,
113
+ filter=None,
114
+ device="cuda",
115
+ ):
116
+ if force_uc_zero_embeddings is None:
117
+ force_uc_zero_embeddings = []
118
+ if batch2model_input is None:
119
+ batch2model_input = []
120
+
121
+ with torch.no_grad():
122
+ with autocast(device) as precision_scope:
123
+ with model.ema_scope():
124
+ num_samples = [num_samples]
125
+ batch, batch_uc = get_batch(
126
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
127
+ value_dict,
128
+ num_samples,
129
+ )
130
+ for key in batch:
131
+ if isinstance(batch[key], torch.Tensor):
132
+ print(key, batch[key].shape)
133
+ elif isinstance(batch[key], list):
134
+ print(key, [len(l) for l in batch[key]])
135
+ else:
136
+ print(key, batch[key])
137
+ c, uc = model.conditioner.get_unconditional_conditioning(
138
+ batch,
139
+ batch_uc=batch_uc,
140
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
141
+ )
142
+
143
+ for k in c:
144
+ if not k == "crossattn":
145
+ c[k], uc[k] = map(
146
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
147
+ )
148
+
149
+ additional_model_inputs = {}
150
+ for k in batch2model_input:
151
+ additional_model_inputs[k] = batch[k]
152
+
153
+ shape = (math.prod(num_samples), C, H // F, W // F)
154
+ randn = torch.randn(shape).to(device)
155
+
156
+ def denoiser(input, sigma, c):
157
+ return model.denoiser(
158
+ model.model, input, sigma, c, **additional_model_inputs
159
+ )
160
+
161
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
162
+ samples_x = model.decode_first_stage(samples_z)
163
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
164
+
165
+ if filter is not None:
166
+ samples = filter(samples)
167
+
168
+ if return_latents:
169
+ return samples, samples_z
170
+ return samples
171
+
172
+
173
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174
+ # Hardcoded demo setups; might undergo some changes in the future
175
+
176
+ batch = {}
177
+ batch_uc = {}
178
+
179
+ for key in keys:
180
+ if key == "txt":
181
+ batch["txt"] = (
182
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183
+ .reshape(N)
184
+ .tolist()
185
+ )
186
+ batch_uc["txt"] = (
187
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188
+ .reshape(N)
189
+ .tolist()
190
+ )
191
+ elif key == "original_size_as_tuple":
192
+ batch["original_size_as_tuple"] = (
193
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194
+ .to(device)
195
+ .repeat(*N, 1)
196
+ )
197
+ elif key == "crop_coords_top_left":
198
+ batch["crop_coords_top_left"] = (
199
+ torch.tensor(
200
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201
+ )
202
+ .to(device)
203
+ .repeat(*N, 1)
204
+ )
205
+ elif key == "aesthetic_score":
206
+ batch["aesthetic_score"] = (
207
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208
+ )
209
+ batch_uc["aesthetic_score"] = (
210
+ torch.tensor([value_dict["negative_aesthetic_score"]])
211
+ .to(device)
212
+ .repeat(*N, 1)
213
+ )
214
+
215
+ elif key == "target_size_as_tuple":
216
+ batch["target_size_as_tuple"] = (
217
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218
+ .to(device)
219
+ .repeat(*N, 1)
220
+ )
221
+ else:
222
+ batch[key] = value_dict[key]
223
+
224
+ for key in batch.keys():
225
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226
+ batch_uc[key] = torch.clone(batch[key])
227
+ return batch, batch_uc
228
+
229
+
230
+ def get_input_image_tensor(image: Image.Image, device="cuda"):
231
+ w, h = image.size
232
+ print(f"loaded input image of size ({w}, {h})")
233
+ width, height = map(
234
+ lambda x: x - x % 64, (w, h)
235
+ ) # resize to integer multiple of 64
236
+ image = image.resize((width, height))
237
+ image_array = np.array(image.convert("RGB"))
238
+ image_array = image_array[None].transpose(0, 3, 1, 2)
239
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
240
+ return image_tensor.to(device)
241
+
242
+
243
+ def do_img2img(
244
+ img,
245
+ model,
246
+ sampler,
247
+ value_dict,
248
+ num_samples,
249
+ force_uc_zero_embeddings=[],
250
+ additional_kwargs={},
251
+ offset_noise_level: float = 0.0,
252
+ return_latents=False,
253
+ skip_encode=False,
254
+ filter=None,
255
+ device="cuda",
256
+ ):
257
+ with torch.no_grad():
258
+ with autocast(device) as precision_scope:
259
+ with model.ema_scope():
260
+ batch, batch_uc = get_batch(
261
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
262
+ value_dict,
263
+ [num_samples],
264
+ )
265
+ c, uc = model.conditioner.get_unconditional_conditioning(
266
+ batch,
267
+ batch_uc=batch_uc,
268
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
269
+ )
270
+
271
+ for k in c:
272
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
+
274
+ for k in additional_kwargs:
275
+ c[k] = uc[k] = additional_kwargs[k]
276
+ if skip_encode:
277
+ z = img
278
+ else:
279
+ z = model.encode_first_stage(img)
280
+ noise = torch.randn_like(z)
281
+ sigmas = sampler.discretization(sampler.num_steps)
282
+ sigma = sigmas[0].to(z.device)
283
+
284
+ if offset_noise_level > 0.0:
285
+ noise = noise + offset_noise_level * append_dims(
286
+ torch.randn(z.shape[0], device=z.device), z.ndim
287
+ )
288
+ noised_z = z + noise * append_dims(sigma, z.ndim)
289
+ noised_z = noised_z / torch.sqrt(
290
+ 1.0 + sigmas[0] ** 2.0
291
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
+
293
+ def denoiser(x, sigma, c):
294
+ return model.denoiser(model.model, x, sigma, c)
295
+
296
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297
+ samples_x = model.decode_first_stage(samples_z)
298
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
+
300
+ if filter is not None:
301
+ samples = filter(samples)
302
+
303
+ if return_latents:
304
+ return samples, samples_z
305
+ return samples