Spaces:
Running
on
Zero
Running
on
Zero
aki-0421
commited on
Commit
·
a3a3ae4
unverified
·
0
Parent(s):
F: add
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +1 -0
- README.md +14 -0
- annotator/base_annotator.py +57 -0
- annotator/canny.py +63 -0
- annotator/color.py +59 -0
- annotator/hed.py +155 -0
- annotator/identity.py +25 -0
- annotator/invert.py +25 -0
- annotator/midas/__init__.py +0 -0
- annotator/midas/api.py +165 -0
- annotator/midas/base_model.py +17 -0
- annotator/midas/blocks.py +390 -0
- annotator/midas/dpt_depth.py +106 -0
- annotator/midas/midas_net.py +79 -0
- annotator/midas/midas_net_custom.py +166 -0
- annotator/midas/transforms.py +230 -0
- annotator/midas/utils.py +192 -0
- annotator/midas/vit.py +509 -0
- annotator/midas_op.py +79 -0
- annotator/mlsd/__init__.py +0 -0
- annotator/mlsd/mbv2_mlsd_large.py +303 -0
- annotator/mlsd/mbv2_mlsd_tiny.py +287 -0
- annotator/mlsd/utils.py +638 -0
- annotator/mlsd_op.py +74 -0
- annotator/openpose.py +812 -0
- annotator/registry.py +30 -0
- annotator/utils.py +114 -0
- app.py +32 -0
- dataset/.gitignore +1 -0
- dataset/opencv_transforms/__init__.py +0 -0
- dataset/opencv_transforms/functional.py +598 -0
- dataset/opencv_transforms/transforms.py +1044 -0
- dataset/setup.py +23 -0
- dataset/tests/compare_to_pil_for_testing.ipynb +241 -0
- dataset/tests/setup_testing_directory.py +50 -0
- dataset/tests/test_color.py +68 -0
- dataset/tests/test_spatial.py +52 -0
- dataset/tests/utils.py +8 -0
- inference.yaml +166 -0
- packages.txt +2 -0
- pipeline.py +168 -0
- requirements.txt +49 -0
- sgm/__init__.py +4 -0
- sgm/data/__init__.py +1 -0
- sgm/data/dataset.py +80 -0
- sgm/data/video_dataset.py +191 -0
- sgm/data/video_dataset_stage2_degradeImages.py +303 -0
- sgm/inference/api.py +385 -0
- sgm/inference/helpers.py +305 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.idea
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Character 360
|
3 |
+
emoji: 🏆
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.9.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: unknown
|
11 |
+
short_description: Would you like to see your character in 360°?
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
annotator/base_annotator.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from abc import ABCMeta
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
9 |
+
from scepter.modules.model.base_model import BaseModel
|
10 |
+
from scepter.modules.utils.config import dict_to_yaml
|
11 |
+
|
12 |
+
|
13 |
+
@ANNOTATORS.register_class()
|
14 |
+
class BaseAnnotator(BaseModel, metaclass=ABCMeta):
|
15 |
+
para_dict = {}
|
16 |
+
|
17 |
+
def __init__(self, cfg, logger=None):
|
18 |
+
super().__init__(cfg, logger=logger)
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
@torch.inference_mode
|
22 |
+
def forward(self, *args, **kwargs):
|
23 |
+
raise NotImplementedError
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def get_config_template():
|
27 |
+
return dict_to_yaml('ANNOTATORS',
|
28 |
+
__class__.__name__,
|
29 |
+
BaseAnnotator.para_dict,
|
30 |
+
set_name=True)
|
31 |
+
|
32 |
+
|
33 |
+
@ANNOTATORS.register_class()
|
34 |
+
class GeneralAnnotator(BaseAnnotator, metaclass=ABCMeta):
|
35 |
+
def __init__(self, cfg, logger=None):
|
36 |
+
super().__init__(cfg, logger=logger)
|
37 |
+
anno_models = cfg.get('ANNOTATORS', [])
|
38 |
+
self.annotators = nn.ModuleList()
|
39 |
+
for n, anno_config in enumerate(anno_models):
|
40 |
+
annotator = ANNOTATORS.build(anno_config, logger=logger)
|
41 |
+
annotator.input_keys = anno_config.get('INPUT_KEYS', [])
|
42 |
+
if isinstance(annotator.input_keys, str):
|
43 |
+
annotator.input_keys = [annotator.input_keys]
|
44 |
+
annotator.output_keys = anno_config.get('OUTPUT_KEYS', [])
|
45 |
+
if isinstance(annotator.output_keys, str):
|
46 |
+
annotator.output_keys = [annotator.output_keys]
|
47 |
+
assert len(annotator.input_keys) == len(annotator.output_keys)
|
48 |
+
self.annotators.append(annotator)
|
49 |
+
|
50 |
+
def forward(self, input_dict):
|
51 |
+
output_dict = {}
|
52 |
+
for annotator in self.annotators:
|
53 |
+
for idx, in_key in enumerate(annotator.input_keys):
|
54 |
+
if in_key in input_dict:
|
55 |
+
image = annotator(input_dict[in_key])
|
56 |
+
output_dict[annotator.output_keys[idx]] = image
|
57 |
+
return output_dict
|
annotator/canny.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from abc import ABCMeta
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
11 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
12 |
+
from scepter.modules.utils.config import dict_to_yaml
|
13 |
+
|
14 |
+
|
15 |
+
@ANNOTATORS.register_class()
|
16 |
+
class CannyAnnotator(BaseAnnotator, metaclass=ABCMeta):
|
17 |
+
para_dict = {}
|
18 |
+
|
19 |
+
def __init__(self, cfg, logger=None):
|
20 |
+
super().__init__(cfg, logger=logger)
|
21 |
+
self.low_threshold = cfg.get('LOW_THRESHOLD', 100)
|
22 |
+
self.high_threshold = cfg.get('HIGH_THRESHOLD', 200)
|
23 |
+
self.random_cfg = cfg.get('RANDOM_CFG', None)
|
24 |
+
|
25 |
+
def forward(self, image):
|
26 |
+
if isinstance(image, Image.Image):
|
27 |
+
image = np.array(image)
|
28 |
+
elif isinstance(image, torch.Tensor):
|
29 |
+
image = image.detach().cpu().numpy()
|
30 |
+
elif isinstance(image, np.ndarray):
|
31 |
+
image = image.copy()
|
32 |
+
else:
|
33 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
34 |
+
assert len(image.shape) < 4
|
35 |
+
|
36 |
+
if self.random_cfg is None:
|
37 |
+
image = cv2.Canny(image, self.low_threshold, self.high_threshold)
|
38 |
+
else:
|
39 |
+
proba = self.random_cfg.get('PROBA', 1.0)
|
40 |
+
if np.random.random() < proba:
|
41 |
+
min_low_threshold = self.random_cfg.get(
|
42 |
+
'MIN_LOW_THRESHOLD', 50)
|
43 |
+
max_low_threshold = self.random_cfg.get(
|
44 |
+
'MAX_LOW_THRESHOLD', 100)
|
45 |
+
min_high_threshold = self.random_cfg.get(
|
46 |
+
'MIN_HIGH_THRESHOLD', 200)
|
47 |
+
max_high_threshold = self.random_cfg.get(
|
48 |
+
'MAX_HIGH_THRESHOLD', 350)
|
49 |
+
low_th = np.random.randint(min_low_threshold,
|
50 |
+
max_low_threshold)
|
51 |
+
high_th = np.random.randint(min_high_threshold,
|
52 |
+
max_high_threshold)
|
53 |
+
else:
|
54 |
+
low_th, high_th = self.low_threshold, self.high_threshold
|
55 |
+
image = cv2.Canny(image, low_th, high_th)
|
56 |
+
return image[..., None].repeat(3, 2)
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def get_config_template():
|
60 |
+
return dict_to_yaml('ANNOTATORS',
|
61 |
+
__class__.__name__,
|
62 |
+
CannyAnnotator.para_dict,
|
63 |
+
set_name=True)
|
annotator/color.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from abc import ABCMeta
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
11 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
12 |
+
from scepter.modules.utils.config import dict_to_yaml
|
13 |
+
|
14 |
+
|
15 |
+
@ANNOTATORS.register_class()
|
16 |
+
class ColorAnnotator(BaseAnnotator, metaclass=ABCMeta):
|
17 |
+
para_dict = {}
|
18 |
+
|
19 |
+
def __init__(self, cfg, logger=None):
|
20 |
+
super().__init__(cfg, logger=logger)
|
21 |
+
self.ratio = cfg.get('RATIO', 64)
|
22 |
+
self.random_cfg = cfg.get('RANDOM_CFG', None)
|
23 |
+
|
24 |
+
def forward(self, image):
|
25 |
+
if isinstance(image, Image.Image):
|
26 |
+
image = np.array(image)
|
27 |
+
elif isinstance(image, torch.Tensor):
|
28 |
+
image = image.detach().cpu().numpy()
|
29 |
+
elif isinstance(image, np.ndarray):
|
30 |
+
image = image.copy()
|
31 |
+
else:
|
32 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
33 |
+
h, w = image.shape[:2]
|
34 |
+
|
35 |
+
if self.random_cfg is None:
|
36 |
+
ratio = self.ratio
|
37 |
+
else:
|
38 |
+
proba = self.random_cfg.get('PROBA', 1.0)
|
39 |
+
if np.random.random() < proba:
|
40 |
+
if 'CHOICE_RATIO' in self.random_cfg:
|
41 |
+
ratio = np.random.choice(self.random_cfg['CHOICE_RATIO'])
|
42 |
+
else:
|
43 |
+
min_ratio = self.random_cfg.get('MIN_RATIO', 48)
|
44 |
+
max_ratio = self.random_cfg.get('MAX_RATIO', 96)
|
45 |
+
ratio = np.random.randint(min_ratio, max_ratio)
|
46 |
+
else:
|
47 |
+
ratio = self.ratio
|
48 |
+
image = cv2.resize(image, (int(w // ratio), int(h // ratio)),
|
49 |
+
interpolation=cv2.INTER_CUBIC)
|
50 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_NEAREST)
|
51 |
+
assert len(image.shape) < 4
|
52 |
+
return image
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def get_config_template():
|
56 |
+
return dict_to_yaml('ANNOTATORS',
|
57 |
+
__class__.__name__,
|
58 |
+
ColorAnnotator.para_dict,
|
59 |
+
set_name=True)
|
annotator/hed.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
# Please use this implementation in your products
|
4 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
5 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
6 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
7 |
+
# and in this way it works better for gradio's RGB protocol
|
8 |
+
|
9 |
+
from abc import ABCMeta
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
17 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
18 |
+
from scepter.modules.utils.config import dict_to_yaml
|
19 |
+
from scepter.modules.utils.distribute import we
|
20 |
+
from scepter.modules.utils.file_system import FS
|
21 |
+
|
22 |
+
|
23 |
+
def nms(x, t, s):
|
24 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
25 |
+
|
26 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
27 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
28 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
29 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
30 |
+
|
31 |
+
y = np.zeros_like(x)
|
32 |
+
|
33 |
+
for f in [f1, f2, f3, f4]:
|
34 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
35 |
+
|
36 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
37 |
+
z[y > t] = 255
|
38 |
+
return z
|
39 |
+
|
40 |
+
|
41 |
+
class DoubleConvBlock(torch.nn.Module):
|
42 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
43 |
+
super().__init__()
|
44 |
+
self.convs = torch.nn.Sequential()
|
45 |
+
self.convs.append(
|
46 |
+
torch.nn.Conv2d(in_channels=input_channel,
|
47 |
+
out_channels=output_channel,
|
48 |
+
kernel_size=(3, 3),
|
49 |
+
stride=(1, 1),
|
50 |
+
padding=1))
|
51 |
+
for i in range(1, layer_number):
|
52 |
+
self.convs.append(
|
53 |
+
torch.nn.Conv2d(in_channels=output_channel,
|
54 |
+
out_channels=output_channel,
|
55 |
+
kernel_size=(3, 3),
|
56 |
+
stride=(1, 1),
|
57 |
+
padding=1))
|
58 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel,
|
59 |
+
out_channels=1,
|
60 |
+
kernel_size=(1, 1),
|
61 |
+
stride=(1, 1),
|
62 |
+
padding=0)
|
63 |
+
|
64 |
+
def __call__(self, x, down_sampling=False):
|
65 |
+
h = x
|
66 |
+
if down_sampling:
|
67 |
+
h = torch.nn.functional.max_pool2d(h,
|
68 |
+
kernel_size=(2, 2),
|
69 |
+
stride=(2, 2))
|
70 |
+
for conv in self.convs:
|
71 |
+
h = conv(h)
|
72 |
+
h = torch.nn.functional.relu(h)
|
73 |
+
return h, self.projection(h)
|
74 |
+
|
75 |
+
|
76 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
77 |
+
def __init__(self):
|
78 |
+
super().__init__()
|
79 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
80 |
+
self.block1 = DoubleConvBlock(input_channel=3,
|
81 |
+
output_channel=64,
|
82 |
+
layer_number=2)
|
83 |
+
self.block2 = DoubleConvBlock(input_channel=64,
|
84 |
+
output_channel=128,
|
85 |
+
layer_number=2)
|
86 |
+
self.block3 = DoubleConvBlock(input_channel=128,
|
87 |
+
output_channel=256,
|
88 |
+
layer_number=3)
|
89 |
+
self.block4 = DoubleConvBlock(input_channel=256,
|
90 |
+
output_channel=512,
|
91 |
+
layer_number=3)
|
92 |
+
self.block5 = DoubleConvBlock(input_channel=512,
|
93 |
+
output_channel=512,
|
94 |
+
layer_number=3)
|
95 |
+
|
96 |
+
def __call__(self, x):
|
97 |
+
h = x - self.norm
|
98 |
+
h, projection1 = self.block1(h)
|
99 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
100 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
101 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
102 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
103 |
+
return projection1, projection2, projection3, projection4, projection5
|
104 |
+
|
105 |
+
|
106 |
+
@ANNOTATORS.register_class()
|
107 |
+
class HedAnnotator(BaseAnnotator, metaclass=ABCMeta):
|
108 |
+
para_dict = {}
|
109 |
+
|
110 |
+
def __init__(self, cfg, logger=None):
|
111 |
+
super().__init__(cfg, logger=logger)
|
112 |
+
self.netNetwork = ControlNetHED_Apache2().float().eval()
|
113 |
+
pretrained_model = cfg.get('PRETRAINED_MODEL', None)
|
114 |
+
if pretrained_model:
|
115 |
+
with FS.get_from(pretrained_model, wait_finish=True) as local_path:
|
116 |
+
self.netNetwork.load_state_dict(torch.load(local_path))
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
@torch.inference_mode()
|
120 |
+
@torch.autocast('cuda', enabled=False)
|
121 |
+
def forward(self, image):
|
122 |
+
if isinstance(image, torch.Tensor):
|
123 |
+
if len(image.shape) == 3:
|
124 |
+
image = rearrange(image, 'h w c -> 1 c h w')
|
125 |
+
B, C, H, W = image.shape
|
126 |
+
else:
|
127 |
+
raise "Unsurpport input image's shape"
|
128 |
+
elif isinstance(image, np.ndarray):
|
129 |
+
image = torch.from_numpy(image.copy()).float()
|
130 |
+
if len(image.shape) == 3:
|
131 |
+
image = rearrange(image, 'h w c -> 1 c h w')
|
132 |
+
B, C, H, W = image.shape
|
133 |
+
else:
|
134 |
+
raise "Unsurpport input image's shape"
|
135 |
+
else:
|
136 |
+
raise "Unsurpport input image's type"
|
137 |
+
edges = self.netNetwork(image.to(we.device_id))
|
138 |
+
edges = [
|
139 |
+
e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges
|
140 |
+
]
|
141 |
+
edges = [
|
142 |
+
cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR)
|
143 |
+
for e in edges
|
144 |
+
]
|
145 |
+
edges = np.stack(edges, axis=2)
|
146 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
147 |
+
edge = 255 - (edge * 255.0).clip(0, 255).astype(np.uint8)
|
148 |
+
return edge[..., None].repeat(3, 2)
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def get_config_template():
|
152 |
+
return dict_to_yaml('ANNOTATORS',
|
153 |
+
__class__.__name__,
|
154 |
+
HedAnnotator.para_dict,
|
155 |
+
set_name=True)
|
annotator/identity.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from abc import ABCMeta
|
4 |
+
|
5 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
6 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
7 |
+
from scepter.modules.utils.config import dict_to_yaml
|
8 |
+
|
9 |
+
|
10 |
+
@ANNOTATORS.register_class()
|
11 |
+
class IdentityAnnotator(BaseAnnotator, metaclass=ABCMeta):
|
12 |
+
para_dict = {}
|
13 |
+
|
14 |
+
def __init__(self, cfg, logger=None):
|
15 |
+
super().__init__(cfg, logger=logger)
|
16 |
+
|
17 |
+
def forward(self, image):
|
18 |
+
return image
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def get_config_template():
|
22 |
+
return dict_to_yaml('ANNOTATORS',
|
23 |
+
__class__.__name__,
|
24 |
+
IdentityAnnotator.para_dict,
|
25 |
+
set_name=True)
|
annotator/invert.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from abc import ABCMeta
|
4 |
+
|
5 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
6 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
7 |
+
from scepter.modules.utils.config import dict_to_yaml
|
8 |
+
|
9 |
+
|
10 |
+
@ANNOTATORS.register_class()
|
11 |
+
class InvertAnnotator(BaseAnnotator, metaclass=ABCMeta):
|
12 |
+
para_dict = {}
|
13 |
+
|
14 |
+
def __init__(self, cfg, logger=None):
|
15 |
+
super().__init__(cfg, logger=logger)
|
16 |
+
|
17 |
+
def forward(self, image):
|
18 |
+
return 255 - image
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def get_config_template():
|
22 |
+
return dict_to_yaml('ANNOTATORS',
|
23 |
+
__class__.__name__,
|
24 |
+
InvertAnnotator.para_dict,
|
25 |
+
set_name=True)
|
annotator/midas/__init__.py
ADDED
File without changes
|
annotator/midas/api.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# based on https://github.com/isl-org/MiDaS
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.transforms import Compose
|
8 |
+
|
9 |
+
from .dpt_depth import DPTDepthModel
|
10 |
+
from .midas_net import MidasNet
|
11 |
+
from .midas_net_custom import MidasNet_small
|
12 |
+
from .transforms import NormalizeImage, PrepareForNet, Resize
|
13 |
+
|
14 |
+
# ISL_PATHS = {
|
15 |
+
# "dpt_large": "dpt_large-midas-2f21e586.pt",
|
16 |
+
# "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
17 |
+
# "midas_v21": "",
|
18 |
+
# "midas_v21_small": "",
|
19 |
+
# }
|
20 |
+
|
21 |
+
# remote_model_path =
|
22 |
+
# "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
|
23 |
+
|
24 |
+
|
25 |
+
def disabled_train(self, mode=True):
|
26 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
27 |
+
does not change anymore."""
|
28 |
+
return self
|
29 |
+
|
30 |
+
|
31 |
+
def load_midas_transform(model_type):
|
32 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
33 |
+
# load transform only
|
34 |
+
if model_type == 'dpt_large': # DPT-Large
|
35 |
+
net_w, net_h = 384, 384
|
36 |
+
resize_mode = 'minimal'
|
37 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
38 |
+
std=[0.5, 0.5, 0.5])
|
39 |
+
|
40 |
+
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
41 |
+
net_w, net_h = 384, 384
|
42 |
+
resize_mode = 'minimal'
|
43 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
44 |
+
std=[0.5, 0.5, 0.5])
|
45 |
+
|
46 |
+
elif model_type == 'midas_v21':
|
47 |
+
net_w, net_h = 384, 384
|
48 |
+
resize_mode = 'upper_bound'
|
49 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
50 |
+
std=[0.229, 0.224, 0.225])
|
51 |
+
|
52 |
+
elif model_type == 'midas_v21_small':
|
53 |
+
net_w, net_h = 256, 256
|
54 |
+
resize_mode = 'upper_bound'
|
55 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
56 |
+
std=[0.229, 0.224, 0.225])
|
57 |
+
|
58 |
+
else:
|
59 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
60 |
+
|
61 |
+
transform = Compose([
|
62 |
+
Resize(
|
63 |
+
net_w,
|
64 |
+
net_h,
|
65 |
+
resize_target=None,
|
66 |
+
keep_aspect_ratio=True,
|
67 |
+
ensure_multiple_of=32,
|
68 |
+
resize_method=resize_mode,
|
69 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
70 |
+
),
|
71 |
+
normalization,
|
72 |
+
PrepareForNet(),
|
73 |
+
])
|
74 |
+
|
75 |
+
return transform
|
76 |
+
|
77 |
+
|
78 |
+
def load_model(model_type, model_path):
|
79 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
80 |
+
# load network
|
81 |
+
# model_path = ISL_PATHS[model_type]
|
82 |
+
if model_type == 'dpt_large': # DPT-Large
|
83 |
+
model = DPTDepthModel(
|
84 |
+
path=model_path,
|
85 |
+
backbone='vitl16_384',
|
86 |
+
non_negative=True,
|
87 |
+
)
|
88 |
+
net_w, net_h = 384, 384
|
89 |
+
resize_mode = 'minimal'
|
90 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
91 |
+
std=[0.5, 0.5, 0.5])
|
92 |
+
|
93 |
+
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
94 |
+
model = DPTDepthModel(
|
95 |
+
path=model_path,
|
96 |
+
backbone='vitb_rn50_384',
|
97 |
+
non_negative=True,
|
98 |
+
)
|
99 |
+
net_w, net_h = 384, 384
|
100 |
+
resize_mode = 'minimal'
|
101 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
102 |
+
std=[0.5, 0.5, 0.5])
|
103 |
+
|
104 |
+
elif model_type == 'midas_v21':
|
105 |
+
model = MidasNet(model_path, non_negative=True)
|
106 |
+
net_w, net_h = 384, 384
|
107 |
+
resize_mode = 'upper_bound'
|
108 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
109 |
+
std=[0.229, 0.224, 0.225])
|
110 |
+
|
111 |
+
elif model_type == 'midas_v21_small':
|
112 |
+
model = MidasNet_small(model_path,
|
113 |
+
features=64,
|
114 |
+
backbone='efficientnet_lite3',
|
115 |
+
exportable=True,
|
116 |
+
non_negative=True,
|
117 |
+
blocks={'expand': True})
|
118 |
+
net_w, net_h = 256, 256
|
119 |
+
resize_mode = 'upper_bound'
|
120 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
121 |
+
std=[0.229, 0.224, 0.225])
|
122 |
+
|
123 |
+
else:
|
124 |
+
print(
|
125 |
+
f"model_type '{model_type}' not implemented, use: --model_type large"
|
126 |
+
)
|
127 |
+
assert False
|
128 |
+
|
129 |
+
transform = Compose([
|
130 |
+
Resize(
|
131 |
+
net_w,
|
132 |
+
net_h,
|
133 |
+
resize_target=None,
|
134 |
+
keep_aspect_ratio=True,
|
135 |
+
ensure_multiple_of=32,
|
136 |
+
resize_method=resize_mode,
|
137 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
138 |
+
),
|
139 |
+
normalization,
|
140 |
+
PrepareForNet(),
|
141 |
+
])
|
142 |
+
|
143 |
+
return model.eval(), transform
|
144 |
+
|
145 |
+
|
146 |
+
class MiDaSInference(nn.Module):
|
147 |
+
MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
|
148 |
+
MODEL_TYPES_ISL = [
|
149 |
+
'dpt_large',
|
150 |
+
'dpt_hybrid',
|
151 |
+
'midas_v21',
|
152 |
+
'midas_v21_small',
|
153 |
+
]
|
154 |
+
|
155 |
+
def __init__(self, model_type, model_path):
|
156 |
+
super().__init__()
|
157 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
158 |
+
model, _ = load_model(model_type, model_path)
|
159 |
+
self.model = model
|
160 |
+
self.model.train = disabled_train
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
with torch.no_grad():
|
164 |
+
prediction = self.model(x)
|
165 |
+
return prediction
|
annotator/midas/base_model.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class BaseModel(torch.nn.Module):
|
6 |
+
def load(self, path):
|
7 |
+
"""Load model from file.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
path (str): file path
|
11 |
+
"""
|
12 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
13 |
+
|
14 |
+
if 'optimizer' in parameters:
|
15 |
+
parameters = parameters['model']
|
16 |
+
|
17 |
+
self.load_state_dict(parameters)
|
annotator/midas/blocks.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384)
|
7 |
+
|
8 |
+
|
9 |
+
def _make_encoder(
|
10 |
+
backbone,
|
11 |
+
features,
|
12 |
+
use_pretrained,
|
13 |
+
groups=1,
|
14 |
+
expand=False,
|
15 |
+
exportable=True,
|
16 |
+
hooks=None,
|
17 |
+
use_vit_only=False,
|
18 |
+
use_readout='ignore',
|
19 |
+
):
|
20 |
+
if backbone == 'vitl16_384':
|
21 |
+
pretrained = _make_pretrained_vitl16_384(use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_readout=use_readout)
|
24 |
+
scratch = _make_scratch(
|
25 |
+
[256, 512, 1024, 1024], features, groups=groups,
|
26 |
+
expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
|
27 |
+
elif backbone == 'vitb_rn50_384':
|
28 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
29 |
+
use_pretrained,
|
30 |
+
hooks=hooks,
|
31 |
+
use_vit_only=use_vit_only,
|
32 |
+
use_readout=use_readout,
|
33 |
+
)
|
34 |
+
scratch = _make_scratch(
|
35 |
+
[256, 512, 768, 768], features, groups=groups,
|
36 |
+
expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
|
37 |
+
elif backbone == 'vitb16_384':
|
38 |
+
pretrained = _make_pretrained_vitb16_384(use_pretrained,
|
39 |
+
hooks=hooks,
|
40 |
+
use_readout=use_readout)
|
41 |
+
scratch = _make_scratch(
|
42 |
+
[96, 192, 384, 768], features, groups=groups,
|
43 |
+
expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
|
44 |
+
elif backbone == 'resnext101_wsl':
|
45 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
46 |
+
scratch = _make_scratch([256, 512, 1024, 2048],
|
47 |
+
features,
|
48 |
+
groups=groups,
|
49 |
+
expand=expand) # efficientnet_lite3
|
50 |
+
elif backbone == 'efficientnet_lite3':
|
51 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
|
52 |
+
exportable=exportable)
|
53 |
+
scratch = _make_scratch([32, 48, 136, 384],
|
54 |
+
features,
|
55 |
+
groups=groups,
|
56 |
+
expand=expand) # efficientnet_lite3
|
57 |
+
else:
|
58 |
+
print(f"Backbone '{backbone}' not implemented")
|
59 |
+
assert False
|
60 |
+
|
61 |
+
return pretrained, scratch
|
62 |
+
|
63 |
+
|
64 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
65 |
+
scratch = nn.Module()
|
66 |
+
|
67 |
+
out_shape1 = out_shape
|
68 |
+
out_shape2 = out_shape
|
69 |
+
out_shape3 = out_shape
|
70 |
+
out_shape4 = out_shape
|
71 |
+
if expand is True:
|
72 |
+
out_shape1 = out_shape
|
73 |
+
out_shape2 = out_shape * 2
|
74 |
+
out_shape3 = out_shape * 4
|
75 |
+
out_shape4 = out_shape * 8
|
76 |
+
|
77 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0],
|
78 |
+
out_shape1,
|
79 |
+
kernel_size=3,
|
80 |
+
stride=1,
|
81 |
+
padding=1,
|
82 |
+
bias=False,
|
83 |
+
groups=groups)
|
84 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1],
|
85 |
+
out_shape2,
|
86 |
+
kernel_size=3,
|
87 |
+
stride=1,
|
88 |
+
padding=1,
|
89 |
+
bias=False,
|
90 |
+
groups=groups)
|
91 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2],
|
92 |
+
out_shape3,
|
93 |
+
kernel_size=3,
|
94 |
+
stride=1,
|
95 |
+
padding=1,
|
96 |
+
bias=False,
|
97 |
+
groups=groups)
|
98 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3],
|
99 |
+
out_shape4,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=1,
|
103 |
+
bias=False,
|
104 |
+
groups=groups)
|
105 |
+
|
106 |
+
return scratch
|
107 |
+
|
108 |
+
|
109 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
110 |
+
efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
|
111 |
+
'tf_efficientnet_lite3',
|
112 |
+
pretrained=use_pretrained,
|
113 |
+
exportable=exportable)
|
114 |
+
return _make_efficientnet_backbone(efficientnet)
|
115 |
+
|
116 |
+
|
117 |
+
def _make_efficientnet_backbone(effnet):
|
118 |
+
pretrained = nn.Module()
|
119 |
+
|
120 |
+
pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
|
121 |
+
effnet.act1, *effnet.blocks[0:2])
|
122 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
123 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
124 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
125 |
+
|
126 |
+
return pretrained
|
127 |
+
|
128 |
+
|
129 |
+
def _make_resnet_backbone(resnet):
|
130 |
+
pretrained = nn.Module()
|
131 |
+
pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
|
132 |
+
resnet.maxpool, resnet.layer1)
|
133 |
+
|
134 |
+
pretrained.layer2 = resnet.layer2
|
135 |
+
pretrained.layer3 = resnet.layer3
|
136 |
+
pretrained.layer4 = resnet.layer4
|
137 |
+
|
138 |
+
return pretrained
|
139 |
+
|
140 |
+
|
141 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
142 |
+
resnet = torch.hub.load('facebookresearch/WSL-Images',
|
143 |
+
'resnext101_32x8d_wsl')
|
144 |
+
return _make_resnet_backbone(resnet)
|
145 |
+
|
146 |
+
|
147 |
+
class Interpolate(nn.Module):
|
148 |
+
"""Interpolation module.
|
149 |
+
"""
|
150 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
151 |
+
"""Init.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
scale_factor (float): scaling
|
155 |
+
mode (str): interpolation mode
|
156 |
+
"""
|
157 |
+
super(Interpolate, self).__init__()
|
158 |
+
|
159 |
+
self.interp = nn.functional.interpolate
|
160 |
+
self.scale_factor = scale_factor
|
161 |
+
self.mode = mode
|
162 |
+
self.align_corners = align_corners
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
"""Forward pass.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
x (tensor): input
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
tensor: interpolated data
|
172 |
+
"""
|
173 |
+
|
174 |
+
x = self.interp(x,
|
175 |
+
scale_factor=self.scale_factor,
|
176 |
+
mode=self.mode,
|
177 |
+
align_corners=self.align_corners)
|
178 |
+
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class ResidualConvUnit(nn.Module):
|
183 |
+
"""Residual convolution module.
|
184 |
+
"""
|
185 |
+
def __init__(self, features):
|
186 |
+
"""Init.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
features (int): number of features
|
190 |
+
"""
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
self.conv1 = nn.Conv2d(features,
|
194 |
+
features,
|
195 |
+
kernel_size=3,
|
196 |
+
stride=1,
|
197 |
+
padding=1,
|
198 |
+
bias=True)
|
199 |
+
|
200 |
+
self.conv2 = nn.Conv2d(features,
|
201 |
+
features,
|
202 |
+
kernel_size=3,
|
203 |
+
stride=1,
|
204 |
+
padding=1,
|
205 |
+
bias=True)
|
206 |
+
|
207 |
+
self.relu = nn.ReLU(inplace=True)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
x (tensor): input
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
tensor: output
|
217 |
+
"""
|
218 |
+
out = self.relu(x)
|
219 |
+
out = self.conv1(out)
|
220 |
+
out = self.relu(out)
|
221 |
+
out = self.conv2(out)
|
222 |
+
|
223 |
+
return out + x
|
224 |
+
|
225 |
+
|
226 |
+
class FeatureFusionBlock(nn.Module):
|
227 |
+
"""Feature fusion block.
|
228 |
+
"""
|
229 |
+
def __init__(self, features):
|
230 |
+
"""Init.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
features (int): number of features
|
234 |
+
"""
|
235 |
+
super(FeatureFusionBlock, self).__init__()
|
236 |
+
|
237 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
238 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
239 |
+
|
240 |
+
def forward(self, *xs):
|
241 |
+
"""Forward pass.
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
tensor: output
|
245 |
+
"""
|
246 |
+
output = xs[0]
|
247 |
+
|
248 |
+
if len(xs) == 2:
|
249 |
+
output += self.resConfUnit1(xs[1])
|
250 |
+
|
251 |
+
output = self.resConfUnit2(output)
|
252 |
+
|
253 |
+
output = nn.functional.interpolate(output,
|
254 |
+
scale_factor=2,
|
255 |
+
mode='bilinear',
|
256 |
+
align_corners=True)
|
257 |
+
|
258 |
+
return output
|
259 |
+
|
260 |
+
|
261 |
+
class ResidualConvUnit_custom(nn.Module):
|
262 |
+
"""Residual convolution module.
|
263 |
+
"""
|
264 |
+
def __init__(self, features, activation, bn):
|
265 |
+
"""Init.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
features (int): number of features
|
269 |
+
"""
|
270 |
+
super().__init__()
|
271 |
+
|
272 |
+
self.bn = bn
|
273 |
+
|
274 |
+
self.groups = 1
|
275 |
+
|
276 |
+
self.conv1 = nn.Conv2d(features,
|
277 |
+
features,
|
278 |
+
kernel_size=3,
|
279 |
+
stride=1,
|
280 |
+
padding=1,
|
281 |
+
bias=True,
|
282 |
+
groups=self.groups)
|
283 |
+
|
284 |
+
self.conv2 = nn.Conv2d(features,
|
285 |
+
features,
|
286 |
+
kernel_size=3,
|
287 |
+
stride=1,
|
288 |
+
padding=1,
|
289 |
+
bias=True,
|
290 |
+
groups=self.groups)
|
291 |
+
|
292 |
+
if self.bn is True:
|
293 |
+
self.bn1 = nn.BatchNorm2d(features)
|
294 |
+
self.bn2 = nn.BatchNorm2d(features)
|
295 |
+
|
296 |
+
self.activation = activation
|
297 |
+
|
298 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
"""Forward pass.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
x (tensor): input
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
tensor: output
|
308 |
+
"""
|
309 |
+
|
310 |
+
out = self.activation(x)
|
311 |
+
out = self.conv1(out)
|
312 |
+
if self.bn is True:
|
313 |
+
out = self.bn1(out)
|
314 |
+
|
315 |
+
out = self.activation(out)
|
316 |
+
out = self.conv2(out)
|
317 |
+
if self.bn is True:
|
318 |
+
out = self.bn2(out)
|
319 |
+
|
320 |
+
if self.groups > 1:
|
321 |
+
out = self.conv_merge(out)
|
322 |
+
|
323 |
+
return self.skip_add.add(out, x)
|
324 |
+
|
325 |
+
# return out + x
|
326 |
+
|
327 |
+
|
328 |
+
class FeatureFusionBlock_custom(nn.Module):
|
329 |
+
"""Feature fusion block.
|
330 |
+
"""
|
331 |
+
def __init__(self,
|
332 |
+
features,
|
333 |
+
activation,
|
334 |
+
deconv=False,
|
335 |
+
bn=False,
|
336 |
+
expand=False,
|
337 |
+
align_corners=True):
|
338 |
+
"""Init.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
features (int): number of features
|
342 |
+
"""
|
343 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
344 |
+
|
345 |
+
self.deconv = deconv
|
346 |
+
self.align_corners = align_corners
|
347 |
+
|
348 |
+
self.groups = 1
|
349 |
+
|
350 |
+
self.expand = expand
|
351 |
+
out_features = features
|
352 |
+
if self.expand is True:
|
353 |
+
out_features = features // 2
|
354 |
+
|
355 |
+
self.out_conv = nn.Conv2d(features,
|
356 |
+
out_features,
|
357 |
+
kernel_size=1,
|
358 |
+
stride=1,
|
359 |
+
padding=0,
|
360 |
+
bias=True,
|
361 |
+
groups=1)
|
362 |
+
|
363 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
364 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
365 |
+
|
366 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
367 |
+
|
368 |
+
def forward(self, *xs):
|
369 |
+
"""Forward pass.
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
tensor: output
|
373 |
+
"""
|
374 |
+
output = xs[0]
|
375 |
+
|
376 |
+
if len(xs) == 2:
|
377 |
+
res = self.resConfUnit1(xs[1])
|
378 |
+
output = self.skip_add.add(output, res)
|
379 |
+
# output += res
|
380 |
+
|
381 |
+
output = self.resConfUnit2(output)
|
382 |
+
|
383 |
+
output = nn.functional.interpolate(output,
|
384 |
+
scale_factor=2,
|
385 |
+
mode='bilinear',
|
386 |
+
align_corners=self.align_corners)
|
387 |
+
|
388 |
+
output = self.out_conv(output)
|
389 |
+
|
390 |
+
return output
|
annotator/midas/dpt_depth.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
7 |
+
from .vit import forward_vit
|
8 |
+
|
9 |
+
|
10 |
+
def _make_fusion_block(features, use_bn):
|
11 |
+
return FeatureFusionBlock_custom(
|
12 |
+
features,
|
13 |
+
nn.ReLU(False),
|
14 |
+
deconv=False,
|
15 |
+
bn=use_bn,
|
16 |
+
expand=False,
|
17 |
+
align_corners=True,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class DPT(BaseModel):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
head,
|
25 |
+
features=256,
|
26 |
+
backbone='vitb_rn50_384',
|
27 |
+
readout='project',
|
28 |
+
channels_last=False,
|
29 |
+
use_bn=False,
|
30 |
+
):
|
31 |
+
|
32 |
+
super(DPT, self).__init__()
|
33 |
+
|
34 |
+
self.channels_last = channels_last
|
35 |
+
|
36 |
+
hooks = {
|
37 |
+
'vitb_rn50_384': [0, 1, 8, 11],
|
38 |
+
'vitb16_384': [2, 5, 8, 11],
|
39 |
+
'vitl16_384': [5, 11, 17, 23],
|
40 |
+
}
|
41 |
+
|
42 |
+
# Instantiate backbone and reassemble blocks
|
43 |
+
self.pretrained, self.scratch = _make_encoder(
|
44 |
+
backbone,
|
45 |
+
features,
|
46 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
47 |
+
groups=1,
|
48 |
+
expand=False,
|
49 |
+
exportable=False,
|
50 |
+
hooks=hooks[backbone],
|
51 |
+
use_readout=readout,
|
52 |
+
)
|
53 |
+
|
54 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
55 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
56 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
57 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
58 |
+
|
59 |
+
self.scratch.output_conv = head
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
if self.channels_last is True:
|
63 |
+
x.contiguous(memory_format=torch.channels_last)
|
64 |
+
|
65 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
66 |
+
|
67 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
68 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
69 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
70 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
71 |
+
|
72 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
73 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
74 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
75 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
76 |
+
|
77 |
+
out = self.scratch.output_conv(path_1)
|
78 |
+
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
class DPTDepthModel(DPT):
|
83 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
84 |
+
features = kwargs['features'] if 'features' in kwargs else 256
|
85 |
+
|
86 |
+
head = nn.Sequential(
|
87 |
+
nn.Conv2d(features,
|
88 |
+
features // 2,
|
89 |
+
kernel_size=3,
|
90 |
+
stride=1,
|
91 |
+
padding=1),
|
92 |
+
Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
|
93 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
94 |
+
nn.ReLU(True),
|
95 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
96 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
97 |
+
nn.Identity(),
|
98 |
+
)
|
99 |
+
|
100 |
+
super().__init__(head, **kwargs)
|
101 |
+
|
102 |
+
if path is not None:
|
103 |
+
self.load(path)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
return super().forward(x).squeeze(dim=1)
|
annotator/midas/midas_net.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
3 |
+
This file contains code that is adapted from
|
4 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .base_model import BaseModel
|
10 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
11 |
+
|
12 |
+
|
13 |
+
class MidasNet(BaseModel):
|
14 |
+
"""Network for monocular depth estimation.
|
15 |
+
"""
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print('Loading weights: ', path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(
|
31 |
+
backbone='resnext101_wsl',
|
32 |
+
features=features,
|
33 |
+
use_pretrained=use_pretrained)
|
34 |
+
|
35 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
36 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
37 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
38 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
39 |
+
|
40 |
+
self.scratch.output_conv = nn.Sequential(
|
41 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
42 |
+
Interpolate(scale_factor=2, mode='bilinear'),
|
43 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
44 |
+
nn.ReLU(True),
|
45 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
46 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
47 |
+
)
|
48 |
+
|
49 |
+
if path:
|
50 |
+
self.load(path)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
"""Forward pass.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (tensor): input data (image)
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
tensor: depth
|
60 |
+
"""
|
61 |
+
|
62 |
+
layer_1 = self.pretrained.layer1(x)
|
63 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
64 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
65 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
66 |
+
|
67 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
68 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
69 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
70 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
71 |
+
|
72 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
73 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
74 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
75 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
76 |
+
|
77 |
+
out = self.scratch.output_conv(path_1)
|
78 |
+
|
79 |
+
return torch.squeeze(out, dim=1)
|
annotator/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
3 |
+
This file contains code that is adapted from
|
4 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .base_model import BaseModel
|
10 |
+
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
11 |
+
|
12 |
+
|
13 |
+
class MidasNet_small(BaseModel):
|
14 |
+
"""Network for monocular depth estimation.
|
15 |
+
"""
|
16 |
+
def __init__(self,
|
17 |
+
path=None,
|
18 |
+
features=64,
|
19 |
+
backbone='efficientnet_lite3',
|
20 |
+
non_negative=True,
|
21 |
+
exportable=True,
|
22 |
+
channels_last=False,
|
23 |
+
align_corners=True,
|
24 |
+
blocks={'expand': True}):
|
25 |
+
"""Init.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
path (str, optional): Path to saved model. Defaults to None.
|
29 |
+
features (int, optional): Number of features. Defaults to 256.
|
30 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
31 |
+
"""
|
32 |
+
print('Loading weights: ', path)
|
33 |
+
|
34 |
+
super(MidasNet_small, self).__init__()
|
35 |
+
|
36 |
+
use_pretrained = False if path else True
|
37 |
+
|
38 |
+
self.channels_last = channels_last
|
39 |
+
self.blocks = blocks
|
40 |
+
self.backbone = backbone
|
41 |
+
|
42 |
+
self.groups = 1
|
43 |
+
|
44 |
+
features1 = features
|
45 |
+
features2 = features
|
46 |
+
features3 = features
|
47 |
+
features4 = features
|
48 |
+
self.expand = False
|
49 |
+
if 'expand' in self.blocks and self.blocks['expand'] is True:
|
50 |
+
self.expand = True
|
51 |
+
features1 = features
|
52 |
+
features2 = features * 2
|
53 |
+
features3 = features * 4
|
54 |
+
features4 = features * 8
|
55 |
+
|
56 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone,
|
57 |
+
features,
|
58 |
+
use_pretrained,
|
59 |
+
groups=self.groups,
|
60 |
+
expand=self.expand,
|
61 |
+
exportable=exportable)
|
62 |
+
|
63 |
+
self.scratch.activation = nn.ReLU(False)
|
64 |
+
|
65 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(
|
66 |
+
features4,
|
67 |
+
self.scratch.activation,
|
68 |
+
deconv=False,
|
69 |
+
bn=False,
|
70 |
+
expand=self.expand,
|
71 |
+
align_corners=align_corners)
|
72 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(
|
73 |
+
features3,
|
74 |
+
self.scratch.activation,
|
75 |
+
deconv=False,
|
76 |
+
bn=False,
|
77 |
+
expand=self.expand,
|
78 |
+
align_corners=align_corners)
|
79 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(
|
80 |
+
features2,
|
81 |
+
self.scratch.activation,
|
82 |
+
deconv=False,
|
83 |
+
bn=False,
|
84 |
+
expand=self.expand,
|
85 |
+
align_corners=align_corners)
|
86 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(
|
87 |
+
features1,
|
88 |
+
self.scratch.activation,
|
89 |
+
deconv=False,
|
90 |
+
bn=False,
|
91 |
+
align_corners=align_corners)
|
92 |
+
|
93 |
+
self.scratch.output_conv = nn.Sequential(
|
94 |
+
nn.Conv2d(features,
|
95 |
+
features // 2,
|
96 |
+
kernel_size=3,
|
97 |
+
stride=1,
|
98 |
+
padding=1,
|
99 |
+
groups=self.groups),
|
100 |
+
Interpolate(scale_factor=2, mode='bilinear'),
|
101 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
102 |
+
self.scratch.activation,
|
103 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
104 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
105 |
+
nn.Identity(),
|
106 |
+
)
|
107 |
+
|
108 |
+
if path:
|
109 |
+
self.load(path)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
"""Forward pass.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
x (tensor): input data (image)
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
tensor: depth
|
119 |
+
"""
|
120 |
+
if self.channels_last is True:
|
121 |
+
print('self.channels_last = ', self.channels_last)
|
122 |
+
x.contiguous(memory_format=torch.channels_last)
|
123 |
+
|
124 |
+
layer_1 = self.pretrained.layer1(x)
|
125 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
126 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
127 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
128 |
+
|
129 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
130 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
131 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
132 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
133 |
+
|
134 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
135 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
136 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
137 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
138 |
+
|
139 |
+
out = self.scratch.output_conv(path_1)
|
140 |
+
|
141 |
+
return torch.squeeze(out, dim=1)
|
142 |
+
|
143 |
+
|
144 |
+
def fuse_model(m):
|
145 |
+
prev_previous_type = nn.Identity()
|
146 |
+
prev_previous_name = ''
|
147 |
+
previous_type = nn.Identity()
|
148 |
+
previous_name = ''
|
149 |
+
for name, module in m.named_modules():
|
150 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
|
151 |
+
module) == nn.ReLU:
|
152 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
153 |
+
torch.quantization.fuse_modules(
|
154 |
+
m, [prev_previous_name, previous_name, name], inplace=True)
|
155 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
156 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
157 |
+
torch.quantization.fuse_modules(
|
158 |
+
m, [prev_previous_name, previous_name], inplace=True)
|
159 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
160 |
+
# print("FUSED ", previous_name, name)
|
161 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
162 |
+
|
163 |
+
prev_previous_type = previous_type
|
164 |
+
prev_previous_name = previous_name
|
165 |
+
previous_type = type(module)
|
166 |
+
previous_name = name
|
annotator/midas/transforms.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import math
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
9 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
sample (dict): sample
|
13 |
+
size (tuple): image size
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: new size
|
17 |
+
"""
|
18 |
+
shape = list(sample['disparity'].shape)
|
19 |
+
|
20 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
21 |
+
return sample
|
22 |
+
|
23 |
+
scale = [0, 0]
|
24 |
+
scale[0] = size[0] / shape[0]
|
25 |
+
scale[1] = size[1] / shape[1]
|
26 |
+
|
27 |
+
scale = max(scale)
|
28 |
+
|
29 |
+
shape[0] = math.ceil(scale * shape[0])
|
30 |
+
shape[1] = math.ceil(scale * shape[1])
|
31 |
+
|
32 |
+
# resize
|
33 |
+
sample['image'] = cv2.resize(sample['image'],
|
34 |
+
tuple(shape[::-1]),
|
35 |
+
interpolation=image_interpolation_method)
|
36 |
+
|
37 |
+
sample['disparity'] = cv2.resize(sample['disparity'],
|
38 |
+
tuple(shape[::-1]),
|
39 |
+
interpolation=cv2.INTER_NEAREST)
|
40 |
+
sample['mask'] = cv2.resize(
|
41 |
+
sample['mask'].astype(np.float32),
|
42 |
+
tuple(shape[::-1]),
|
43 |
+
interpolation=cv2.INTER_NEAREST,
|
44 |
+
)
|
45 |
+
sample['mask'] = sample['mask'].astype(bool)
|
46 |
+
|
47 |
+
return tuple(shape)
|
48 |
+
|
49 |
+
|
50 |
+
class Resize(object):
|
51 |
+
"""Resize sample to given size (width, height).
|
52 |
+
"""
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
width,
|
56 |
+
height,
|
57 |
+
resize_target=True,
|
58 |
+
keep_aspect_ratio=False,
|
59 |
+
ensure_multiple_of=1,
|
60 |
+
resize_method='lower_bound',
|
61 |
+
image_interpolation_method=cv2.INTER_AREA,
|
62 |
+
):
|
63 |
+
"""Init.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
width (int): desired output width
|
67 |
+
height (int): desired output height
|
68 |
+
resize_target (bool, optional):
|
69 |
+
True: Resize the full sample (image, mask, target).
|
70 |
+
False: Resize image only.
|
71 |
+
Defaults to True.
|
72 |
+
keep_aspect_ratio (bool, optional):
|
73 |
+
True: Keep the aspect ratio of the input sample.
|
74 |
+
Output sample might not have the given width and height, and
|
75 |
+
resize behaviour depends on the parameter 'resize_method'.
|
76 |
+
Defaults to False.
|
77 |
+
ensure_multiple_of (int, optional):
|
78 |
+
Output width and height is constrained to be multiple of this parameter.
|
79 |
+
Defaults to 1.
|
80 |
+
resize_method (str, optional):
|
81 |
+
"lower_bound": Output will be at least as large as the given size.
|
82 |
+
"upper_bound": Output will be at max as large as the given size. "
|
83 |
+
"(Output size might be smaller than given size.)"
|
84 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
85 |
+
Defaults to "lower_bound".
|
86 |
+
"""
|
87 |
+
self.__width = width
|
88 |
+
self.__height = height
|
89 |
+
|
90 |
+
self.__resize_target = resize_target
|
91 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
92 |
+
self.__multiple_of = ensure_multiple_of
|
93 |
+
self.__resize_method = resize_method
|
94 |
+
self.__image_interpolation_method = image_interpolation_method
|
95 |
+
|
96 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
97 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
98 |
+
|
99 |
+
if max_val is not None and y > max_val:
|
100 |
+
y = (np.floor(x / self.__multiple_of) *
|
101 |
+
self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
if y < min_val:
|
104 |
+
y = (np.ceil(x / self.__multiple_of) *
|
105 |
+
self.__multiple_of).astype(int)
|
106 |
+
|
107 |
+
return y
|
108 |
+
|
109 |
+
def get_size(self, width, height):
|
110 |
+
# determine new height and width
|
111 |
+
scale_height = self.__height / height
|
112 |
+
scale_width = self.__width / width
|
113 |
+
|
114 |
+
if self.__keep_aspect_ratio:
|
115 |
+
if self.__resize_method == 'lower_bound':
|
116 |
+
# scale such that output size is lower bound
|
117 |
+
if scale_width > scale_height:
|
118 |
+
# fit width
|
119 |
+
scale_height = scale_width
|
120 |
+
else:
|
121 |
+
# fit height
|
122 |
+
scale_width = scale_height
|
123 |
+
elif self.__resize_method == 'upper_bound':
|
124 |
+
# scale such that output size is upper bound
|
125 |
+
if scale_width < scale_height:
|
126 |
+
# fit width
|
127 |
+
scale_height = scale_width
|
128 |
+
else:
|
129 |
+
# fit height
|
130 |
+
scale_width = scale_height
|
131 |
+
elif self.__resize_method == 'minimal':
|
132 |
+
# scale as least as possbile
|
133 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
134 |
+
# fit width
|
135 |
+
scale_height = scale_width
|
136 |
+
else:
|
137 |
+
# fit height
|
138 |
+
scale_width = scale_height
|
139 |
+
else:
|
140 |
+
raise ValueError(
|
141 |
+
f'resize_method {self.__resize_method} not implemented')
|
142 |
+
|
143 |
+
if self.__resize_method == 'lower_bound':
|
144 |
+
new_height = self.constrain_to_multiple_of(scale_height * height,
|
145 |
+
min_val=self.__height)
|
146 |
+
new_width = self.constrain_to_multiple_of(scale_width * width,
|
147 |
+
min_val=self.__width)
|
148 |
+
elif self.__resize_method == 'upper_bound':
|
149 |
+
new_height = self.constrain_to_multiple_of(scale_height * height,
|
150 |
+
max_val=self.__height)
|
151 |
+
new_width = self.constrain_to_multiple_of(scale_width * width,
|
152 |
+
max_val=self.__width)
|
153 |
+
elif self.__resize_method == 'minimal':
|
154 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
155 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
156 |
+
else:
|
157 |
+
raise ValueError(
|
158 |
+
f'resize_method {self.__resize_method} not implemented')
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(sample['image'].shape[1],
|
164 |
+
sample['image'].shape[0])
|
165 |
+
|
166 |
+
# resize sample
|
167 |
+
sample['image'] = cv2.resize(
|
168 |
+
sample['image'],
|
169 |
+
(width, height),
|
170 |
+
interpolation=self.__image_interpolation_method,
|
171 |
+
)
|
172 |
+
|
173 |
+
if self.__resize_target:
|
174 |
+
if 'disparity' in sample:
|
175 |
+
sample['disparity'] = cv2.resize(
|
176 |
+
sample['disparity'],
|
177 |
+
(width, height),
|
178 |
+
interpolation=cv2.INTER_NEAREST,
|
179 |
+
)
|
180 |
+
|
181 |
+
if 'depth' in sample:
|
182 |
+
sample['depth'] = cv2.resize(sample['depth'], (width, height),
|
183 |
+
interpolation=cv2.INTER_NEAREST)
|
184 |
+
|
185 |
+
sample['mask'] = cv2.resize(
|
186 |
+
sample['mask'].astype(np.float32),
|
187 |
+
(width, height),
|
188 |
+
interpolation=cv2.INTER_NEAREST,
|
189 |
+
)
|
190 |
+
sample['mask'] = sample['mask'].astype(bool)
|
191 |
+
|
192 |
+
return sample
|
193 |
+
|
194 |
+
|
195 |
+
class NormalizeImage(object):
|
196 |
+
"""Normlize image by given mean and std.
|
197 |
+
"""
|
198 |
+
def __init__(self, mean, std):
|
199 |
+
self.__mean = mean
|
200 |
+
self.__std = std
|
201 |
+
|
202 |
+
def __call__(self, sample):
|
203 |
+
sample['image'] = (sample['image'] - self.__mean) / self.__std
|
204 |
+
|
205 |
+
return sample
|
206 |
+
|
207 |
+
|
208 |
+
class PrepareForNet(object):
|
209 |
+
"""Prepare sample for usage as network input.
|
210 |
+
"""
|
211 |
+
def __init__(self):
|
212 |
+
pass
|
213 |
+
|
214 |
+
def __call__(self, sample):
|
215 |
+
image = np.transpose(sample['image'], (2, 0, 1))
|
216 |
+
sample['image'] = np.ascontiguousarray(image).astype(np.float32)
|
217 |
+
|
218 |
+
if 'mask' in sample:
|
219 |
+
sample['mask'] = sample['mask'].astype(np.float32)
|
220 |
+
sample['mask'] = np.ascontiguousarray(sample['mask'])
|
221 |
+
|
222 |
+
if 'disparity' in sample:
|
223 |
+
disparity = sample['disparity'].astype(np.float32)
|
224 |
+
sample['disparity'] = np.ascontiguousarray(disparity)
|
225 |
+
|
226 |
+
if 'depth' in sample:
|
227 |
+
depth = sample['depth'].astype(np.float32)
|
228 |
+
sample['depth'] = np.ascontiguousarray(depth)
|
229 |
+
|
230 |
+
return sample
|
annotator/midas/utils.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Utils for monoDepth."""
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def read_pfm(path):
|
12 |
+
"""Read pfm file.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
path (str): path to file
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
tuple: (data, scale)
|
19 |
+
"""
|
20 |
+
with open(path, 'rb') as file:
|
21 |
+
|
22 |
+
color = None
|
23 |
+
width = None
|
24 |
+
height = None
|
25 |
+
scale = None
|
26 |
+
endian = None
|
27 |
+
|
28 |
+
header = file.readline().rstrip()
|
29 |
+
if header.decode('ascii') == 'PF':
|
30 |
+
color = True
|
31 |
+
elif header.decode('ascii') == 'Pf':
|
32 |
+
color = False
|
33 |
+
else:
|
34 |
+
raise Exception('Not a PFM file: ' + path)
|
35 |
+
|
36 |
+
dim_match = re.match(r'^(\d+)\s(\d+)\s$',
|
37 |
+
file.readline().decode('ascii'))
|
38 |
+
if dim_match:
|
39 |
+
width, height = list(map(int, dim_match.groups()))
|
40 |
+
else:
|
41 |
+
raise Exception('Malformed PFM header.')
|
42 |
+
|
43 |
+
scale = float(file.readline().decode('ascii').rstrip())
|
44 |
+
if scale < 0:
|
45 |
+
# little-endian
|
46 |
+
endian = '<'
|
47 |
+
scale = -scale
|
48 |
+
else:
|
49 |
+
# big-endian
|
50 |
+
endian = '>'
|
51 |
+
|
52 |
+
data = np.fromfile(file, endian + 'f')
|
53 |
+
shape = (height, width, 3) if color else (height, width)
|
54 |
+
|
55 |
+
data = np.reshape(data, shape)
|
56 |
+
data = np.flipud(data)
|
57 |
+
|
58 |
+
return data, scale
|
59 |
+
|
60 |
+
|
61 |
+
def write_pfm(path, image, scale=1):
|
62 |
+
"""Write pfm file.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
path (str): pathto file
|
66 |
+
image (array): data
|
67 |
+
scale (int, optional): Scale. Defaults to 1.
|
68 |
+
"""
|
69 |
+
|
70 |
+
with open(path, 'wb') as file:
|
71 |
+
color = None
|
72 |
+
|
73 |
+
if image.dtype.name != 'float32':
|
74 |
+
raise Exception('Image dtype must be float32.')
|
75 |
+
|
76 |
+
image = np.flipud(image)
|
77 |
+
|
78 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
79 |
+
color = True
|
80 |
+
elif (len(image.shape) == 2
|
81 |
+
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
|
82 |
+
color = False
|
83 |
+
else:
|
84 |
+
raise Exception(
|
85 |
+
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
86 |
+
|
87 |
+
file.write('PF\n' if color else 'Pf\n'.encode())
|
88 |
+
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
|
89 |
+
|
90 |
+
endian = image.dtype.byteorder
|
91 |
+
|
92 |
+
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
93 |
+
scale = -scale
|
94 |
+
|
95 |
+
file.write('%f\n'.encode() % scale)
|
96 |
+
|
97 |
+
image.tofile(file)
|
98 |
+
|
99 |
+
|
100 |
+
def read_image(path):
|
101 |
+
"""Read image and output RGB image (0-1).
|
102 |
+
|
103 |
+
Args:
|
104 |
+
path (str): path to file
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
array: RGB image (0-1)
|
108 |
+
"""
|
109 |
+
img = cv2.imread(path)
|
110 |
+
|
111 |
+
if img.ndim == 2:
|
112 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
113 |
+
|
114 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
115 |
+
|
116 |
+
return img
|
117 |
+
|
118 |
+
|
119 |
+
def resize_image(img):
|
120 |
+
"""Resize image and make it fit for network.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
img (array): image
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
tensor: data ready for network
|
127 |
+
"""
|
128 |
+
height_orig = img.shape[0]
|
129 |
+
width_orig = img.shape[1]
|
130 |
+
|
131 |
+
if width_orig > height_orig:
|
132 |
+
scale = width_orig / 384
|
133 |
+
else:
|
134 |
+
scale = height_orig / 384
|
135 |
+
|
136 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
137 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
138 |
+
|
139 |
+
img_resized = cv2.resize(img, (width, height),
|
140 |
+
interpolation=cv2.INTER_AREA)
|
141 |
+
|
142 |
+
img_resized = (torch.from_numpy(np.transpose(
|
143 |
+
img_resized, (2, 0, 1))).contiguous().float())
|
144 |
+
img_resized = img_resized.unsqueeze(0)
|
145 |
+
|
146 |
+
return img_resized
|
147 |
+
|
148 |
+
|
149 |
+
def resize_depth(depth, width, height):
|
150 |
+
"""Resize depth map and bring to CPU (numpy).
|
151 |
+
|
152 |
+
Args:
|
153 |
+
depth (tensor): depth
|
154 |
+
width (int): image width
|
155 |
+
height (int): image height
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
array: processed depth
|
159 |
+
"""
|
160 |
+
depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
|
161 |
+
|
162 |
+
depth_resized = cv2.resize(depth.numpy(), (width, height),
|
163 |
+
interpolation=cv2.INTER_CUBIC)
|
164 |
+
|
165 |
+
return depth_resized
|
166 |
+
|
167 |
+
|
168 |
+
def write_depth(path, depth, bits=1):
|
169 |
+
"""Write depth map to pfm and png file.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
path (str): filepath without extension
|
173 |
+
depth (array): depth
|
174 |
+
"""
|
175 |
+
write_pfm(path + '.pfm', depth.astype(np.float32))
|
176 |
+
|
177 |
+
depth_min = depth.min()
|
178 |
+
depth_max = depth.max()
|
179 |
+
|
180 |
+
max_val = (2**(8 * bits)) - 1
|
181 |
+
|
182 |
+
if depth_max - depth_min > np.finfo('float').eps:
|
183 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
184 |
+
else:
|
185 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
186 |
+
|
187 |
+
if bits == 1:
|
188 |
+
cv2.imwrite(path + '.png', out.astype('uint8'))
|
189 |
+
elif bits == 2:
|
190 |
+
cv2.imwrite(path + '.png', out.astype('uint16'))
|
191 |
+
|
192 |
+
return
|
annotator/midas/vit.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import math
|
3 |
+
import types
|
4 |
+
|
5 |
+
import timm
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class Slice(nn.Module):
|
12 |
+
def __init__(self, start_index=1):
|
13 |
+
super(Slice, self).__init__()
|
14 |
+
self.start_index = start_index
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return x[:, self.start_index:]
|
18 |
+
|
19 |
+
|
20 |
+
class AddReadout(nn.Module):
|
21 |
+
def __init__(self, start_index=1):
|
22 |
+
super(AddReadout, self).__init__()
|
23 |
+
self.start_index = start_index
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
if self.start_index == 2:
|
27 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
28 |
+
else:
|
29 |
+
readout = x[:, 0]
|
30 |
+
return x[:, self.start_index:] + readout.unsqueeze(1)
|
31 |
+
|
32 |
+
|
33 |
+
class ProjectReadout(nn.Module):
|
34 |
+
def __init__(self, in_features, start_index=1):
|
35 |
+
super(ProjectReadout, self).__init__()
|
36 |
+
self.start_index = start_index
|
37 |
+
|
38 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
|
39 |
+
nn.GELU())
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
|
43 |
+
features = torch.cat((x[:, self.start_index:], readout), -1)
|
44 |
+
|
45 |
+
return self.project(features)
|
46 |
+
|
47 |
+
|
48 |
+
class Transpose(nn.Module):
|
49 |
+
def __init__(self, dim0, dim1):
|
50 |
+
super(Transpose, self).__init__()
|
51 |
+
self.dim0 = dim0
|
52 |
+
self.dim1 = dim1
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = x.transpose(self.dim0, self.dim1)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def forward_vit(pretrained, x):
|
60 |
+
b, c, h, w = x.shape
|
61 |
+
|
62 |
+
_ = pretrained.model.forward_flex(x)
|
63 |
+
|
64 |
+
layer_1 = pretrained.activations['1']
|
65 |
+
layer_2 = pretrained.activations['2']
|
66 |
+
layer_3 = pretrained.activations['3']
|
67 |
+
layer_4 = pretrained.activations['4']
|
68 |
+
|
69 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
70 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
71 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
72 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
73 |
+
|
74 |
+
unflatten = nn.Sequential(
|
75 |
+
nn.Unflatten(
|
76 |
+
2,
|
77 |
+
torch.Size([
|
78 |
+
h // pretrained.model.patch_size[1],
|
79 |
+
w // pretrained.model.patch_size[0],
|
80 |
+
]),
|
81 |
+
))
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
|
93 |
+
layer_1)
|
94 |
+
layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
|
95 |
+
layer_2)
|
96 |
+
layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
|
97 |
+
layer_3)
|
98 |
+
layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
|
99 |
+
layer_4)
|
100 |
+
|
101 |
+
return layer_1, layer_2, layer_3, layer_4
|
102 |
+
|
103 |
+
|
104 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
105 |
+
posemb_tok, posemb_grid = (
|
106 |
+
posemb[:, :self.start_index],
|
107 |
+
posemb[0, self.start_index:],
|
108 |
+
)
|
109 |
+
|
110 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
111 |
+
|
112 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
113 |
+
-1).permute(0, 3, 1, 2)
|
114 |
+
posemb_grid = F.interpolate(posemb_grid,
|
115 |
+
size=(gs_h, gs_w),
|
116 |
+
mode='bilinear')
|
117 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
118 |
+
|
119 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
120 |
+
|
121 |
+
return posemb
|
122 |
+
|
123 |
+
|
124 |
+
def forward_flex(self, x):
|
125 |
+
b, c, h, w = x.shape
|
126 |
+
|
127 |
+
pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
|
128 |
+
w // self.patch_size[0])
|
129 |
+
|
130 |
+
B = x.shape[0]
|
131 |
+
|
132 |
+
if hasattr(self.patch_embed, 'backbone'):
|
133 |
+
x = self.patch_embed.backbone(x)
|
134 |
+
if isinstance(x, (list, tuple)):
|
135 |
+
x = x[
|
136 |
+
-1] # last feature if backbone outputs list/tuple of features
|
137 |
+
|
138 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
139 |
+
|
140 |
+
if getattr(self, 'dist_token', None) is not None:
|
141 |
+
cls_tokens = self.cls_token.expand(
|
142 |
+
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
144 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
145 |
+
else:
|
146 |
+
cls_tokens = self.cls_token.expand(
|
147 |
+
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
148 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
149 |
+
|
150 |
+
x = x + pos_embed
|
151 |
+
x = self.pos_drop(x)
|
152 |
+
|
153 |
+
for blk in self.blocks:
|
154 |
+
x = blk(x)
|
155 |
+
|
156 |
+
x = self.norm(x)
|
157 |
+
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
activations = {}
|
162 |
+
|
163 |
+
|
164 |
+
def get_activation(name):
|
165 |
+
def hook(model, input, output):
|
166 |
+
activations[name] = output
|
167 |
+
|
168 |
+
return hook
|
169 |
+
|
170 |
+
|
171 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
172 |
+
if use_readout == 'ignore':
|
173 |
+
readout_oper = [Slice(start_index)] * len(features)
|
174 |
+
elif use_readout == 'add':
|
175 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
176 |
+
elif use_readout == 'project':
|
177 |
+
readout_oper = [
|
178 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
179 |
+
]
|
180 |
+
else:
|
181 |
+
assert (
|
182 |
+
False
|
183 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
184 |
+
|
185 |
+
return readout_oper
|
186 |
+
|
187 |
+
|
188 |
+
def _make_vit_b16_backbone(
|
189 |
+
model,
|
190 |
+
features=[96, 192, 384, 768],
|
191 |
+
size=[384, 384],
|
192 |
+
hooks=[2, 5, 8, 11],
|
193 |
+
vit_features=768,
|
194 |
+
use_readout='ignore',
|
195 |
+
start_index=1,
|
196 |
+
):
|
197 |
+
pretrained = nn.Module()
|
198 |
+
|
199 |
+
pretrained.model = model
|
200 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(
|
201 |
+
get_activation('1'))
|
202 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(
|
203 |
+
get_activation('2'))
|
204 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(
|
205 |
+
get_activation('3'))
|
206 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(
|
207 |
+
get_activation('4'))
|
208 |
+
|
209 |
+
pretrained.activations = activations
|
210 |
+
|
211 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout,
|
212 |
+
start_index)
|
213 |
+
|
214 |
+
# 32, 48, 136, 384
|
215 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
216 |
+
readout_oper[0],
|
217 |
+
Transpose(1, 2),
|
218 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
219 |
+
nn.Conv2d(
|
220 |
+
in_channels=vit_features,
|
221 |
+
out_channels=features[0],
|
222 |
+
kernel_size=1,
|
223 |
+
stride=1,
|
224 |
+
padding=0,
|
225 |
+
),
|
226 |
+
nn.ConvTranspose2d(
|
227 |
+
in_channels=features[0],
|
228 |
+
out_channels=features[0],
|
229 |
+
kernel_size=4,
|
230 |
+
stride=4,
|
231 |
+
padding=0,
|
232 |
+
bias=True,
|
233 |
+
dilation=1,
|
234 |
+
groups=1,
|
235 |
+
),
|
236 |
+
)
|
237 |
+
|
238 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
239 |
+
readout_oper[1],
|
240 |
+
Transpose(1, 2),
|
241 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
242 |
+
nn.Conv2d(
|
243 |
+
in_channels=vit_features,
|
244 |
+
out_channels=features[1],
|
245 |
+
kernel_size=1,
|
246 |
+
stride=1,
|
247 |
+
padding=0,
|
248 |
+
),
|
249 |
+
nn.ConvTranspose2d(
|
250 |
+
in_channels=features[1],
|
251 |
+
out_channels=features[1],
|
252 |
+
kernel_size=2,
|
253 |
+
stride=2,
|
254 |
+
padding=0,
|
255 |
+
bias=True,
|
256 |
+
dilation=1,
|
257 |
+
groups=1,
|
258 |
+
),
|
259 |
+
)
|
260 |
+
|
261 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
262 |
+
readout_oper[2],
|
263 |
+
Transpose(1, 2),
|
264 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
265 |
+
nn.Conv2d(
|
266 |
+
in_channels=vit_features,
|
267 |
+
out_channels=features[2],
|
268 |
+
kernel_size=1,
|
269 |
+
stride=1,
|
270 |
+
padding=0,
|
271 |
+
),
|
272 |
+
)
|
273 |
+
|
274 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
275 |
+
readout_oper[3],
|
276 |
+
Transpose(1, 2),
|
277 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
278 |
+
nn.Conv2d(
|
279 |
+
in_channels=vit_features,
|
280 |
+
out_channels=features[3],
|
281 |
+
kernel_size=1,
|
282 |
+
stride=1,
|
283 |
+
padding=0,
|
284 |
+
),
|
285 |
+
nn.Conv2d(
|
286 |
+
in_channels=features[3],
|
287 |
+
out_channels=features[3],
|
288 |
+
kernel_size=3,
|
289 |
+
stride=2,
|
290 |
+
padding=1,
|
291 |
+
),
|
292 |
+
)
|
293 |
+
|
294 |
+
pretrained.model.start_index = start_index
|
295 |
+
pretrained.model.patch_size = [16, 16]
|
296 |
+
|
297 |
+
# We inject this function into the VisionTransformer instances so that
|
298 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
299 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
300 |
+
pretrained.model)
|
301 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
302 |
+
_resize_pos_embed, pretrained.model)
|
303 |
+
|
304 |
+
return pretrained
|
305 |
+
|
306 |
+
|
307 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout='ignore', hooks=None):
|
308 |
+
model = timm.create_model('vit_large_patch16_384', pretrained=pretrained)
|
309 |
+
|
310 |
+
hooks = [5, 11, 17, 23] if hooks is None else hooks
|
311 |
+
return _make_vit_b16_backbone(
|
312 |
+
model,
|
313 |
+
features=[256, 512, 1024, 1024],
|
314 |
+
hooks=hooks,
|
315 |
+
vit_features=1024,
|
316 |
+
use_readout=use_readout,
|
317 |
+
)
|
318 |
+
|
319 |
+
|
320 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout='ignore', hooks=None):
|
321 |
+
model = timm.create_model('vit_base_patch16_384', pretrained=pretrained)
|
322 |
+
|
323 |
+
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
324 |
+
return _make_vit_b16_backbone(model,
|
325 |
+
features=[96, 192, 384, 768],
|
326 |
+
hooks=hooks,
|
327 |
+
use_readout=use_readout)
|
328 |
+
|
329 |
+
|
330 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout='ignore', hooks=None):
|
331 |
+
model = timm.create_model('vit_deit_base_patch16_384',
|
332 |
+
pretrained=pretrained)
|
333 |
+
|
334 |
+
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
335 |
+
return _make_vit_b16_backbone(model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout)
|
339 |
+
|
340 |
+
|
341 |
+
def _make_pretrained_deitb16_distil_384(pretrained,
|
342 |
+
use_readout='ignore',
|
343 |
+
hooks=None):
|
344 |
+
model = timm.create_model('vit_deit_base_distilled_patch16_384',
|
345 |
+
pretrained=pretrained)
|
346 |
+
|
347 |
+
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
348 |
+
return _make_vit_b16_backbone(
|
349 |
+
model,
|
350 |
+
features=[96, 192, 384, 768],
|
351 |
+
hooks=hooks,
|
352 |
+
use_readout=use_readout,
|
353 |
+
start_index=2,
|
354 |
+
)
|
355 |
+
|
356 |
+
|
357 |
+
def _make_vit_b_rn50_backbone(
|
358 |
+
model,
|
359 |
+
features=[256, 512, 768, 768],
|
360 |
+
size=[384, 384],
|
361 |
+
hooks=[0, 1, 8, 11],
|
362 |
+
vit_features=768,
|
363 |
+
use_vit_only=False,
|
364 |
+
use_readout='ignore',
|
365 |
+
start_index=1,
|
366 |
+
):
|
367 |
+
pretrained = nn.Module()
|
368 |
+
|
369 |
+
pretrained.model = model
|
370 |
+
|
371 |
+
if use_vit_only is True:
|
372 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(
|
373 |
+
get_activation('1'))
|
374 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(
|
375 |
+
get_activation('2'))
|
376 |
+
else:
|
377 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
378 |
+
get_activation('1'))
|
379 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
380 |
+
get_activation('2'))
|
381 |
+
|
382 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(
|
383 |
+
get_activation('3'))
|
384 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(
|
385 |
+
get_activation('4'))
|
386 |
+
|
387 |
+
pretrained.activations = activations
|
388 |
+
|
389 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout,
|
390 |
+
start_index)
|
391 |
+
|
392 |
+
if use_vit_only is True:
|
393 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
394 |
+
readout_oper[0],
|
395 |
+
Transpose(1, 2),
|
396 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
397 |
+
nn.Conv2d(
|
398 |
+
in_channels=vit_features,
|
399 |
+
out_channels=features[0],
|
400 |
+
kernel_size=1,
|
401 |
+
stride=1,
|
402 |
+
padding=0,
|
403 |
+
),
|
404 |
+
nn.ConvTranspose2d(
|
405 |
+
in_channels=features[0],
|
406 |
+
out_channels=features[0],
|
407 |
+
kernel_size=4,
|
408 |
+
stride=4,
|
409 |
+
padding=0,
|
410 |
+
bias=True,
|
411 |
+
dilation=1,
|
412 |
+
groups=1,
|
413 |
+
),
|
414 |
+
)
|
415 |
+
|
416 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
417 |
+
readout_oper[1],
|
418 |
+
Transpose(1, 2),
|
419 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
420 |
+
nn.Conv2d(
|
421 |
+
in_channels=vit_features,
|
422 |
+
out_channels=features[1],
|
423 |
+
kernel_size=1,
|
424 |
+
stride=1,
|
425 |
+
padding=0,
|
426 |
+
),
|
427 |
+
nn.ConvTranspose2d(
|
428 |
+
in_channels=features[1],
|
429 |
+
out_channels=features[1],
|
430 |
+
kernel_size=2,
|
431 |
+
stride=2,
|
432 |
+
padding=0,
|
433 |
+
bias=True,
|
434 |
+
dilation=1,
|
435 |
+
groups=1,
|
436 |
+
),
|
437 |
+
)
|
438 |
+
else:
|
439 |
+
pretrained.act_postprocess1 = nn.Sequential(nn.Identity(),
|
440 |
+
nn.Identity(),
|
441 |
+
nn.Identity())
|
442 |
+
pretrained.act_postprocess2 = nn.Sequential(nn.Identity(),
|
443 |
+
nn.Identity(),
|
444 |
+
nn.Identity())
|
445 |
+
|
446 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
447 |
+
readout_oper[2],
|
448 |
+
Transpose(1, 2),
|
449 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
450 |
+
nn.Conv2d(
|
451 |
+
in_channels=vit_features,
|
452 |
+
out_channels=features[2],
|
453 |
+
kernel_size=1,
|
454 |
+
stride=1,
|
455 |
+
padding=0,
|
456 |
+
),
|
457 |
+
)
|
458 |
+
|
459 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
460 |
+
readout_oper[3],
|
461 |
+
Transpose(1, 2),
|
462 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
463 |
+
nn.Conv2d(
|
464 |
+
in_channels=vit_features,
|
465 |
+
out_channels=features[3],
|
466 |
+
kernel_size=1,
|
467 |
+
stride=1,
|
468 |
+
padding=0,
|
469 |
+
),
|
470 |
+
nn.Conv2d(
|
471 |
+
in_channels=features[3],
|
472 |
+
out_channels=features[3],
|
473 |
+
kernel_size=3,
|
474 |
+
stride=2,
|
475 |
+
padding=1,
|
476 |
+
),
|
477 |
+
)
|
478 |
+
|
479 |
+
pretrained.model.start_index = start_index
|
480 |
+
pretrained.model.patch_size = [16, 16]
|
481 |
+
|
482 |
+
# We inject this function into the VisionTransformer instances so that
|
483 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
484 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
485 |
+
pretrained.model)
|
486 |
+
|
487 |
+
# We inject this function into the VisionTransformer instances so that
|
488 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
489 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
490 |
+
_resize_pos_embed, pretrained.model)
|
491 |
+
|
492 |
+
return pretrained
|
493 |
+
|
494 |
+
|
495 |
+
def _make_pretrained_vitb_rn50_384(pretrained,
|
496 |
+
use_readout='ignore',
|
497 |
+
hooks=None,
|
498 |
+
use_vit_only=False):
|
499 |
+
model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained)
|
500 |
+
|
501 |
+
hooks = [0, 1, 8, 11] if hooks is None else hooks
|
502 |
+
return _make_vit_b_rn50_backbone(
|
503 |
+
model,
|
504 |
+
features=[256, 512, 768, 768],
|
505 |
+
size=[384, 384],
|
506 |
+
hooks=hooks,
|
507 |
+
use_vit_only=use_vit_only,
|
508 |
+
use_readout=use_readout,
|
509 |
+
)
|
annotator/midas_op.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
# Midas Depth Estimation
|
4 |
+
# From https://github.com/isl-org/MiDaS
|
5 |
+
# MIT LICENSE
|
6 |
+
from abc import ABCMeta
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from einops import rearrange
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
14 |
+
from scepter.modules.annotator.midas.api import MiDaSInference
|
15 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
16 |
+
from scepter.modules.annotator.utils import resize_image, resize_image_ori
|
17 |
+
from scepter.modules.utils.config import dict_to_yaml
|
18 |
+
from scepter.modules.utils.distribute import we
|
19 |
+
from scepter.modules.utils.file_system import FS
|
20 |
+
|
21 |
+
|
22 |
+
@ANNOTATORS.register_class()
|
23 |
+
class MidasDetector(BaseAnnotator, metaclass=ABCMeta):
|
24 |
+
def __init__(self, cfg, logger=None):
|
25 |
+
super().__init__(cfg, logger=logger)
|
26 |
+
pretrained_model = cfg.get('PRETRAINED_MODEL', None)
|
27 |
+
if pretrained_model:
|
28 |
+
with FS.get_from(pretrained_model, wait_finish=True) as local_path:
|
29 |
+
self.model = MiDaSInference(model_type='dpt_hybrid',
|
30 |
+
model_path=local_path)
|
31 |
+
self.a = cfg.get('A', np.pi * 2.0)
|
32 |
+
self.bg_th = cfg.get('BG_TH', 0.1)
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
@torch.inference_mode()
|
36 |
+
@torch.autocast('cuda', enabled=False)
|
37 |
+
def forward(self, image):
|
38 |
+
if isinstance(image, Image.Image):
|
39 |
+
image = np.array(image)
|
40 |
+
elif isinstance(image, torch.Tensor):
|
41 |
+
image = image.detach().cpu().numpy()
|
42 |
+
elif isinstance(image, np.ndarray):
|
43 |
+
image = image.copy()
|
44 |
+
else:
|
45 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
46 |
+
image_depth = image
|
47 |
+
h, w, c = image.shape
|
48 |
+
image_depth, k = resize_image(image_depth,
|
49 |
+
1024 if min(h, w) > 1024 else min(h, w))
|
50 |
+
image_depth = torch.from_numpy(image_depth).float().to(we.device_id)
|
51 |
+
image_depth = image_depth / 127.5 - 1.0
|
52 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
53 |
+
depth = self.model(image_depth)[0]
|
54 |
+
|
55 |
+
depth_pt = depth.clone()
|
56 |
+
depth_pt -= torch.min(depth_pt)
|
57 |
+
depth_pt /= torch.max(depth_pt)
|
58 |
+
depth_pt = depth_pt.cpu().numpy()
|
59 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
60 |
+
depth_image = depth_image[..., None].repeat(3, 2)
|
61 |
+
|
62 |
+
# depth_np = depth.cpu().numpy() # float16 error
|
63 |
+
# x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
64 |
+
# y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
65 |
+
# z = np.ones_like(x) * self.a
|
66 |
+
# x[depth_pt < self.bg_th] = 0
|
67 |
+
# y[depth_pt < self.bg_th] = 0
|
68 |
+
# normal = np.stack([x, y, z], axis=2)
|
69 |
+
# normal /= np.sum(normal**2.0, axis=2, keepdims=True)**0.5
|
70 |
+
# normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
71 |
+
depth_image = resize_image_ori(h, w, depth_image, k)
|
72 |
+
return depth_image
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def get_config_template():
|
76 |
+
return dict_to_yaml('ANNOTATORS',
|
77 |
+
__class__.__name__,
|
78 |
+
MidasDetector.para_dict,
|
79 |
+
set_name=True)
|
annotator/mlsd/__init__.py
ADDED
File without changes
|
annotator/mlsd/mbv2_mlsd_large.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class BlockTypeA(nn.Module):
|
10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale=True):
|
11 |
+
super(BlockTypeA, self).__init__()
|
12 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
13 |
+
nn.BatchNorm2d(out_c2),
|
14 |
+
nn.ReLU(inplace=True))
|
15 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
16 |
+
nn.BatchNorm2d(out_c1),
|
17 |
+
nn.ReLU(inplace=True))
|
18 |
+
self.upscale = upscale
|
19 |
+
|
20 |
+
def forward(self, a, b):
|
21 |
+
b = self.conv1(b)
|
22 |
+
a = self.conv2(a)
|
23 |
+
if self.upscale:
|
24 |
+
b = F.interpolate(b,
|
25 |
+
scale_factor=2.0,
|
26 |
+
mode='bilinear',
|
27 |
+
align_corners=True)
|
28 |
+
return torch.cat((a, b), dim=1)
|
29 |
+
|
30 |
+
|
31 |
+
class BlockTypeB(nn.Module):
|
32 |
+
def __init__(self, in_c, out_c):
|
33 |
+
super(BlockTypeB, self).__init__()
|
34 |
+
self.conv1 = nn.Sequential(
|
35 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
36 |
+
nn.BatchNorm2d(in_c), nn.ReLU())
|
37 |
+
self.conv2 = nn.Sequential(
|
38 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
39 |
+
nn.BatchNorm2d(out_c), nn.ReLU())
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.conv1(x) + x
|
43 |
+
x = self.conv2(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class BlockTypeC(nn.Module):
|
48 |
+
def __init__(self, in_c, out_c):
|
49 |
+
super(BlockTypeC, self).__init__()
|
50 |
+
self.conv1 = nn.Sequential(
|
51 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
52 |
+
nn.BatchNorm2d(in_c), nn.ReLU())
|
53 |
+
self.conv2 = nn.Sequential(
|
54 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
55 |
+
nn.BatchNorm2d(in_c), nn.ReLU())
|
56 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = self.conv1(x)
|
60 |
+
x = self.conv2(x)
|
61 |
+
x = self.conv3(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
def _make_divisible(v, divisor, min_value=None):
|
66 |
+
"""
|
67 |
+
This function is taken from the original tf repo.
|
68 |
+
It ensures that all layers have a channel number that is divisible by 8
|
69 |
+
It can be seen here:
|
70 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
71 |
+
:param v:
|
72 |
+
:param divisor:
|
73 |
+
:param min_value:
|
74 |
+
:return:
|
75 |
+
"""
|
76 |
+
if min_value is None:
|
77 |
+
min_value = divisor
|
78 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
79 |
+
# Make sure that round down does not go down by more than 10%.
|
80 |
+
if new_v < 0.9 * v:
|
81 |
+
new_v += divisor
|
82 |
+
return new_v
|
83 |
+
|
84 |
+
|
85 |
+
class ConvBNReLU(nn.Sequential):
|
86 |
+
def __init__(self,
|
87 |
+
in_planes,
|
88 |
+
out_planes,
|
89 |
+
kernel_size=3,
|
90 |
+
stride=1,
|
91 |
+
groups=1):
|
92 |
+
self.channel_pad = out_planes - in_planes
|
93 |
+
self.stride = stride
|
94 |
+
# padding = (kernel_size - 1) // 2
|
95 |
+
|
96 |
+
# TFLite uses slightly different padding than PyTorch
|
97 |
+
if stride == 2:
|
98 |
+
padding = 0
|
99 |
+
else:
|
100 |
+
padding = (kernel_size - 1) // 2
|
101 |
+
|
102 |
+
super(ConvBNReLU, self).__init__(
|
103 |
+
nn.Conv2d(in_planes,
|
104 |
+
out_planes,
|
105 |
+
kernel_size,
|
106 |
+
stride,
|
107 |
+
padding,
|
108 |
+
groups=groups,
|
109 |
+
bias=False), nn.BatchNorm2d(out_planes),
|
110 |
+
nn.ReLU6(inplace=True))
|
111 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
# TFLite uses different padding
|
115 |
+
if self.stride == 2:
|
116 |
+
x = F.pad(x, (0, 1, 0, 1), 'constant', 0)
|
117 |
+
# print(x.shape)
|
118 |
+
|
119 |
+
for module in self:
|
120 |
+
if not isinstance(module, nn.MaxPool2d):
|
121 |
+
x = module(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class InvertedResidual(nn.Module):
|
126 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
127 |
+
super(InvertedResidual, self).__init__()
|
128 |
+
self.stride = stride
|
129 |
+
assert stride in [1, 2]
|
130 |
+
|
131 |
+
hidden_dim = int(round(inp * expand_ratio))
|
132 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
133 |
+
|
134 |
+
layers = []
|
135 |
+
if expand_ratio != 1:
|
136 |
+
# pw
|
137 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
138 |
+
layers.extend([
|
139 |
+
# dw
|
140 |
+
ConvBNReLU(hidden_dim,
|
141 |
+
hidden_dim,
|
142 |
+
stride=stride,
|
143 |
+
groups=hidden_dim),
|
144 |
+
# pw-linear
|
145 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
146 |
+
nn.BatchNorm2d(oup),
|
147 |
+
])
|
148 |
+
self.conv = nn.Sequential(*layers)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
if self.use_res_connect:
|
152 |
+
return x + self.conv(x)
|
153 |
+
else:
|
154 |
+
return self.conv(x)
|
155 |
+
|
156 |
+
|
157 |
+
class MobileNetV2(nn.Module):
|
158 |
+
def __init__(self, pretrained=True):
|
159 |
+
"""
|
160 |
+
MobileNet V2 main class
|
161 |
+
Args:
|
162 |
+
num_classes (int): Number of classes
|
163 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
164 |
+
inverted_residual_setting: Network structure
|
165 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
166 |
+
Set to 1 to turn off rounding
|
167 |
+
block: Module specifying inverted residual building block for mobilenet
|
168 |
+
"""
|
169 |
+
super(MobileNetV2, self).__init__()
|
170 |
+
|
171 |
+
block = InvertedResidual
|
172 |
+
input_channel = 32
|
173 |
+
last_channel = 1280
|
174 |
+
width_mult = 1.0
|
175 |
+
round_nearest = 8
|
176 |
+
|
177 |
+
inverted_residual_setting = [
|
178 |
+
# t, c, n, s
|
179 |
+
[1, 16, 1, 1],
|
180 |
+
[6, 24, 2, 2],
|
181 |
+
[6, 32, 3, 2],
|
182 |
+
[6, 64, 4, 2],
|
183 |
+
[6, 96, 3, 1],
|
184 |
+
# [6, 160, 3, 2],
|
185 |
+
# [6, 320, 1, 1],
|
186 |
+
]
|
187 |
+
|
188 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
189 |
+
if len(inverted_residual_setting) == 0 or len(
|
190 |
+
inverted_residual_setting[0]) != 4:
|
191 |
+
raise ValueError('inverted_residual_setting should be non-empty '
|
192 |
+
'or a 4-element list, got {}'.format(
|
193 |
+
inverted_residual_setting))
|
194 |
+
|
195 |
+
# building first layer
|
196 |
+
input_channel = _make_divisible(input_channel * width_mult,
|
197 |
+
round_nearest)
|
198 |
+
self.last_channel = _make_divisible(
|
199 |
+
last_channel * max(1.0, width_mult), round_nearest)
|
200 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
201 |
+
# building inverted residual blocks
|
202 |
+
for t, c, n, s in inverted_residual_setting:
|
203 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
204 |
+
for i in range(n):
|
205 |
+
stride = s if i == 0 else 1
|
206 |
+
features.append(
|
207 |
+
block(input_channel,
|
208 |
+
output_channel,
|
209 |
+
stride,
|
210 |
+
expand_ratio=t))
|
211 |
+
input_channel = output_channel
|
212 |
+
|
213 |
+
self.features = nn.Sequential(*features)
|
214 |
+
self.fpn_selected = [1, 3, 6, 10, 13]
|
215 |
+
# weight initialization
|
216 |
+
for m in self.modules():
|
217 |
+
if isinstance(m, nn.Conv2d):
|
218 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
219 |
+
if m.bias is not None:
|
220 |
+
nn.init.zeros_(m.bias)
|
221 |
+
elif isinstance(m, nn.BatchNorm2d):
|
222 |
+
nn.init.ones_(m.weight)
|
223 |
+
nn.init.zeros_(m.bias)
|
224 |
+
elif isinstance(m, nn.Linear):
|
225 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
226 |
+
nn.init.zeros_(m.bias)
|
227 |
+
if pretrained:
|
228 |
+
self._load_pretrained_model()
|
229 |
+
|
230 |
+
def _forward_impl(self, x):
|
231 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
232 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
233 |
+
fpn_features = []
|
234 |
+
for i, f in enumerate(self.features):
|
235 |
+
if i > self.fpn_selected[-1]:
|
236 |
+
break
|
237 |
+
x = f(x)
|
238 |
+
if i in self.fpn_selected:
|
239 |
+
fpn_features.append(x)
|
240 |
+
|
241 |
+
c1, c2, c3, c4, c5 = fpn_features
|
242 |
+
return c1, c2, c3, c4, c5
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
return self._forward_impl(x)
|
246 |
+
|
247 |
+
def _load_pretrained_model(self):
|
248 |
+
pretrain_dict = model_zoo.load_url(
|
249 |
+
'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
250 |
+
model_dict = {}
|
251 |
+
state_dict = self.state_dict()
|
252 |
+
for k, v in pretrain_dict.items():
|
253 |
+
if k in state_dict:
|
254 |
+
model_dict[k] = v
|
255 |
+
state_dict.update(model_dict)
|
256 |
+
self.load_state_dict(state_dict)
|
257 |
+
|
258 |
+
|
259 |
+
class MobileV2_MLSD_Large(nn.Module):
|
260 |
+
def __init__(self):
|
261 |
+
super(MobileV2_MLSD_Large, self).__init__()
|
262 |
+
|
263 |
+
self.backbone = MobileNetV2(pretrained=False)
|
264 |
+
# A, B
|
265 |
+
self.block15 = BlockTypeA(in_c1=64,
|
266 |
+
in_c2=96,
|
267 |
+
out_c1=64,
|
268 |
+
out_c2=64,
|
269 |
+
upscale=False)
|
270 |
+
self.block16 = BlockTypeB(128, 64)
|
271 |
+
|
272 |
+
# A, B
|
273 |
+
self.block17 = BlockTypeA(in_c1=32, in_c2=64, out_c1=64, out_c2=64)
|
274 |
+
self.block18 = BlockTypeB(128, 64)
|
275 |
+
|
276 |
+
# A, B
|
277 |
+
self.block19 = BlockTypeA(in_c1=24, in_c2=64, out_c1=64, out_c2=64)
|
278 |
+
self.block20 = BlockTypeB(128, 64)
|
279 |
+
|
280 |
+
# A, B, C
|
281 |
+
self.block21 = BlockTypeA(in_c1=16, in_c2=64, out_c1=64, out_c2=64)
|
282 |
+
self.block22 = BlockTypeB(128, 64)
|
283 |
+
|
284 |
+
self.block23 = BlockTypeC(64, 16)
|
285 |
+
|
286 |
+
def forward(self, x):
|
287 |
+
c1, c2, c3, c4, c5 = self.backbone(x)
|
288 |
+
|
289 |
+
x = self.block15(c4, c5)
|
290 |
+
x = self.block16(x)
|
291 |
+
|
292 |
+
x = self.block17(c3, x)
|
293 |
+
x = self.block18(x)
|
294 |
+
|
295 |
+
x = self.block19(c2, x)
|
296 |
+
x = self.block20(x)
|
297 |
+
|
298 |
+
x = self.block21(c1, x)
|
299 |
+
x = self.block22(x)
|
300 |
+
x = self.block23(x)
|
301 |
+
x = x[:, 7:, :, :]
|
302 |
+
|
303 |
+
return x
|
annotator/mlsd/mbv2_mlsd_tiny.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class BlockTypeA(nn.Module):
|
9 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale=True):
|
10 |
+
super(BlockTypeA, self).__init__()
|
11 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
12 |
+
nn.BatchNorm2d(out_c2),
|
13 |
+
nn.ReLU(inplace=True))
|
14 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
15 |
+
nn.BatchNorm2d(out_c1),
|
16 |
+
nn.ReLU(inplace=True))
|
17 |
+
self.upscale = upscale
|
18 |
+
|
19 |
+
def forward(self, a, b):
|
20 |
+
b = self.conv1(b)
|
21 |
+
a = self.conv2(a)
|
22 |
+
b = F.interpolate(b,
|
23 |
+
scale_factor=2.0,
|
24 |
+
mode='bilinear',
|
25 |
+
align_corners=True)
|
26 |
+
return torch.cat((a, b), dim=1)
|
27 |
+
|
28 |
+
|
29 |
+
class BlockTypeB(nn.Module):
|
30 |
+
def __init__(self, in_c, out_c):
|
31 |
+
super(BlockTypeB, self).__init__()
|
32 |
+
self.conv1 = nn.Sequential(
|
33 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
34 |
+
nn.BatchNorm2d(in_c), nn.ReLU())
|
35 |
+
self.conv2 = nn.Sequential(
|
36 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
37 |
+
nn.BatchNorm2d(out_c), nn.ReLU())
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.conv1(x) + x
|
41 |
+
x = self.conv2(x)
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
class BlockTypeC(nn.Module):
|
46 |
+
def __init__(self, in_c, out_c):
|
47 |
+
super(BlockTypeC, self).__init__()
|
48 |
+
self.conv1 = nn.Sequential(
|
49 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
50 |
+
nn.BatchNorm2d(in_c), nn.ReLU())
|
51 |
+
self.conv2 = nn.Sequential(
|
52 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
53 |
+
nn.BatchNorm2d(in_c), nn.ReLU())
|
54 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
x = self.conv1(x)
|
58 |
+
x = self.conv2(x)
|
59 |
+
x = self.conv3(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
def _make_divisible(v, divisor, min_value=None):
|
64 |
+
"""
|
65 |
+
This function is taken from the original tf repo.
|
66 |
+
It ensures that all layers have a channel number that is divisible by 8
|
67 |
+
It can be seen here:
|
68 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
69 |
+
:param v:
|
70 |
+
:param divisor:
|
71 |
+
:param min_value:
|
72 |
+
:return:
|
73 |
+
"""
|
74 |
+
if min_value is None:
|
75 |
+
min_value = divisor
|
76 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
77 |
+
# Make sure that round down does not go down by more than 10%.
|
78 |
+
if new_v < 0.9 * v:
|
79 |
+
new_v += divisor
|
80 |
+
return new_v
|
81 |
+
|
82 |
+
|
83 |
+
class ConvBNReLU(nn.Sequential):
|
84 |
+
def __init__(self,
|
85 |
+
in_planes,
|
86 |
+
out_planes,
|
87 |
+
kernel_size=3,
|
88 |
+
stride=1,
|
89 |
+
groups=1):
|
90 |
+
self.channel_pad = out_planes - in_planes
|
91 |
+
self.stride = stride
|
92 |
+
# padding = (kernel_size - 1) // 2
|
93 |
+
|
94 |
+
# TFLite uses slightly different padding than PyTorch
|
95 |
+
if stride == 2:
|
96 |
+
padding = 0
|
97 |
+
else:
|
98 |
+
padding = (kernel_size - 1) // 2
|
99 |
+
|
100 |
+
super(ConvBNReLU, self).__init__(
|
101 |
+
nn.Conv2d(in_planes,
|
102 |
+
out_planes,
|
103 |
+
kernel_size,
|
104 |
+
stride,
|
105 |
+
padding,
|
106 |
+
groups=groups,
|
107 |
+
bias=False), nn.BatchNorm2d(out_planes),
|
108 |
+
nn.ReLU6(inplace=True))
|
109 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
# TFLite uses different padding
|
113 |
+
if self.stride == 2:
|
114 |
+
x = F.pad(x, (0, 1, 0, 1), 'constant', 0)
|
115 |
+
# print(x.shape)
|
116 |
+
|
117 |
+
for module in self:
|
118 |
+
if not isinstance(module, nn.MaxPool2d):
|
119 |
+
x = module(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class InvertedResidual(nn.Module):
|
124 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
125 |
+
super(InvertedResidual, self).__init__()
|
126 |
+
self.stride = stride
|
127 |
+
assert stride in [1, 2]
|
128 |
+
|
129 |
+
hidden_dim = int(round(inp * expand_ratio))
|
130 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
131 |
+
|
132 |
+
layers = []
|
133 |
+
if expand_ratio != 1:
|
134 |
+
# pw
|
135 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
136 |
+
layers.extend([
|
137 |
+
# dw
|
138 |
+
ConvBNReLU(hidden_dim,
|
139 |
+
hidden_dim,
|
140 |
+
stride=stride,
|
141 |
+
groups=hidden_dim),
|
142 |
+
# pw-linear
|
143 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
144 |
+
nn.BatchNorm2d(oup),
|
145 |
+
])
|
146 |
+
self.conv = nn.Sequential(*layers)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
if self.use_res_connect:
|
150 |
+
return x + self.conv(x)
|
151 |
+
else:
|
152 |
+
return self.conv(x)
|
153 |
+
|
154 |
+
|
155 |
+
class MobileNetV2(nn.Module):
|
156 |
+
def __init__(self, pretrained=True):
|
157 |
+
"""
|
158 |
+
MobileNet V2 main class
|
159 |
+
Args:
|
160 |
+
num_classes (int): Number of classes
|
161 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
162 |
+
inverted_residual_setting: Network structure
|
163 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
164 |
+
Set to 1 to turn off rounding
|
165 |
+
block: Module specifying inverted residual building block for mobilenet
|
166 |
+
"""
|
167 |
+
super(MobileNetV2, self).__init__()
|
168 |
+
|
169 |
+
block = InvertedResidual
|
170 |
+
input_channel = 32
|
171 |
+
last_channel = 1280
|
172 |
+
width_mult = 1.0
|
173 |
+
round_nearest = 8
|
174 |
+
|
175 |
+
inverted_residual_setting = [
|
176 |
+
# t, c, n, s
|
177 |
+
[1, 16, 1, 1],
|
178 |
+
[6, 24, 2, 2],
|
179 |
+
[6, 32, 3, 2],
|
180 |
+
[6, 64, 4, 2],
|
181 |
+
# [6, 96, 3, 1],
|
182 |
+
# [6, 160, 3, 2],
|
183 |
+
# [6, 320, 1, 1],
|
184 |
+
]
|
185 |
+
|
186 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
187 |
+
if len(inverted_residual_setting) == 0 or len(
|
188 |
+
inverted_residual_setting[0]) != 4:
|
189 |
+
raise ValueError('inverted_residual_setting should be non-empty '
|
190 |
+
'or a 4-element list, got {}'.format(
|
191 |
+
inverted_residual_setting))
|
192 |
+
|
193 |
+
# building first layer
|
194 |
+
input_channel = _make_divisible(input_channel * width_mult,
|
195 |
+
round_nearest)
|
196 |
+
self.last_channel = _make_divisible(
|
197 |
+
last_channel * max(1.0, width_mult), round_nearest)
|
198 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
199 |
+
# building inverted residual blocks
|
200 |
+
for t, c, n, s in inverted_residual_setting:
|
201 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
202 |
+
for i in range(n):
|
203 |
+
stride = s if i == 0 else 1
|
204 |
+
features.append(
|
205 |
+
block(input_channel,
|
206 |
+
output_channel,
|
207 |
+
stride,
|
208 |
+
expand_ratio=t))
|
209 |
+
input_channel = output_channel
|
210 |
+
self.features = nn.Sequential(*features)
|
211 |
+
|
212 |
+
self.fpn_selected = [3, 6, 10]
|
213 |
+
# weight initialization
|
214 |
+
for m in self.modules():
|
215 |
+
if isinstance(m, nn.Conv2d):
|
216 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
217 |
+
if m.bias is not None:
|
218 |
+
nn.init.zeros_(m.bias)
|
219 |
+
elif isinstance(m, nn.BatchNorm2d):
|
220 |
+
nn.init.ones_(m.weight)
|
221 |
+
nn.init.zeros_(m.bias)
|
222 |
+
elif isinstance(m, nn.Linear):
|
223 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
224 |
+
nn.init.zeros_(m.bias)
|
225 |
+
|
226 |
+
# if pretrained:
|
227 |
+
# self._load_pretrained_model()
|
228 |
+
|
229 |
+
def _forward_impl(self, x):
|
230 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
231 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
232 |
+
fpn_features = []
|
233 |
+
for i, f in enumerate(self.features):
|
234 |
+
if i > self.fpn_selected[-1]:
|
235 |
+
break
|
236 |
+
x = f(x)
|
237 |
+
if i in self.fpn_selected:
|
238 |
+
fpn_features.append(x)
|
239 |
+
|
240 |
+
c2, c3, c4 = fpn_features
|
241 |
+
return c2, c3, c4
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
return self._forward_impl(x)
|
245 |
+
|
246 |
+
def _load_pretrained_model(self):
|
247 |
+
pretrain_dict = model_zoo.load_url(
|
248 |
+
'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
249 |
+
model_dict = {}
|
250 |
+
state_dict = self.state_dict()
|
251 |
+
for k, v in pretrain_dict.items():
|
252 |
+
if k in state_dict:
|
253 |
+
model_dict[k] = v
|
254 |
+
state_dict.update(model_dict)
|
255 |
+
self.load_state_dict(state_dict)
|
256 |
+
|
257 |
+
|
258 |
+
class MobileV2_MLSD_Tiny(nn.Module):
|
259 |
+
def __init__(self):
|
260 |
+
super(MobileV2_MLSD_Tiny, self).__init__()
|
261 |
+
|
262 |
+
self.backbone = MobileNetV2(pretrained=True)
|
263 |
+
|
264 |
+
self.block12 = BlockTypeA(in_c1=32, in_c2=64, out_c1=64, out_c2=64)
|
265 |
+
self.block13 = BlockTypeB(128, 64)
|
266 |
+
|
267 |
+
self.block14 = BlockTypeA(in_c1=24, in_c2=64, out_c1=32, out_c2=32)
|
268 |
+
self.block15 = BlockTypeB(64, 64)
|
269 |
+
|
270 |
+
self.block16 = BlockTypeC(64, 16)
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
c2, c3, c4 = self.backbone(x)
|
274 |
+
|
275 |
+
x = self.block12(c3, c4)
|
276 |
+
x = self.block13(x)
|
277 |
+
x = self.block14(c2, x)
|
278 |
+
x = self.block15(x)
|
279 |
+
x = self.block16(x)
|
280 |
+
x = x[:, 7:, :, :]
|
281 |
+
# print(x.shape)
|
282 |
+
x = F.interpolate(x,
|
283 |
+
scale_factor=2.0,
|
284 |
+
mode='bilinear',
|
285 |
+
align_corners=True)
|
286 |
+
|
287 |
+
return x
|
annotator/mlsd/utils.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# modified by lihaoweicv
|
4 |
+
# pytorch version
|
5 |
+
#
|
6 |
+
# M-LSD
|
7 |
+
# Copyright 2021-present NAVER Corp.
|
8 |
+
# Apache License v2.0
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
def deccode_output_score_and_ptss(tpMap, topk_n=200, ksize=5):
|
17 |
+
'''
|
18 |
+
tpMap:
|
19 |
+
center: tpMap[1, 0, :, :]
|
20 |
+
displacement: tpMap[1, 1:5, :, :]
|
21 |
+
'''
|
22 |
+
b, c, h, w = tpMap.shape
|
23 |
+
assert b == 1, 'only support bsize==1'
|
24 |
+
displacement = tpMap[:, 1:5, :, :][0]
|
25 |
+
center = tpMap[:, 0, :, :]
|
26 |
+
heat = torch.sigmoid(center)
|
27 |
+
hmax = F.max_pool2d(heat, (ksize, ksize),
|
28 |
+
stride=1,
|
29 |
+
padding=(ksize - 1) // 2)
|
30 |
+
keep = (hmax == heat).float()
|
31 |
+
heat = heat * keep
|
32 |
+
heat = heat.reshape(-1, )
|
33 |
+
|
34 |
+
scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
|
35 |
+
yy = torch.floor_divide(indices, w).unsqueeze(-1)
|
36 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
37 |
+
ptss = torch.cat((yy, xx), dim=-1)
|
38 |
+
|
39 |
+
ptss = ptss.detach().cpu().numpy()
|
40 |
+
scores = scores.detach().cpu().numpy()
|
41 |
+
displacement = displacement.detach().cpu().numpy()
|
42 |
+
displacement = displacement.transpose((1, 2, 0))
|
43 |
+
return ptss, scores, displacement
|
44 |
+
|
45 |
+
|
46 |
+
def pred_lines(image,
|
47 |
+
model,
|
48 |
+
input_shape=[512, 512],
|
49 |
+
score_thr=0.10,
|
50 |
+
dist_thr=20.0,
|
51 |
+
device='cuda'):
|
52 |
+
h, w, _ = image.shape
|
53 |
+
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
54 |
+
|
55 |
+
resized_image = np.concatenate([
|
56 |
+
cv2.resize(image, (input_shape[1], input_shape[0]),
|
57 |
+
interpolation=cv2.INTER_AREA),
|
58 |
+
np.ones([input_shape[0], input_shape[1], 1])
|
59 |
+
],
|
60 |
+
axis=-1)
|
61 |
+
|
62 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
63 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
64 |
+
batch_image = (batch_image / 127.5) - 1.0
|
65 |
+
|
66 |
+
batch_image = torch.from_numpy(batch_image).float().to(device)
|
67 |
+
outputs = model(batch_image)
|
68 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
69 |
+
start = vmap[:, :, :2]
|
70 |
+
end = vmap[:, :, 2:]
|
71 |
+
dist_map = np.sqrt(np.sum((start - end)**2, axis=-1))
|
72 |
+
|
73 |
+
segments_list = []
|
74 |
+
for center, score in zip(pts, pts_score):
|
75 |
+
y, x = center
|
76 |
+
distance = dist_map[y, x]
|
77 |
+
if score > score_thr and distance > dist_thr:
|
78 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
79 |
+
x_start = x + disp_x_start
|
80 |
+
y_start = y + disp_y_start
|
81 |
+
x_end = x + disp_x_end
|
82 |
+
y_end = y + disp_y_end
|
83 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
84 |
+
|
85 |
+
lines = 2 * np.array(segments_list) # 256 > 512
|
86 |
+
lines[:, 0] = lines[:, 0] * w_ratio
|
87 |
+
lines[:, 1] = lines[:, 1] * h_ratio
|
88 |
+
lines[:, 2] = lines[:, 2] * w_ratio
|
89 |
+
lines[:, 3] = lines[:, 3] * h_ratio
|
90 |
+
|
91 |
+
return lines
|
92 |
+
|
93 |
+
|
94 |
+
def pred_squares(
|
95 |
+
image,
|
96 |
+
model,
|
97 |
+
input_shape=[512, 512],
|
98 |
+
device='cuda',
|
99 |
+
params={
|
100 |
+
'score': 0.06,
|
101 |
+
'outside_ratio': 0.28,
|
102 |
+
'inside_ratio': 0.45,
|
103 |
+
'w_overlap': 0.0,
|
104 |
+
'w_degree': 1.95,
|
105 |
+
'w_length': 0.0,
|
106 |
+
'w_area': 1.86,
|
107 |
+
'w_center': 0.14
|
108 |
+
}): # noqa
|
109 |
+
# shape = [height, width]
|
110 |
+
h, w, _ = image.shape
|
111 |
+
original_shape = [h, w]
|
112 |
+
|
113 |
+
resized_image = np.concatenate([
|
114 |
+
cv2.resize(image, (input_shape[0], input_shape[1]),
|
115 |
+
interpolation=cv2.INTER_AREA),
|
116 |
+
np.ones([input_shape[0], input_shape[1], 1])
|
117 |
+
],
|
118 |
+
axis=-1)
|
119 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
120 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
121 |
+
batch_image = (batch_image / 127.5) - 1.0
|
122 |
+
|
123 |
+
batch_image = torch.from_numpy(batch_image).float().to(device)
|
124 |
+
outputs = model(batch_image)
|
125 |
+
|
126 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
127 |
+
start = vmap[:, :, :2] # (x, y)
|
128 |
+
end = vmap[:, :, 2:] # (x, y)
|
129 |
+
dist_map = np.sqrt(np.sum((start - end)**2, axis=-1))
|
130 |
+
|
131 |
+
junc_list = []
|
132 |
+
segments_list = []
|
133 |
+
for junc, score in zip(pts, pts_score):
|
134 |
+
y, x = junc
|
135 |
+
distance = dist_map[y, x]
|
136 |
+
if score > params['score'] and distance > 20.0:
|
137 |
+
junc_list.append([x, y])
|
138 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
139 |
+
d_arrow = 1.0
|
140 |
+
x_start = x + d_arrow * disp_x_start
|
141 |
+
y_start = y + d_arrow * disp_y_start
|
142 |
+
x_end = x + d_arrow * disp_x_end
|
143 |
+
y_end = y + d_arrow * disp_y_end
|
144 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
145 |
+
|
146 |
+
segments = np.array(segments_list)
|
147 |
+
|
148 |
+
# post processing for squares
|
149 |
+
# 1. get unique lines
|
150 |
+
point = np.array([[0, 0]])
|
151 |
+
point = point[0]
|
152 |
+
start = segments[:, :2]
|
153 |
+
end = segments[:, 2:]
|
154 |
+
diff = start - end
|
155 |
+
a = diff[:, 1]
|
156 |
+
b = -diff[:, 0]
|
157 |
+
c = a * start[:, 0] + b * start[:, 1]
|
158 |
+
|
159 |
+
d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a**2 + b**2 + 1e-10)
|
160 |
+
theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
|
161 |
+
theta[theta < 0.0] += 180
|
162 |
+
hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
|
163 |
+
|
164 |
+
d_quant = 1
|
165 |
+
theta_quant = 2
|
166 |
+
hough[:, 0] //= d_quant
|
167 |
+
hough[:, 1] //= theta_quant
|
168 |
+
_, indices, counts = np.unique(hough,
|
169 |
+
axis=0,
|
170 |
+
return_index=True,
|
171 |
+
return_counts=True)
|
172 |
+
|
173 |
+
acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1],
|
174 |
+
dtype='float32')
|
175 |
+
idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1],
|
176 |
+
dtype='int32') - 1
|
177 |
+
yx_indices = hough[indices, :].astype('int32')
|
178 |
+
acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
|
179 |
+
idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
|
180 |
+
|
181 |
+
acc_map_np = acc_map
|
182 |
+
# acc_map = acc_map[None, :, :, None]
|
183 |
+
#
|
184 |
+
# ### fast suppression using tensorflow op
|
185 |
+
# acc_map = tf.constant(acc_map, dtype=tf.float32)
|
186 |
+
# max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
|
187 |
+
# acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
|
188 |
+
# flatten_acc_map = tf.reshape(acc_map, [1, -1])
|
189 |
+
# topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
|
190 |
+
# _, h, w, _ = acc_map.shape
|
191 |
+
# y = tf.expand_dims(topk_indices // w, axis=-1)
|
192 |
+
# x = tf.expand_dims(topk_indices % w, axis=-1)
|
193 |
+
# yx = tf.concat([y, x], axis=-1)
|
194 |
+
|
195 |
+
# fast suppression using pytorch op
|
196 |
+
acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
|
197 |
+
_, _, h, w = acc_map.shape
|
198 |
+
max_acc_map = F.max_pool2d(acc_map, kernel_size=5, stride=1, padding=2)
|
199 |
+
acc_map = acc_map * ((acc_map == max_acc_map).float())
|
200 |
+
flatten_acc_map = acc_map.reshape([
|
201 |
+
-1,
|
202 |
+
])
|
203 |
+
|
204 |
+
scores, indices = torch.topk(flatten_acc_map,
|
205 |
+
len(pts),
|
206 |
+
dim=-1,
|
207 |
+
largest=True)
|
208 |
+
yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
|
209 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
210 |
+
yx = torch.cat((yy, xx), dim=-1)
|
211 |
+
|
212 |
+
yx = yx.detach().cpu().numpy()
|
213 |
+
|
214 |
+
topk_values = scores.detach().cpu().numpy()
|
215 |
+
indices = idx_map[yx[:, 0], yx[:, 1]]
|
216 |
+
basis = 5 // 2
|
217 |
+
|
218 |
+
merged_segments = []
|
219 |
+
for yx_pt, max_indice, value in zip(yx, indices, topk_values):
|
220 |
+
y, x = yx_pt
|
221 |
+
if max_indice == -1 or value == 0:
|
222 |
+
continue
|
223 |
+
segment_list = []
|
224 |
+
for y_offset in range(-basis, basis + 1):
|
225 |
+
for x_offset in range(-basis, basis + 1):
|
226 |
+
indice = idx_map[y + y_offset, x + x_offset]
|
227 |
+
cnt = int(acc_map_np[y + y_offset, x + x_offset])
|
228 |
+
if indice != -1:
|
229 |
+
segment_list.append(segments[indice])
|
230 |
+
if cnt > 1:
|
231 |
+
check_cnt = 1
|
232 |
+
current_hough = hough[indice]
|
233 |
+
for new_indice, new_hough in enumerate(hough):
|
234 |
+
if (current_hough
|
235 |
+
== new_hough).all() and indice != new_indice:
|
236 |
+
segment_list.append(segments[new_indice])
|
237 |
+
check_cnt += 1
|
238 |
+
if check_cnt == cnt:
|
239 |
+
break
|
240 |
+
group_segments = np.array(segment_list).reshape([-1, 2])
|
241 |
+
sorted_group_segments = np.sort(group_segments, axis=0)
|
242 |
+
x_min, y_min = sorted_group_segments[0, :]
|
243 |
+
x_max, y_max = sorted_group_segments[-1, :]
|
244 |
+
|
245 |
+
deg = theta[max_indice]
|
246 |
+
if deg >= 90:
|
247 |
+
merged_segments.append([x_min, y_max, x_max, y_min])
|
248 |
+
else:
|
249 |
+
merged_segments.append([x_min, y_min, x_max, y_max])
|
250 |
+
|
251 |
+
# 2. get intersections
|
252 |
+
new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
|
253 |
+
start = new_segments[:, :2] # (x1, y1)
|
254 |
+
end = new_segments[:, 2:] # (x2, y2)
|
255 |
+
new_centers = (start + end) / 2.0
|
256 |
+
diff = start - end
|
257 |
+
dist_segments = np.sqrt(np.sum(diff**2, axis=-1))
|
258 |
+
|
259 |
+
# ax + by = c
|
260 |
+
a = diff[:, 1]
|
261 |
+
b = -diff[:, 0]
|
262 |
+
c = a * start[:, 0] + b * start[:, 1]
|
263 |
+
pre_det = a[:, None] * b[None, :]
|
264 |
+
det = pre_det - np.transpose(pre_det)
|
265 |
+
|
266 |
+
pre_inter_y = a[:, None] * c[None, :]
|
267 |
+
inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
|
268 |
+
pre_inter_x = c[:, None] * b[None, :]
|
269 |
+
inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
|
270 |
+
inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]],
|
271 |
+
axis=-1).astype('int32')
|
272 |
+
|
273 |
+
# 3. get corner information
|
274 |
+
# 3.1 get distance
|
275 |
+
'''
|
276 |
+
dist_segments:
|
277 |
+
| dist(0), dist(1), dist(2), ...|
|
278 |
+
dist_inter_to_segment1:
|
279 |
+
| dist(inter,0), dist(inter,0), dist(inter,0), ... |
|
280 |
+
| dist(inter,1), dist(inter,1), dist(inter,1), ... |
|
281 |
+
...
|
282 |
+
dist_inter_to_semgnet2:
|
283 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
284 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
285 |
+
...
|
286 |
+
'''
|
287 |
+
|
288 |
+
dist_inter_to_segment1_start = np.sqrt(
|
289 |
+
np.sum(((inter_pts - start[:, None, :])**2), axis=-1,
|
290 |
+
keepdims=True)) # [n_batch, n_batch, 1]
|
291 |
+
dist_inter_to_segment1_end = np.sqrt(
|
292 |
+
np.sum(((inter_pts - end[:, None, :])**2), axis=-1,
|
293 |
+
keepdims=True)) # [n_batch, n_batch, 1]
|
294 |
+
dist_inter_to_segment2_start = np.sqrt(
|
295 |
+
np.sum(((inter_pts - start[None, :, :])**2), axis=-1,
|
296 |
+
keepdims=True)) # [n_batch, n_batch, 1]
|
297 |
+
dist_inter_to_segment2_end = np.sqrt(
|
298 |
+
np.sum(((inter_pts - end[None, :, :])**2), axis=-1,
|
299 |
+
keepdims=True)) # [n_batch, n_batch, 1]
|
300 |
+
|
301 |
+
# sort ascending
|
302 |
+
dist_inter_to_segment1 = np.sort(np.concatenate(
|
303 |
+
[dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
|
304 |
+
axis=-1) # [n_batch, n_batch, 2]
|
305 |
+
dist_inter_to_segment2 = np.sort(np.concatenate(
|
306 |
+
[dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
|
307 |
+
axis=-1) # [n_batch, n_batch, 2]
|
308 |
+
|
309 |
+
# 3.2 get degree
|
310 |
+
inter_to_start = new_centers[:, None, :] - inter_pts
|
311 |
+
deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1],
|
312 |
+
inter_to_start[:, :, 0]) * 180 / np.pi
|
313 |
+
deg_inter_to_start[deg_inter_to_start < 0.0] += 360
|
314 |
+
inter_to_end = new_centers[None, :, :] - inter_pts
|
315 |
+
deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1],
|
316 |
+
inter_to_end[:, :, 0]) * 180 / np.pi
|
317 |
+
deg_inter_to_end[deg_inter_to_end < 0.0] += 360
|
318 |
+
'''
|
319 |
+
B -- G
|
320 |
+
| |
|
321 |
+
C -- R
|
322 |
+
B : blue / G: green / C: cyan / R: red
|
323 |
+
|
324 |
+
0 -- 1
|
325 |
+
| |
|
326 |
+
3 -- 2
|
327 |
+
'''
|
328 |
+
# rename variables
|
329 |
+
deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
|
330 |
+
# sort deg ascending
|
331 |
+
deg_sort = np.sort(np.concatenate(
|
332 |
+
[deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1),
|
333 |
+
axis=-1)
|
334 |
+
|
335 |
+
deg_diff_map = np.abs(deg1_map - deg2_map)
|
336 |
+
# we only consider the smallest degree of intersect
|
337 |
+
deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
|
338 |
+
|
339 |
+
# define available degree range
|
340 |
+
deg_range = [60, 120]
|
341 |
+
|
342 |
+
corner_dict = {corner_info: [] for corner_info in range(4)}
|
343 |
+
inter_points = []
|
344 |
+
for i in range(inter_pts.shape[0]):
|
345 |
+
for j in range(i + 1, inter_pts.shape[1]):
|
346 |
+
# i, j > line index, always i < j
|
347 |
+
x, y = inter_pts[i, j, :]
|
348 |
+
deg1, deg2 = deg_sort[i, j, :]
|
349 |
+
deg_diff = deg_diff_map[i, j]
|
350 |
+
|
351 |
+
check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
|
352 |
+
|
353 |
+
outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
|
354 |
+
inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
|
355 |
+
check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and
|
356 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or
|
357 |
+
(dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and
|
358 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
|
359 |
+
((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and
|
360 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or
|
361 |
+
(dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and
|
362 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
|
363 |
+
|
364 |
+
if check_degree and check_distance:
|
365 |
+
corner_info = None # noqa
|
366 |
+
|
367 |
+
if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
|
368 |
+
(deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
|
369 |
+
corner_info, color_info = 0, 'blue'
|
370 |
+
elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125
|
371 |
+
and deg2 <= 225):
|
372 |
+
corner_info, color_info = 1, 'green'
|
373 |
+
elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225
|
374 |
+
and deg2 <= 315):
|
375 |
+
corner_info, color_info = 2, 'black'
|
376 |
+
elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
|
377 |
+
(deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
|
378 |
+
corner_info, color_info = 3, 'cyan'
|
379 |
+
else:
|
380 |
+
corner_info, color_info = 4, 'red' # we don't use it # noqa
|
381 |
+
continue
|
382 |
+
|
383 |
+
corner_dict[corner_info].append([x, y, i, j])
|
384 |
+
inter_points.append([x, y])
|
385 |
+
|
386 |
+
square_list = []
|
387 |
+
connect_list = []
|
388 |
+
segments_list = []
|
389 |
+
for corner0 in corner_dict[0]:
|
390 |
+
for corner1 in corner_dict[1]:
|
391 |
+
connect01 = False
|
392 |
+
for corner0_line in corner0[2:]:
|
393 |
+
if corner0_line in corner1[2:]:
|
394 |
+
connect01 = True
|
395 |
+
break
|
396 |
+
if connect01:
|
397 |
+
for corner2 in corner_dict[2]:
|
398 |
+
connect12 = False
|
399 |
+
for corner1_line in corner1[2:]:
|
400 |
+
if corner1_line in corner2[2:]:
|
401 |
+
connect12 = True
|
402 |
+
break
|
403 |
+
if connect12:
|
404 |
+
for corner3 in corner_dict[3]:
|
405 |
+
connect23 = False
|
406 |
+
for corner2_line in corner2[2:]:
|
407 |
+
if corner2_line in corner3[2:]:
|
408 |
+
connect23 = True
|
409 |
+
break
|
410 |
+
if connect23:
|
411 |
+
for corner3_line in corner3[2:]:
|
412 |
+
if corner3_line in corner0[2:]:
|
413 |
+
# SQUARE!!!
|
414 |
+
'''
|
415 |
+
0 -- 1
|
416 |
+
| |
|
417 |
+
3 -- 2
|
418 |
+
square_list:
|
419 |
+
order: 0 > 1 > 2 > 3
|
420 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
421 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
422 |
+
...
|
423 |
+
connect_list:
|
424 |
+
order: 01 > 12 > 23 > 30
|
425 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
426 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
427 |
+
...
|
428 |
+
segments_list:
|
429 |
+
order: 0 > 1 > 2 > 3
|
430 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i,
|
431 |
+
line_idx2_j, line_idx3_i, line_idx3_j |
|
432 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i,
|
433 |
+
line_idx2_j, line_idx3_i, line_idx3_j |
|
434 |
+
...
|
435 |
+
'''
|
436 |
+
square_list.append(corner0[:2] +
|
437 |
+
corner1[:2] +
|
438 |
+
corner2[:2] +
|
439 |
+
corner3[:2])
|
440 |
+
connect_list.append([
|
441 |
+
corner0_line, corner1_line,
|
442 |
+
corner2_line, corner3_line
|
443 |
+
])
|
444 |
+
segments_list.append(corner0[2:] +
|
445 |
+
corner1[2:] +
|
446 |
+
corner2[2:] +
|
447 |
+
corner3[2:])
|
448 |
+
|
449 |
+
def check_outside_inside(segments_info, connect_idx):
|
450 |
+
# return 'outside or inside', min distance, cover_param, peri_param
|
451 |
+
if connect_idx == segments_info[0]:
|
452 |
+
check_dist_mat = dist_inter_to_segment1
|
453 |
+
else:
|
454 |
+
check_dist_mat = dist_inter_to_segment2
|
455 |
+
|
456 |
+
i, j = segments_info
|
457 |
+
min_dist, max_dist = check_dist_mat[i, j, :]
|
458 |
+
connect_dist = dist_segments[connect_idx]
|
459 |
+
if max_dist > connect_dist:
|
460 |
+
return 'outside', min_dist, 0, 1
|
461 |
+
else:
|
462 |
+
return 'inside', min_dist, -1, -1
|
463 |
+
|
464 |
+
top_square = None # noqa
|
465 |
+
|
466 |
+
try:
|
467 |
+
map_size = input_shape[0] / 2
|
468 |
+
squares = np.array(square_list).reshape([-1, 4, 2])
|
469 |
+
score_array = []
|
470 |
+
connect_array = np.array(connect_list)
|
471 |
+
segments_array = np.array(segments_list).reshape([-1, 4, 2])
|
472 |
+
|
473 |
+
# get degree of corners:
|
474 |
+
squares_rollup = np.roll(squares, 1, axis=1)
|
475 |
+
squares_rolldown = np.roll(squares, -1, axis=1)
|
476 |
+
vec1 = squares_rollup - squares
|
477 |
+
normalized_vec1 = vec1 / (
|
478 |
+
np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
|
479 |
+
vec2 = squares_rolldown - squares
|
480 |
+
normalized_vec2 = vec2 / (
|
481 |
+
np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
|
482 |
+
inner_products = np.sum(normalized_vec1 * normalized_vec2,
|
483 |
+
axis=-1) # [n_squares, 4]
|
484 |
+
squares_degree = np.arccos(
|
485 |
+
inner_products) * 180 / np.pi # [n_squares, 4]
|
486 |
+
|
487 |
+
# get square score
|
488 |
+
overlap_scores = []
|
489 |
+
degree_scores = []
|
490 |
+
length_scores = []
|
491 |
+
|
492 |
+
for connects, segments, square, degree in zip(connect_array,
|
493 |
+
segments_array, squares,
|
494 |
+
squares_degree):
|
495 |
+
'''
|
496 |
+
0 -- 1
|
497 |
+
| |
|
498 |
+
3 -- 2
|
499 |
+
|
500 |
+
# segments: [4, 2]
|
501 |
+
# connects: [4]
|
502 |
+
'''
|
503 |
+
|
504 |
+
# OVERLAP SCORES
|
505 |
+
cover = 0
|
506 |
+
perimeter = 0
|
507 |
+
# check 0 > 1 > 2 > 3
|
508 |
+
square_length = []
|
509 |
+
|
510 |
+
for start_idx in range(4):
|
511 |
+
end_idx = (start_idx + 1) % 4
|
512 |
+
|
513 |
+
connect_idx = connects[start_idx] # segment idx of segment01
|
514 |
+
start_segments = segments[start_idx]
|
515 |
+
end_segments = segments[end_idx]
|
516 |
+
|
517 |
+
start_point = square[start_idx] # noqa
|
518 |
+
end_point = square[end_idx] # noqa
|
519 |
+
|
520 |
+
# check whether outside or inside
|
521 |
+
start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(
|
522 |
+
start_segments, connect_idx)
|
523 |
+
end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(
|
524 |
+
end_segments, connect_idx)
|
525 |
+
|
526 |
+
cover += dist_segments[
|
527 |
+
connect_idx] + start_cover_param * start_min + end_cover_param * end_min
|
528 |
+
perimeter += dist_segments[
|
529 |
+
connect_idx] + start_peri_param * start_min + end_peri_param * end_min
|
530 |
+
|
531 |
+
square_length.append(dist_segments[connect_idx] +
|
532 |
+
start_peri_param * start_min +
|
533 |
+
end_peri_param * end_min)
|
534 |
+
|
535 |
+
overlap_scores.append(cover / perimeter)
|
536 |
+
# DEGREE SCORES
|
537 |
+
'''
|
538 |
+
deg0 vs deg2
|
539 |
+
deg1 vs deg3
|
540 |
+
'''
|
541 |
+
deg0, deg1, deg2, deg3 = degree
|
542 |
+
deg_ratio1 = deg0 / deg2
|
543 |
+
if deg_ratio1 > 1.0:
|
544 |
+
deg_ratio1 = 1 / deg_ratio1
|
545 |
+
deg_ratio2 = deg1 / deg3
|
546 |
+
if deg_ratio2 > 1.0:
|
547 |
+
deg_ratio2 = 1 / deg_ratio2
|
548 |
+
degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
|
549 |
+
# LENGTH SCORES
|
550 |
+
'''
|
551 |
+
len0 vs len2
|
552 |
+
len1 vs len3
|
553 |
+
'''
|
554 |
+
len0, len1, len2, len3 = square_length
|
555 |
+
len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
|
556 |
+
len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
|
557 |
+
length_scores.append((len_ratio1 + len_ratio2) / 2)
|
558 |
+
|
559 |
+
######################################
|
560 |
+
|
561 |
+
overlap_scores = np.array(overlap_scores)
|
562 |
+
overlap_scores /= np.max(overlap_scores)
|
563 |
+
|
564 |
+
degree_scores = np.array(degree_scores)
|
565 |
+
# degree_scores /= np.max(degree_scores)
|
566 |
+
|
567 |
+
length_scores = np.array(length_scores)
|
568 |
+
|
569 |
+
# AREA SCORES
|
570 |
+
area_scores = np.reshape(squares, [-1, 4, 2])
|
571 |
+
area_x = area_scores[:, :, 0]
|
572 |
+
area_y = area_scores[:, :, 1]
|
573 |
+
correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:,
|
574 |
+
0]
|
575 |
+
area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(
|
576 |
+
area_y[:, :-1] * area_x[:, 1:], axis=-1)
|
577 |
+
area_scores = 0.5 * np.abs(area_scores + correction)
|
578 |
+
area_scores /= (map_size * map_size) # np.max(area_scores)
|
579 |
+
|
580 |
+
# CENTER SCORES
|
581 |
+
centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
|
582 |
+
# squares: [n, 4, 2]
|
583 |
+
square_centers = np.mean(squares, axis=1) # [n, 2]
|
584 |
+
center2center = np.sqrt(np.sum((centers - square_centers)**2))
|
585 |
+
center_scores = center2center / (map_size / np.sqrt(2.0))
|
586 |
+
'''
|
587 |
+
score_w = [overlap, degree, area, center, length]
|
588 |
+
'''
|
589 |
+
score_w = [0.0, 1.0, 10.0, 0.5, 1.0] # noqa
|
590 |
+
score_array = (params['w_overlap'] * overlap_scores +
|
591 |
+
params['w_degree'] * degree_scores +
|
592 |
+
params['w_area'] * area_scores -
|
593 |
+
params['w_center'] * center_scores +
|
594 |
+
params['w_length'] * length_scores)
|
595 |
+
|
596 |
+
best_square = [] # noqa
|
597 |
+
|
598 |
+
sorted_idx = np.argsort(score_array)[::-1]
|
599 |
+
score_array = score_array[sorted_idx]
|
600 |
+
squares = squares[sorted_idx]
|
601 |
+
|
602 |
+
except Exception:
|
603 |
+
pass
|
604 |
+
'''return list
|
605 |
+
merged_lines, squares, scores
|
606 |
+
'''
|
607 |
+
|
608 |
+
try:
|
609 |
+
new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[
|
610 |
+
1] * original_shape[1]
|
611 |
+
new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[
|
612 |
+
0] * original_shape[0]
|
613 |
+
new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[
|
614 |
+
1] * original_shape[1]
|
615 |
+
new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[
|
616 |
+
0] * original_shape[0]
|
617 |
+
except Exception:
|
618 |
+
new_segments = []
|
619 |
+
|
620 |
+
try:
|
621 |
+
squares[:, :,
|
622 |
+
0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
|
623 |
+
squares[:, :,
|
624 |
+
1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
|
625 |
+
except Exception:
|
626 |
+
squares = []
|
627 |
+
score_array = []
|
628 |
+
|
629 |
+
try:
|
630 |
+
inter_points = np.array(inter_points)
|
631 |
+
inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[
|
632 |
+
1] * original_shape[1]
|
633 |
+
inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[
|
634 |
+
0] * original_shape[0]
|
635 |
+
except Exception:
|
636 |
+
inter_points = []
|
637 |
+
|
638 |
+
return new_segments, squares, score_array, inter_points
|
annotator/mlsd_op.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
# MLSD Line Detection
|
4 |
+
# From https://github.com/navervision/mlsd
|
5 |
+
# Apache-2.0 license
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
from abc import ABCMeta
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
16 |
+
from scepter.modules.annotator.mlsd.mbv2_mlsd_large import MobileV2_MLSD_Large
|
17 |
+
from scepter.modules.annotator.mlsd.utils import pred_lines
|
18 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
19 |
+
from scepter.modules.annotator.utils import resize_image, resize_image_ori
|
20 |
+
from scepter.modules.utils.config import dict_to_yaml
|
21 |
+
from scepter.modules.utils.distribute import we
|
22 |
+
from scepter.modules.utils.file_system import FS
|
23 |
+
|
24 |
+
|
25 |
+
@ANNOTATORS.register_class()
|
26 |
+
class MLSDdetector(BaseAnnotator, metaclass=ABCMeta):
|
27 |
+
def __init__(self, cfg, logger=None):
|
28 |
+
super().__init__(cfg, logger=logger)
|
29 |
+
model = MobileV2_MLSD_Large()
|
30 |
+
pretrained_model = cfg.get('PRETRAINED_MODEL', None)
|
31 |
+
if pretrained_model:
|
32 |
+
with FS.get_from(pretrained_model, wait_finish=True) as local_path:
|
33 |
+
model.load_state_dict(torch.load(local_path), strict=True)
|
34 |
+
self.model = model.eval()
|
35 |
+
self.thr_v = cfg.get('THR_V', 0.1)
|
36 |
+
self.thr_d = cfg.get('THR_D', 0.1)
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
@torch.inference_mode()
|
40 |
+
@torch.autocast('cuda', enabled=False)
|
41 |
+
def forward(self, image):
|
42 |
+
if isinstance(image, Image.Image):
|
43 |
+
image = np.array(image)
|
44 |
+
elif isinstance(image, torch.Tensor):
|
45 |
+
image = image.detach().cpu().numpy()
|
46 |
+
elif isinstance(image, np.ndarray):
|
47 |
+
image = image.copy()
|
48 |
+
else:
|
49 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
50 |
+
h, w, c = image.shape
|
51 |
+
image, k = resize_image(image, 1024 if min(h, w) > 1024 else min(h, w))
|
52 |
+
img_output = np.zeros_like(image)
|
53 |
+
try:
|
54 |
+
lines = pred_lines(image,
|
55 |
+
self.model, [image.shape[0], image.shape[1]],
|
56 |
+
self.thr_v,
|
57 |
+
self.thr_d,
|
58 |
+
device=we.device_id)
|
59 |
+
for line in lines:
|
60 |
+
x_start, y_start, x_end, y_end = [int(val) for val in line]
|
61 |
+
cv2.line(img_output, (x_start, y_start), (x_end, y_end),
|
62 |
+
[255, 255, 255], 1)
|
63 |
+
except Exception as e:
|
64 |
+
warnings.warn(f'{e}')
|
65 |
+
return None
|
66 |
+
img_output = resize_image_ori(h, w, img_output, k)
|
67 |
+
return img_output[:, :, 0]
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def get_config_template():
|
71 |
+
return dict_to_yaml('ANNOTATORS',
|
72 |
+
__class__.__name__,
|
73 |
+
MLSDdetector.para_dict,
|
74 |
+
set_name=True)
|
annotator/openpose.py
ADDED
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
# Openpose
|
4 |
+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
5 |
+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
6 |
+
# The implementation is modified from 3rd Edited Version by ControlNet
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from abc import ABCMeta
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import matplotlib
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from PIL import Image
|
18 |
+
from scipy.ndimage.filters import gaussian_filter
|
19 |
+
from skimage.measure import label
|
20 |
+
|
21 |
+
from scepter.modules.annotator.base_annotator import BaseAnnotator
|
22 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
23 |
+
from scepter.modules.utils.config import dict_to_yaml
|
24 |
+
from scepter.modules.utils.file_system import FS
|
25 |
+
|
26 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
27 |
+
|
28 |
+
|
29 |
+
def padRightDownCorner(img, stride, padValue):
|
30 |
+
h = img.shape[0]
|
31 |
+
w = img.shape[1]
|
32 |
+
|
33 |
+
pad = 4 * [None]
|
34 |
+
pad[0] = 0 # up
|
35 |
+
pad[1] = 0 # left
|
36 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
37 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
38 |
+
|
39 |
+
img_padded = img
|
40 |
+
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
|
41 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
42 |
+
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
|
43 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
44 |
+
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
|
45 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
46 |
+
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
|
47 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
48 |
+
|
49 |
+
return img_padded, pad
|
50 |
+
|
51 |
+
|
52 |
+
# transfer caffe model to pytorch which will match the layer name
|
53 |
+
def transfer(model, model_weights):
|
54 |
+
transfered_model_weights = {}
|
55 |
+
for weights_name in model.state_dict().keys():
|
56 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(
|
57 |
+
weights_name.split('.')[1:])]
|
58 |
+
return transfered_model_weights
|
59 |
+
|
60 |
+
|
61 |
+
# draw the body keypoint and lims
|
62 |
+
def draw_bodypose(canvas, candidate, subset):
|
63 |
+
stickwidth = 4
|
64 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
|
65 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15],
|
66 |
+
[15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
|
67 |
+
|
68 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0],
|
69 |
+
[170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85],
|
70 |
+
[0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255],
|
71 |
+
[0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255],
|
72 |
+
[255, 0, 170], [255, 0, 85]]
|
73 |
+
for i in range(18):
|
74 |
+
for n in range(len(subset)):
|
75 |
+
index = int(subset[n][i])
|
76 |
+
if index == -1:
|
77 |
+
continue
|
78 |
+
x, y = candidate[index][0:2]
|
79 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
80 |
+
for i in range(17):
|
81 |
+
for n in range(len(subset)):
|
82 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
83 |
+
if -1 in index:
|
84 |
+
continue
|
85 |
+
cur_canvas = canvas.copy()
|
86 |
+
Y = candidate[index.astype(int), 0]
|
87 |
+
X = candidate[index.astype(int), 1]
|
88 |
+
mX = np.mean(X)
|
89 |
+
mY = np.mean(Y)
|
90 |
+
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
|
91 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
92 |
+
polygon = cv2.ellipse2Poly(
|
93 |
+
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle),
|
94 |
+
0, 360, 1)
|
95 |
+
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
96 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
97 |
+
# plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
|
98 |
+
# plt.imshow(canvas[:, :, [2, 1, 0]])
|
99 |
+
return canvas
|
100 |
+
|
101 |
+
|
102 |
+
# image drawed by opencv is not good.
|
103 |
+
def draw_handpose(canvas, all_hand_peaks, show_number=False):
|
104 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8],
|
105 |
+
[0, 9], [9, 10], [10, 11], [11, 12], [0, 13], [13, 14], [14, 15],
|
106 |
+
[15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
107 |
+
|
108 |
+
for peaks in all_hand_peaks:
|
109 |
+
for ie, e in enumerate(edges):
|
110 |
+
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
|
111 |
+
x1, y1 = peaks[e[0]]
|
112 |
+
x2, y2 = peaks[e[1]]
|
113 |
+
cv2.line(canvas, (x1, y1), (x2, y2),
|
114 |
+
matplotlib.colors.hsv_to_rgb(
|
115 |
+
[ie / float(len(edges)), 1.0, 1.0]) * 255,
|
116 |
+
thickness=2)
|
117 |
+
|
118 |
+
for i, keyponit in enumerate(peaks):
|
119 |
+
x, y = keyponit
|
120 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
121 |
+
if show_number:
|
122 |
+
cv2.putText(canvas,
|
123 |
+
str(i), (x, y),
|
124 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
125 |
+
0.3, (0, 0, 0),
|
126 |
+
lineType=cv2.LINE_AA)
|
127 |
+
return canvas
|
128 |
+
|
129 |
+
|
130 |
+
# detect hand according to body pose keypoints
|
131 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/
|
132 |
+
# master/src/openpose/hand/handDetector.cpp
|
133 |
+
def handDetect(candidate, subset, oriImg):
|
134 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
135 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
136 |
+
ratioWristElbow = 0.33
|
137 |
+
detect_result = []
|
138 |
+
image_height, image_width = oriImg.shape[0:2]
|
139 |
+
for person in subset.astype(int):
|
140 |
+
# if any of three not detected
|
141 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
142 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
143 |
+
if not (has_left or has_right):
|
144 |
+
continue
|
145 |
+
hands = []
|
146 |
+
# left hand
|
147 |
+
if has_left:
|
148 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[
|
149 |
+
5, 6, 7
|
150 |
+
]]
|
151 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
152 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
153 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
154 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
155 |
+
# right hand
|
156 |
+
if has_right:
|
157 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[
|
158 |
+
[2, 3, 4]]
|
159 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
160 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
161 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
162 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
163 |
+
|
164 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
165 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
166 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
167 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
168 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
169 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
170 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
171 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
172 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
173 |
+
distanceWristElbow = math.sqrt((x3 - x2)**2 + (y3 - y2)**2)
|
174 |
+
distanceElbowShoulder = math.sqrt((x2 - x1)**2 + (y2 - y1)**2)
|
175 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
176 |
+
# x-y refers to the center --> offset to topLeft point
|
177 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
178 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
179 |
+
x -= width / 2
|
180 |
+
y -= width / 2 # width = height
|
181 |
+
# overflow the image
|
182 |
+
if x < 0:
|
183 |
+
x = 0
|
184 |
+
if y < 0:
|
185 |
+
y = 0
|
186 |
+
width1 = width
|
187 |
+
width2 = width
|
188 |
+
if x + width > image_width:
|
189 |
+
width1 = image_width - x
|
190 |
+
if y + width > image_height:
|
191 |
+
width2 = image_height - y
|
192 |
+
width = min(width1, width2)
|
193 |
+
# the max hand box value is 20 pixels
|
194 |
+
if width >= 20:
|
195 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
196 |
+
'''
|
197 |
+
return value: [[x, y, w, True if left hand else False]].
|
198 |
+
width=height since the network require squared input.
|
199 |
+
x, y is the coordinate of top left
|
200 |
+
'''
|
201 |
+
return detect_result
|
202 |
+
|
203 |
+
|
204 |
+
# get max index of 2d array
|
205 |
+
def npmax(array):
|
206 |
+
arrayindex = array.argmax(1)
|
207 |
+
arrayvalue = array.max(1)
|
208 |
+
i = arrayvalue.argmax()
|
209 |
+
j = arrayindex[i]
|
210 |
+
return i, j
|
211 |
+
|
212 |
+
|
213 |
+
def make_layers(block, no_relu_layers):
|
214 |
+
layers = []
|
215 |
+
for layer_name, v in block.items():
|
216 |
+
if 'pool' in layer_name:
|
217 |
+
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
|
218 |
+
layers.append((layer_name, layer))
|
219 |
+
else:
|
220 |
+
conv2d = nn.Conv2d(in_channels=v[0],
|
221 |
+
out_channels=v[1],
|
222 |
+
kernel_size=v[2],
|
223 |
+
stride=v[3],
|
224 |
+
padding=v[4])
|
225 |
+
layers.append((layer_name, conv2d))
|
226 |
+
if layer_name not in no_relu_layers:
|
227 |
+
layers.append(('relu_' + layer_name, nn.ReLU(inplace=True)))
|
228 |
+
|
229 |
+
return nn.Sequential(OrderedDict(layers))
|
230 |
+
|
231 |
+
|
232 |
+
class bodypose_model(nn.Module):
|
233 |
+
def __init__(self):
|
234 |
+
super(bodypose_model, self).__init__()
|
235 |
+
|
236 |
+
# these layers have no relu layer
|
237 |
+
no_relu_layers = [
|
238 |
+
'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',
|
239 |
+
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',
|
240 |
+
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',
|
241 |
+
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'
|
242 |
+
]
|
243 |
+
blocks = {}
|
244 |
+
block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]),
|
245 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
246 |
+
('pool1_stage1', [2, 2, 0]),
|
247 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
248 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
249 |
+
('pool2_stage1', [2, 2, 0]),
|
250 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
251 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
252 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
253 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
254 |
+
('pool3_stage1', [2, 2, 0]),
|
255 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
256 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
257 |
+
('conv4_3_CPM', [512, 256, 3, 1, 1]),
|
258 |
+
('conv4_4_CPM', [256, 128, 3, 1, 1])])
|
259 |
+
|
260 |
+
# Stage 1
|
261 |
+
block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
|
262 |
+
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
|
263 |
+
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
|
264 |
+
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
|
265 |
+
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])])
|
266 |
+
|
267 |
+
block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
|
268 |
+
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
|
269 |
+
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
|
270 |
+
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
|
271 |
+
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])])
|
272 |
+
blocks['block1_1'] = block1_1
|
273 |
+
blocks['block1_2'] = block1_2
|
274 |
+
|
275 |
+
self.model0 = make_layers(block0, no_relu_layers)
|
276 |
+
|
277 |
+
# Stages 2 - 6
|
278 |
+
for i in range(2, 7):
|
279 |
+
blocks['block%d_1' % i] = OrderedDict([
|
280 |
+
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
|
281 |
+
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
282 |
+
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
283 |
+
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
284 |
+
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
285 |
+
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
|
286 |
+
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
|
287 |
+
])
|
288 |
+
|
289 |
+
blocks['block%d_2' % i] = OrderedDict([
|
290 |
+
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
|
291 |
+
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
292 |
+
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
293 |
+
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
294 |
+
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
295 |
+
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
|
296 |
+
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
|
297 |
+
])
|
298 |
+
|
299 |
+
for k in blocks.keys():
|
300 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
301 |
+
|
302 |
+
self.model1_1 = blocks['block1_1']
|
303 |
+
self.model2_1 = blocks['block2_1']
|
304 |
+
self.model3_1 = blocks['block3_1']
|
305 |
+
self.model4_1 = blocks['block4_1']
|
306 |
+
self.model5_1 = blocks['block5_1']
|
307 |
+
self.model6_1 = blocks['block6_1']
|
308 |
+
|
309 |
+
self.model1_2 = blocks['block1_2']
|
310 |
+
self.model2_2 = blocks['block2_2']
|
311 |
+
self.model3_2 = blocks['block3_2']
|
312 |
+
self.model4_2 = blocks['block4_2']
|
313 |
+
self.model5_2 = blocks['block5_2']
|
314 |
+
self.model6_2 = blocks['block6_2']
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
|
318 |
+
out1 = self.model0(x)
|
319 |
+
|
320 |
+
out1_1 = self.model1_1(out1)
|
321 |
+
out1_2 = self.model1_2(out1)
|
322 |
+
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
323 |
+
|
324 |
+
out2_1 = self.model2_1(out2)
|
325 |
+
out2_2 = self.model2_2(out2)
|
326 |
+
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
327 |
+
|
328 |
+
out3_1 = self.model3_1(out3)
|
329 |
+
out3_2 = self.model3_2(out3)
|
330 |
+
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
331 |
+
|
332 |
+
out4_1 = self.model4_1(out4)
|
333 |
+
out4_2 = self.model4_2(out4)
|
334 |
+
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
335 |
+
|
336 |
+
out5_1 = self.model5_1(out5)
|
337 |
+
out5_2 = self.model5_2(out5)
|
338 |
+
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
339 |
+
|
340 |
+
out6_1 = self.model6_1(out6)
|
341 |
+
out6_2 = self.model6_2(out6)
|
342 |
+
|
343 |
+
return out6_1, out6_2
|
344 |
+
|
345 |
+
|
346 |
+
class handpose_model(nn.Module):
|
347 |
+
def __init__(self):
|
348 |
+
super(handpose_model, self).__init__()
|
349 |
+
|
350 |
+
# these layers have no relu layer
|
351 |
+
no_relu_layers = [
|
352 |
+
'conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3', 'Mconv7_stage4',
|
353 |
+
'Mconv7_stage5', 'Mconv7_stage6'
|
354 |
+
]
|
355 |
+
# stage 1
|
356 |
+
block1_0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]),
|
357 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
358 |
+
('pool1_stage1', [2, 2, 0]),
|
359 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
360 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
361 |
+
('pool2_stage1', [2, 2, 0]),
|
362 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
363 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
364 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
365 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
366 |
+
('pool3_stage1', [2, 2, 0]),
|
367 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
368 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
369 |
+
('conv4_3', [512, 512, 3, 1, 1]),
|
370 |
+
('conv4_4', [512, 512, 3, 1, 1]),
|
371 |
+
('conv5_1', [512, 512, 3, 1, 1]),
|
372 |
+
('conv5_2', [512, 512, 3, 1, 1]),
|
373 |
+
('conv5_3_CPM', [512, 128, 3, 1, 1])])
|
374 |
+
|
375 |
+
block1_1 = OrderedDict([('conv6_1_CPM', [128, 512, 1, 1, 0]),
|
376 |
+
('conv6_2_CPM', [512, 22, 1, 1, 0])])
|
377 |
+
|
378 |
+
blocks = {}
|
379 |
+
blocks['block1_0'] = block1_0
|
380 |
+
blocks['block1_1'] = block1_1
|
381 |
+
|
382 |
+
# stage 2-6
|
383 |
+
for i in range(2, 7):
|
384 |
+
blocks['block%d' % i] = OrderedDict([
|
385 |
+
('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
|
386 |
+
('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
|
387 |
+
('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
|
388 |
+
('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
|
389 |
+
('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
|
390 |
+
('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
|
391 |
+
('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
|
392 |
+
])
|
393 |
+
|
394 |
+
for k in blocks.keys():
|
395 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
396 |
+
|
397 |
+
self.model1_0 = blocks['block1_0']
|
398 |
+
self.model1_1 = blocks['block1_1']
|
399 |
+
self.model2 = blocks['block2']
|
400 |
+
self.model3 = blocks['block3']
|
401 |
+
self.model4 = blocks['block4']
|
402 |
+
self.model5 = blocks['block5']
|
403 |
+
self.model6 = blocks['block6']
|
404 |
+
|
405 |
+
def forward(self, x):
|
406 |
+
out1_0 = self.model1_0(x)
|
407 |
+
out1_1 = self.model1_1(out1_0)
|
408 |
+
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
409 |
+
out_stage2 = self.model2(concat_stage2)
|
410 |
+
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
411 |
+
out_stage3 = self.model3(concat_stage3)
|
412 |
+
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
413 |
+
out_stage4 = self.model4(concat_stage4)
|
414 |
+
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
415 |
+
out_stage5 = self.model5(concat_stage5)
|
416 |
+
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
417 |
+
out_stage6 = self.model6(concat_stage6)
|
418 |
+
return out_stage6
|
419 |
+
|
420 |
+
|
421 |
+
class Hand(object):
|
422 |
+
def __init__(self, model_path, device='cuda'):
|
423 |
+
self.model = handpose_model()
|
424 |
+
if torch.cuda.is_available():
|
425 |
+
self.model = self.model.to(device)
|
426 |
+
model_dict = transfer(self.model, torch.load(model_path))
|
427 |
+
self.model.load_state_dict(model_dict)
|
428 |
+
self.model.eval()
|
429 |
+
self.device = device
|
430 |
+
|
431 |
+
def __call__(self, oriImg):
|
432 |
+
scale_search = [0.5, 1.0, 1.5, 2.0]
|
433 |
+
# scale_search = [0.5]
|
434 |
+
boxsize = 368
|
435 |
+
stride = 8
|
436 |
+
padValue = 128
|
437 |
+
thre = 0.05
|
438 |
+
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
439 |
+
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
|
440 |
+
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
441 |
+
|
442 |
+
for m in range(len(multiplier)):
|
443 |
+
scale = multiplier[m]
|
444 |
+
imageToTest = cv2.resize(oriImg, (0, 0),
|
445 |
+
fx=scale,
|
446 |
+
fy=scale,
|
447 |
+
interpolation=cv2.INTER_CUBIC)
|
448 |
+
imageToTest_padded, pad = padRightDownCorner(
|
449 |
+
imageToTest, stride, padValue)
|
450 |
+
im = np.transpose(
|
451 |
+
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
|
452 |
+
(3, 2, 0, 1)) / 256 - 0.5
|
453 |
+
im = np.ascontiguousarray(im)
|
454 |
+
|
455 |
+
data = torch.from_numpy(im).float()
|
456 |
+
if torch.cuda.is_available():
|
457 |
+
data = data.to(self.device)
|
458 |
+
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
459 |
+
with torch.no_grad():
|
460 |
+
output = self.model(data).cpu().numpy()
|
461 |
+
# output = self.model(data).numpy()q
|
462 |
+
|
463 |
+
# extract outputs, resize, and remove padding
|
464 |
+
heatmap = np.transpose(np.squeeze(output),
|
465 |
+
(1, 2, 0)) # output 1 is heatmaps
|
466 |
+
heatmap = cv2.resize(heatmap, (0, 0),
|
467 |
+
fx=stride,
|
468 |
+
fy=stride,
|
469 |
+
interpolation=cv2.INTER_CUBIC)
|
470 |
+
heatmap = heatmap[:imageToTest_padded.shape[0] -
|
471 |
+
pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
472 |
+
heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]),
|
473 |
+
interpolation=cv2.INTER_CUBIC)
|
474 |
+
|
475 |
+
heatmap_avg += heatmap / len(multiplier)
|
476 |
+
|
477 |
+
all_peaks = []
|
478 |
+
for part in range(21):
|
479 |
+
map_ori = heatmap_avg[:, :, part]
|
480 |
+
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
481 |
+
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
|
482 |
+
# 全部小于阈值
|
483 |
+
if np.sum(binary) == 0:
|
484 |
+
all_peaks.append([0, 0])
|
485 |
+
continue
|
486 |
+
label_img, label_numbers = label(binary,
|
487 |
+
return_num=True,
|
488 |
+
connectivity=binary.ndim)
|
489 |
+
max_index = np.argmax([
|
490 |
+
np.sum(map_ori[label_img == i])
|
491 |
+
for i in range(1, label_numbers + 1)
|
492 |
+
]) + 1
|
493 |
+
label_img[label_img != max_index] = 0
|
494 |
+
map_ori[label_img == 0] = 0
|
495 |
+
|
496 |
+
y, x = npmax(map_ori)
|
497 |
+
all_peaks.append([x, y])
|
498 |
+
return np.array(all_peaks)
|
499 |
+
|
500 |
+
|
501 |
+
class Body(object):
|
502 |
+
def __init__(self, model_path, device='cuda'):
|
503 |
+
self.model = bodypose_model()
|
504 |
+
if torch.cuda.is_available():
|
505 |
+
self.model = self.model.to(device)
|
506 |
+
model_dict = transfer(self.model, torch.load(model_path))
|
507 |
+
self.model.load_state_dict(model_dict)
|
508 |
+
self.model.eval()
|
509 |
+
self.device = device
|
510 |
+
|
511 |
+
def __call__(self, oriImg):
|
512 |
+
# scale_search = [0.5, 1.0, 1.5, 2.0]
|
513 |
+
scale_search = [0.5]
|
514 |
+
boxsize = 368
|
515 |
+
stride = 8
|
516 |
+
padValue = 128
|
517 |
+
thre1 = 0.1
|
518 |
+
thre2 = 0.05
|
519 |
+
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
520 |
+
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
521 |
+
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
522 |
+
|
523 |
+
for m in range(len(multiplier)):
|
524 |
+
scale = multiplier[m]
|
525 |
+
imageToTest = cv2.resize(oriImg, (0, 0),
|
526 |
+
fx=scale,
|
527 |
+
fy=scale,
|
528 |
+
interpolation=cv2.INTER_CUBIC)
|
529 |
+
imageToTest_padded, pad = padRightDownCorner(
|
530 |
+
imageToTest, stride, padValue)
|
531 |
+
im = np.transpose(
|
532 |
+
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
|
533 |
+
(3, 2, 0, 1)) / 256 - 0.5
|
534 |
+
im = np.ascontiguousarray(im)
|
535 |
+
|
536 |
+
data = torch.from_numpy(im).float()
|
537 |
+
if torch.cuda.is_available():
|
538 |
+
data = data.to(self.device)
|
539 |
+
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
540 |
+
with torch.no_grad():
|
541 |
+
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
542 |
+
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
543 |
+
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
544 |
+
|
545 |
+
# extract outputs, resize, and remove padding
|
546 |
+
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0))
|
547 |
+
# output 1 is heatmaps
|
548 |
+
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2),
|
549 |
+
(1, 2, 0)) # output 1 is heatmaps
|
550 |
+
heatmap = cv2.resize(heatmap, (0, 0),
|
551 |
+
fx=stride,
|
552 |
+
fy=stride,
|
553 |
+
interpolation=cv2.INTER_CUBIC)
|
554 |
+
heatmap = heatmap[:imageToTest_padded.shape[0] -
|
555 |
+
pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
556 |
+
heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]),
|
557 |
+
interpolation=cv2.INTER_CUBIC)
|
558 |
+
|
559 |
+
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
560 |
+
paf = np.transpose(np.squeeze(Mconv7_stage6_L1),
|
561 |
+
(1, 2, 0)) # output 0 is PAFs
|
562 |
+
paf = cv2.resize(paf, (0, 0),
|
563 |
+
fx=stride,
|
564 |
+
fy=stride,
|
565 |
+
interpolation=cv2.INTER_CUBIC)
|
566 |
+
paf = paf[:imageToTest_padded.shape[0] -
|
567 |
+
pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
568 |
+
paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]),
|
569 |
+
interpolation=cv2.INTER_CUBIC)
|
570 |
+
|
571 |
+
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
572 |
+
paf_avg += +paf / len(multiplier)
|
573 |
+
|
574 |
+
all_peaks = []
|
575 |
+
peak_counter = 0
|
576 |
+
|
577 |
+
for part in range(18):
|
578 |
+
map_ori = heatmap_avg[:, :, part]
|
579 |
+
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
580 |
+
|
581 |
+
map_left = np.zeros(one_heatmap.shape)
|
582 |
+
map_left[1:, :] = one_heatmap[:-1, :]
|
583 |
+
map_right = np.zeros(one_heatmap.shape)
|
584 |
+
map_right[:-1, :] = one_heatmap[1:, :]
|
585 |
+
map_up = np.zeros(one_heatmap.shape)
|
586 |
+
map_up[:, 1:] = one_heatmap[:, :-1]
|
587 |
+
map_down = np.zeros(one_heatmap.shape)
|
588 |
+
map_down[:, :-1] = one_heatmap[:, 1:]
|
589 |
+
|
590 |
+
peaks_binary = np.logical_and.reduce(
|
591 |
+
(one_heatmap >= map_left, one_heatmap >= map_right,
|
592 |
+
one_heatmap >= map_up, one_heatmap >= map_down,
|
593 |
+
one_heatmap > thre1))
|
594 |
+
peaks = list(
|
595 |
+
zip(np.nonzero(peaks_binary)[1],
|
596 |
+
np.nonzero(peaks_binary)[0])) # note reverse
|
597 |
+
peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks]
|
598 |
+
peak_id = range(peak_counter, peak_counter + len(peaks))
|
599 |
+
peaks_with_score_and_id = [
|
600 |
+
peaks_with_score[i] + (peak_id[i], )
|
601 |
+
for i in range(len(peak_id))
|
602 |
+
]
|
603 |
+
|
604 |
+
all_peaks.append(peaks_with_score_and_id)
|
605 |
+
peak_counter += len(peaks)
|
606 |
+
|
607 |
+
# find connection in the specified sequence, center 29 is in the position 15
|
608 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9],
|
609 |
+
[9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1],
|
610 |
+
[1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
|
611 |
+
# the middle joints heatmap correpondence
|
612 |
+
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44],
|
613 |
+
[19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30],
|
614 |
+
[47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38],
|
615 |
+
[45, 46]]
|
616 |
+
|
617 |
+
connection_all = []
|
618 |
+
special_k = []
|
619 |
+
mid_num = 10
|
620 |
+
|
621 |
+
for k in range(len(mapIdx)):
|
622 |
+
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
623 |
+
candA = all_peaks[limbSeq[k][0] - 1]
|
624 |
+
candB = all_peaks[limbSeq[k][1] - 1]
|
625 |
+
nA = len(candA)
|
626 |
+
nB = len(candB)
|
627 |
+
indexA, indexB = limbSeq[k]
|
628 |
+
if (nA != 0 and nB != 0):
|
629 |
+
connection_candidate = []
|
630 |
+
for i in range(nA):
|
631 |
+
for j in range(nB):
|
632 |
+
vec = np.subtract(candB[j][:2], candA[i][:2])
|
633 |
+
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
634 |
+
norm = max(0.001, norm)
|
635 |
+
vec = np.divide(vec, norm)
|
636 |
+
|
637 |
+
startend = list(
|
638 |
+
zip(
|
639 |
+
np.linspace(candA[i][0],
|
640 |
+
candB[j][0],
|
641 |
+
num=mid_num),
|
642 |
+
np.linspace(candA[i][1],
|
643 |
+
candB[j][1],
|
644 |
+
num=mid_num)))
|
645 |
+
|
646 |
+
vec_x = np.array([
|
647 |
+
score_mid[int(round(startend[ii][1])),
|
648 |
+
int(round(startend[ii][0])), 0]
|
649 |
+
for ii in range(len(startend))
|
650 |
+
])
|
651 |
+
vec_y = np.array([
|
652 |
+
score_mid[int(round(startend[ii][1])),
|
653 |
+
int(round(startend[ii][0])), 1]
|
654 |
+
for ii in range(len(startend))
|
655 |
+
])
|
656 |
+
|
657 |
+
score_midpts = np.multiply(
|
658 |
+
vec_x, vec[0]) + np.multiply(vec_y, vec[1])
|
659 |
+
score_with_dist_prior = sum(score_midpts) / len(
|
660 |
+
score_midpts) + min(
|
661 |
+
0.5 * oriImg.shape[0] / norm - 1, 0)
|
662 |
+
criterion1 = len(np.nonzero(
|
663 |
+
score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
|
664 |
+
criterion2 = score_with_dist_prior > 0
|
665 |
+
if criterion1 and criterion2:
|
666 |
+
connection_candidate.append([
|
667 |
+
i, j, score_with_dist_prior,
|
668 |
+
score_with_dist_prior + candA[i][2] +
|
669 |
+
candB[j][2]
|
670 |
+
])
|
671 |
+
|
672 |
+
connection_candidate = sorted(connection_candidate,
|
673 |
+
key=lambda x: x[2],
|
674 |
+
reverse=True)
|
675 |
+
connection = np.zeros((0, 5))
|
676 |
+
for c in range(len(connection_candidate)):
|
677 |
+
i, j, s = connection_candidate[c][0:3]
|
678 |
+
if (i not in connection[:, 3]
|
679 |
+
and j not in connection[:, 4]):
|
680 |
+
connection = np.vstack(
|
681 |
+
[connection, [candA[i][3], candB[j][3], s, i, j]])
|
682 |
+
if (len(connection) >= min(nA, nB)):
|
683 |
+
break
|
684 |
+
|
685 |
+
connection_all.append(connection)
|
686 |
+
else:
|
687 |
+
special_k.append(k)
|
688 |
+
connection_all.append([])
|
689 |
+
|
690 |
+
# last number in each row is the total parts number of that person
|
691 |
+
# the second last number in each row is the score of the overall configuration
|
692 |
+
subset = -1 * np.ones((0, 20))
|
693 |
+
candidate = np.array(
|
694 |
+
[item for sublist in all_peaks for item in sublist])
|
695 |
+
|
696 |
+
for k in range(len(mapIdx)):
|
697 |
+
if k not in special_k:
|
698 |
+
partAs = connection_all[k][:, 0]
|
699 |
+
partBs = connection_all[k][:, 1]
|
700 |
+
indexA, indexB = np.array(limbSeq[k]) - 1
|
701 |
+
|
702 |
+
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
703 |
+
found = 0
|
704 |
+
subset_idx = [-1, -1]
|
705 |
+
for j in range(len(subset)): # 1:size(subset,1):
|
706 |
+
if subset[j][indexA] == partAs[i] or subset[j][
|
707 |
+
indexB] == partBs[i]:
|
708 |
+
subset_idx[found] = j
|
709 |
+
found += 1
|
710 |
+
|
711 |
+
if found == 1:
|
712 |
+
j = subset_idx[0]
|
713 |
+
if subset[j][indexB] != partBs[i]:
|
714 |
+
subset[j][indexB] = partBs[i]
|
715 |
+
subset[j][-1] += 1
|
716 |
+
subset[j][-2] += candidate[
|
717 |
+
partBs[i].astype(int),
|
718 |
+
2] + connection_all[k][i][2]
|
719 |
+
elif found == 2: # if found 2 and disjoint, merge them
|
720 |
+
j1, j2 = subset_idx
|
721 |
+
membership = ((subset[j1] >= 0).astype(int) +
|
722 |
+
(subset[j2] >= 0).astype(int))[:-2]
|
723 |
+
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
724 |
+
subset[j1][:-2] += (subset[j2][:-2] + 1)
|
725 |
+
subset[j1][-2:] += subset[j2][-2:]
|
726 |
+
subset[j1][-2] += connection_all[k][i][2]
|
727 |
+
subset = np.delete(subset, j2, 0)
|
728 |
+
else: # as like found == 1
|
729 |
+
subset[j1][indexB] = partBs[i]
|
730 |
+
subset[j1][-1] += 1
|
731 |
+
subset[j1][-2] += candidate[
|
732 |
+
partBs[i].astype(int),
|
733 |
+
2] + connection_all[k][i][2]
|
734 |
+
|
735 |
+
# if find no partA in the subset, create a new subset
|
736 |
+
elif not found and k < 17:
|
737 |
+
row = -1 * np.ones(20)
|
738 |
+
row[indexA] = partAs[i]
|
739 |
+
row[indexB] = partBs[i]
|
740 |
+
row[-1] = 2
|
741 |
+
row[-2] = sum(
|
742 |
+
candidate[connection_all[k][i, :2].astype(int),
|
743 |
+
2]) + connection_all[k][i][2]
|
744 |
+
subset = np.vstack([subset, row])
|
745 |
+
# delete some rows of subset which has few parts occur
|
746 |
+
deleteIdx = []
|
747 |
+
for i in range(len(subset)):
|
748 |
+
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
749 |
+
deleteIdx.append(i)
|
750 |
+
subset = np.delete(subset, deleteIdx, axis=0)
|
751 |
+
|
752 |
+
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
|
753 |
+
# candidate: x, y, score, id
|
754 |
+
return candidate, subset
|
755 |
+
|
756 |
+
|
757 |
+
@ANNOTATORS.register_class()
|
758 |
+
class OpenposeAnnotator(BaseAnnotator, metaclass=ABCMeta):
|
759 |
+
para_dict = {}
|
760 |
+
|
761 |
+
def __init__(self, cfg, logger=None):
|
762 |
+
super().__init__(cfg, logger=logger)
|
763 |
+
with FS.get_from(cfg.BODY_MODEL_PATH,
|
764 |
+
wait_finish=True) as body_model_path:
|
765 |
+
self.body_estimation = Body(body_model_path, device='cpu')
|
766 |
+
with FS.get_from(cfg.HAND_MODEL_PATH,
|
767 |
+
wait_finish=True) as hand_model_path:
|
768 |
+
self.hand_estimation = Hand(hand_model_path, device='cpu')
|
769 |
+
self.use_hand = cfg.get('USE_HAND', False)
|
770 |
+
|
771 |
+
def to(self, device):
|
772 |
+
self.body_estimation.model = self.body_estimation.model.to(device)
|
773 |
+
self.body_estimation.device = device
|
774 |
+
self.hand_estimation.model = self.hand_estimation.model.to(device)
|
775 |
+
self.hand_estimation.device = device
|
776 |
+
return self
|
777 |
+
|
778 |
+
@torch.no_grad()
|
779 |
+
@torch.inference_mode()
|
780 |
+
@torch.autocast('cuda', enabled=False)
|
781 |
+
def forward(self, image):
|
782 |
+
if isinstance(image, Image.Image):
|
783 |
+
image = np.array(image)
|
784 |
+
elif isinstance(image, torch.Tensor):
|
785 |
+
image = image.detach().cpu().numpy()
|
786 |
+
elif isinstance(image, np.ndarray):
|
787 |
+
image = image.copy()
|
788 |
+
else:
|
789 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
790 |
+
image = image[:, :, ::-1]
|
791 |
+
candidate, subset = self.body_estimation(image)
|
792 |
+
canvas = np.zeros_like(image)
|
793 |
+
canvas = draw_bodypose(canvas, candidate, subset)
|
794 |
+
if self.use_hand:
|
795 |
+
hands_list = handDetect(candidate, subset, image)
|
796 |
+
all_hand_peaks = []
|
797 |
+
for x, y, w, is_left in hands_list:
|
798 |
+
peaks = self.hand_estimation(image[y:y + w, x:x + w, :])
|
799 |
+
peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0],
|
800 |
+
peaks[:, 0] + x)
|
801 |
+
peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1],
|
802 |
+
peaks[:, 1] + y)
|
803 |
+
all_hand_peaks.append(peaks)
|
804 |
+
canvas = draw_handpose(canvas, all_hand_peaks)
|
805 |
+
return canvas
|
806 |
+
|
807 |
+
@staticmethod
|
808 |
+
def get_config_template():
|
809 |
+
return dict_to_yaml('ANNOTATORS',
|
810 |
+
__class__.__name__,
|
811 |
+
OpenposeAnnotator.para_dict,
|
812 |
+
set_name=True)
|
annotator/registry.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from scepter.modules.utils.config import Config
|
4 |
+
from scepter.modules.utils.registry import Registry, build_from_config
|
5 |
+
|
6 |
+
|
7 |
+
def build_annotator(cfg, registry, logger=None, *args, **kwargs):
|
8 |
+
""" After build model, load pretrained model if exists key `pretrain`.
|
9 |
+
|
10 |
+
pretrain (str, dict): Describes how to load pretrained model.
|
11 |
+
str, treat pretrain as model path;
|
12 |
+
dict: should contains key `path`, and other parameters token by function load_pretrained();
|
13 |
+
"""
|
14 |
+
if not isinstance(cfg, Config):
|
15 |
+
raise TypeError(f'Config must be type dict, got {type(cfg)}')
|
16 |
+
if cfg.have('PRETRAINED_MODEL'):
|
17 |
+
pretrain_cfg = cfg.PRETRAINED_MODEL
|
18 |
+
if pretrain_cfg is not None and not isinstance(pretrain_cfg, (str)):
|
19 |
+
raise TypeError('Pretrain parameter must be a string')
|
20 |
+
else:
|
21 |
+
pretrain_cfg = None
|
22 |
+
|
23 |
+
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
|
24 |
+
if pretrain_cfg is not None:
|
25 |
+
if hasattr(model, 'load_pretrained_model'):
|
26 |
+
model.load_pretrained_model(pretrain_cfg)
|
27 |
+
return model
|
28 |
+
|
29 |
+
|
30 |
+
ANNOTATORS = Registry('ANNOTATORS', build_func=build_annotator)
|
annotator/utils.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def resize_image(input_image, resolution):
|
8 |
+
H, W, C = input_image.shape
|
9 |
+
H = float(H)
|
10 |
+
W = float(W)
|
11 |
+
k = float(resolution) / min(H, W)
|
12 |
+
H *= k
|
13 |
+
W *= k
|
14 |
+
H = int(np.round(H / 64.0)) * 64
|
15 |
+
W = int(np.round(W / 64.0)) * 64
|
16 |
+
img = cv2.resize(
|
17 |
+
input_image, (W, H),
|
18 |
+
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
19 |
+
return img, k
|
20 |
+
|
21 |
+
|
22 |
+
def resize_image_ori(h, w, image, k):
|
23 |
+
img = cv2.resize(
|
24 |
+
image, (w, h),
|
25 |
+
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
class AnnotatorProcessor():
|
30 |
+
canny_cfg = {
|
31 |
+
'NAME': 'CannyAnnotator',
|
32 |
+
'LOW_THRESHOLD': 100,
|
33 |
+
'HIGH_THRESHOLD': 200,
|
34 |
+
'INPUT_KEYS': ['img'],
|
35 |
+
'OUTPUT_KEYS': ['canny']
|
36 |
+
}
|
37 |
+
hed_cfg = {
|
38 |
+
'NAME': 'HedAnnotator',
|
39 |
+
'PRETRAINED_MODEL':
|
40 |
+
'ms://damo/scepter_scedit@annotator/ckpts/ControlNetHED.pth',
|
41 |
+
'INPUT_KEYS': ['img'],
|
42 |
+
'OUTPUT_KEYS': ['hed']
|
43 |
+
}
|
44 |
+
openpose_cfg = {
|
45 |
+
'NAME': 'OpenposeAnnotator',
|
46 |
+
'BODY_MODEL_PATH':
|
47 |
+
'ms://damo/scepter_scedit@annotator/ckpts/body_pose_model.pth',
|
48 |
+
'HAND_MODEL_PATH':
|
49 |
+
'ms://damo/scepter_scedit@annotator/ckpts/hand_pose_model.pth',
|
50 |
+
'INPUT_KEYS': ['img'],
|
51 |
+
'OUTPUT_KEYS': ['openpose']
|
52 |
+
}
|
53 |
+
midas_cfg = {
|
54 |
+
'NAME': 'MidasDetector',
|
55 |
+
'PRETRAINED_MODEL':
|
56 |
+
'ms://damo/scepter_scedit@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt',
|
57 |
+
'INPUT_KEYS': ['img'],
|
58 |
+
'OUTPUT_KEYS': ['depth']
|
59 |
+
}
|
60 |
+
mlsd_cfg = {
|
61 |
+
'NAME': 'MLSDdetector',
|
62 |
+
'PRETRAINED_MODEL':
|
63 |
+
'ms://damo/scepter_scedit@annotator/ckpts/mlsd_large_512_fp32.pth',
|
64 |
+
'INPUT_KEYS': ['img'],
|
65 |
+
'OUTPUT_KEYS': ['mlsd']
|
66 |
+
}
|
67 |
+
color_cfg = {
|
68 |
+
'NAME': 'ColorAnnotator',
|
69 |
+
'RATIO': 64,
|
70 |
+
'INPUT_KEYS': ['img'],
|
71 |
+
'OUTPUT_KEYS': ['color']
|
72 |
+
}
|
73 |
+
|
74 |
+
anno_type_map = {
|
75 |
+
'canny': canny_cfg,
|
76 |
+
'hed': hed_cfg,
|
77 |
+
'pose': openpose_cfg,
|
78 |
+
'depth': midas_cfg,
|
79 |
+
'mlsd': mlsd_cfg,
|
80 |
+
'color': color_cfg
|
81 |
+
}
|
82 |
+
|
83 |
+
def __init__(self, anno_type):
|
84 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
85 |
+
from scepter.modules.utils.config import Config
|
86 |
+
from scepter.modules.utils.distribute import we
|
87 |
+
|
88 |
+
if isinstance(anno_type, str):
|
89 |
+
assert anno_type in self.anno_type_map.keys()
|
90 |
+
anno_type = [anno_type]
|
91 |
+
elif isinstance(anno_type, (list, tuple)):
|
92 |
+
assert all(tp in self.anno_type_map.keys() for tp in anno_type)
|
93 |
+
else:
|
94 |
+
raise Exception(f'Error anno_type: {anno_type}')
|
95 |
+
|
96 |
+
general_dict = {
|
97 |
+
'NAME': 'GeneralAnnotator',
|
98 |
+
'ANNOTATORS': [self.anno_type_map[tp] for tp in anno_type]
|
99 |
+
}
|
100 |
+
general_anno = Config(cfg_dict=general_dict, load=False)
|
101 |
+
self.general_ins = ANNOTATORS.build(general_anno).to(we.device_id)
|
102 |
+
|
103 |
+
def run(self, image, anno_type=None):
|
104 |
+
output_image = self.general_ins({'img': image})
|
105 |
+
if anno_type is not None:
|
106 |
+
if isinstance(anno_type, str) and anno_type in output_image:
|
107 |
+
return output_image[anno_type]
|
108 |
+
else:
|
109 |
+
return {
|
110 |
+
tp: output_image[tp]
|
111 |
+
for tp in anno_type if tp in output_image
|
112 |
+
}
|
113 |
+
else:
|
114 |
+
return output_image
|
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
from PIL import Image
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from pipeline import prepare_white_image, MultiViewGenerator
|
7 |
+
from util import download_file, unzip_file
|
8 |
+
|
9 |
+
download_file("https://huggingface.co/aki-0421/character-360/resolve/main/v2.ckpt", "v2.ckpt")
|
10 |
+
download_file("https://huggingface.co/hbyang/Hi3D/resolve/main/ckpts.zip", "ckpts.zip")
|
11 |
+
|
12 |
+
unzip_file("ckpts.zip", ".")
|
13 |
+
|
14 |
+
multi_view_generator = MultiViewGenerator(checkpoint_path="v2.ckpt")
|
15 |
+
|
16 |
+
@spaces.GPU(duration=120)
|
17 |
+
def generate_images(input_image: Image.Image) -> List[Image.Image]:
|
18 |
+
white_image = prepare_white_image(input_image=input_image)
|
19 |
+
|
20 |
+
return multi_view_generator.infer(white_image=white_image)
|
21 |
+
|
22 |
+
|
23 |
+
with gr.Blocks() as demo:
|
24 |
+
gr.Markdown("# GPU-accelerated Image Processing")
|
25 |
+
with gr.Row():
|
26 |
+
input_image = gr.Image(label="Input Image", type="pil") # 入力はPIL形式
|
27 |
+
output_gallery = gr.Gallery(label="Output Images (25 Variations)").style(grid=(5, 5))
|
28 |
+
submit_button = gr.Button("Generate")
|
29 |
+
|
30 |
+
submit_button.click(generate_images, inputs=input_image, outputs=output_gallery)
|
31 |
+
|
32 |
+
demo.launch()
|
dataset/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
benchmarks/benchmarking_Random_grayscale.png
|
dataset/opencv_transforms/__init__.py
ADDED
File without changes
|
dataset/opencv_transforms/functional.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image, ImageEnhance, ImageOps
|
6 |
+
|
7 |
+
try:
|
8 |
+
import accimage
|
9 |
+
except ImportError:
|
10 |
+
accimage = None
|
11 |
+
import collections
|
12 |
+
import numbers
|
13 |
+
import types
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
_cv2_pad_to_str = {
|
21 |
+
'constant': cv2.BORDER_CONSTANT,
|
22 |
+
'edge': cv2.BORDER_REPLICATE,
|
23 |
+
'reflect': cv2.BORDER_REFLECT_101,
|
24 |
+
'symmetric': cv2.BORDER_REFLECT
|
25 |
+
}
|
26 |
+
_cv2_interpolation_to_str = {
|
27 |
+
'nearest': cv2.INTER_NEAREST,
|
28 |
+
'bilinear': cv2.INTER_LINEAR,
|
29 |
+
'area': cv2.INTER_AREA,
|
30 |
+
'bicubic': cv2.INTER_CUBIC,
|
31 |
+
'lanczos': cv2.INTER_LANCZOS4
|
32 |
+
}
|
33 |
+
_cv2_interpolation_from_str = {v: k for k, v in _cv2_interpolation_to_str.items()}
|
34 |
+
|
35 |
+
|
36 |
+
def _is_pil_image(img):
|
37 |
+
if accimage is not None:
|
38 |
+
return isinstance(img, (Image.Image, accimage.Image))
|
39 |
+
else:
|
40 |
+
return isinstance(img, Image.Image)
|
41 |
+
|
42 |
+
|
43 |
+
def _is_tensor_image(img):
|
44 |
+
return torch.is_tensor(img) and img.ndimension() == 3
|
45 |
+
|
46 |
+
|
47 |
+
def _is_numpy_image(img):
|
48 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
49 |
+
|
50 |
+
|
51 |
+
def to_tensor(pic):
|
52 |
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
53 |
+
See ``ToTensor`` for more details.
|
54 |
+
Args:
|
55 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
56 |
+
Returns:
|
57 |
+
Tensor: Converted image.
|
58 |
+
"""
|
59 |
+
if not (_is_numpy_image(pic)):
|
60 |
+
raise TypeError('pic should be ndarray. Got {}'.format(type(pic)))
|
61 |
+
|
62 |
+
# handle numpy array
|
63 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
64 |
+
# backward compatibility
|
65 |
+
if isinstance(img, torch.ByteTensor) or img.dtype == torch.uint8:
|
66 |
+
return img.float().div(255)
|
67 |
+
else:
|
68 |
+
return img
|
69 |
+
|
70 |
+
|
71 |
+
def normalize(tensor, mean, std):
|
72 |
+
"""Normalize a tensor image with mean and standard deviation.
|
73 |
+
.. note::
|
74 |
+
This transform acts in-place, i.e., it mutates the input tensor.
|
75 |
+
See :class:`~torchvision.transforms.Normalize` for more details.
|
76 |
+
Args:
|
77 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
78 |
+
mean (sequence): Sequence of means for each channel.
|
79 |
+
std (sequence): Sequence of standard deviations for each channely.
|
80 |
+
Returns:
|
81 |
+
Tensor: Normalized Tensor image.
|
82 |
+
"""
|
83 |
+
if not _is_tensor_image(tensor):
|
84 |
+
raise TypeError('tensor is not a torch image.')
|
85 |
+
|
86 |
+
# This is faster than using broadcasting, don't change without benchmarking
|
87 |
+
for t, m, s in zip(tensor, mean, std):
|
88 |
+
t.sub_(m).div_(s)
|
89 |
+
return tensor
|
90 |
+
|
91 |
+
|
92 |
+
def resize(img, size, interpolation=cv2.INTER_LINEAR):
|
93 |
+
r"""Resize the input numpy ndarray to the given size.
|
94 |
+
Args:
|
95 |
+
img (numpy ndarray): Image to be resized.
|
96 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
97 |
+
(h, w), the output size will be matched to this. If size is an int,
|
98 |
+
the smaller edge of the image will be matched to this number maintaing
|
99 |
+
the aspect ratio. i.e, if height > width, then image will be rescaled to
|
100 |
+
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
|
101 |
+
interpolation (int, optional): Desired interpolation. Default is
|
102 |
+
``cv2.INTER_LINEAR``
|
103 |
+
Returns:
|
104 |
+
PIL Image: Resized image.
|
105 |
+
"""
|
106 |
+
if not _is_numpy_image(img):
|
107 |
+
raise TypeError('img should be numpy image. Got {}'.format(type(img)))
|
108 |
+
if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):
|
109 |
+
raise TypeError('Got inappropriate size arg: {}'.format(size))
|
110 |
+
h, w = img.shape[0], img.shape[1]
|
111 |
+
|
112 |
+
if isinstance(size, int):
|
113 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
114 |
+
return img
|
115 |
+
if w < h:
|
116 |
+
ow = size
|
117 |
+
oh = int(size * h / w)
|
118 |
+
else:
|
119 |
+
oh = size
|
120 |
+
ow = int(size * w / h)
|
121 |
+
else:
|
122 |
+
ow, oh = size[1], size[0]
|
123 |
+
output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)
|
124 |
+
if img.shape[2] == 1:
|
125 |
+
return output[:, :, np.newaxis]
|
126 |
+
else:
|
127 |
+
return output
|
128 |
+
|
129 |
+
|
130 |
+
def scale(*args, **kwargs):
|
131 |
+
warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.")
|
132 |
+
return resize(*args, **kwargs)
|
133 |
+
|
134 |
+
|
135 |
+
def pad(img, padding, fill=0, padding_mode='constant'):
|
136 |
+
r"""Pad the given numpy ndarray on all sides with specified padding mode and fill value.
|
137 |
+
Args:
|
138 |
+
img (numpy ndarray): image to be padded.
|
139 |
+
padding (int or tuple): Padding on each border. If a single int is provided this
|
140 |
+
is used to pad all borders. If tuple of length 2 is provided this is the padding
|
141 |
+
on left/right and top/bottom respectively. If a tuple of length 4 is provided
|
142 |
+
this is the padding for the left, top, right and bottom borders
|
143 |
+
respectively.
|
144 |
+
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
|
145 |
+
length 3, it is used to fill R, G, B channels respectively.
|
146 |
+
This value is only used when the padding_mode is constant
|
147 |
+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
|
148 |
+
- constant: pads with a constant value, this value is specified with fill
|
149 |
+
- edge: pads with the last value on the edge of the image
|
150 |
+
- reflect: pads with reflection of image (without repeating the last value on the edge)
|
151 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
152 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
153 |
+
- symmetric: pads with reflection of image (repeating the last value on the edge)
|
154 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
155 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
156 |
+
Returns:
|
157 |
+
Numpy image: padded image.
|
158 |
+
"""
|
159 |
+
if not _is_numpy_image(img):
|
160 |
+
raise TypeError('img should be numpy ndarray. Got {}'.format(type(img)))
|
161 |
+
if not isinstance(padding, (numbers.Number, tuple, list)):
|
162 |
+
raise TypeError('Got inappropriate padding arg')
|
163 |
+
if not isinstance(fill, (numbers.Number, str, tuple)):
|
164 |
+
raise TypeError('Got inappropriate fill arg')
|
165 |
+
if not isinstance(padding_mode, str):
|
166 |
+
raise TypeError('Got inappropriate padding_mode arg')
|
167 |
+
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
|
168 |
+
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
|
169 |
+
"{} element tuple".format(len(padding)))
|
170 |
+
|
171 |
+
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
|
172 |
+
'Padding mode should be either constant, edge, reflect or symmetric'
|
173 |
+
|
174 |
+
if isinstance(padding, int):
|
175 |
+
pad_left = pad_right = pad_top = pad_bottom = padding
|
176 |
+
if isinstance(padding, collections.Sequence) and len(padding) == 2:
|
177 |
+
pad_left = pad_right = padding[0]
|
178 |
+
pad_top = pad_bottom = padding[1]
|
179 |
+
if isinstance(padding, collections.Sequence) and len(padding) == 4:
|
180 |
+
pad_left = padding[0]
|
181 |
+
pad_top = padding[1]
|
182 |
+
pad_right = padding[2]
|
183 |
+
pad_bottom = padding[3]
|
184 |
+
if img.shape[2] == 1:
|
185 |
+
return cv2.copyMakeBorder(img,
|
186 |
+
top=pad_top,
|
187 |
+
bottom=pad_bottom,
|
188 |
+
left=pad_left,
|
189 |
+
right=pad_right,
|
190 |
+
borderType=_cv2_pad_to_str[padding_mode],
|
191 |
+
value=fill)[:, :, np.newaxis]
|
192 |
+
else:
|
193 |
+
return cv2.copyMakeBorder(img,
|
194 |
+
top=pad_top,
|
195 |
+
bottom=pad_bottom,
|
196 |
+
left=pad_left,
|
197 |
+
right=pad_right,
|
198 |
+
borderType=_cv2_pad_to_str[padding_mode],
|
199 |
+
value=fill)
|
200 |
+
|
201 |
+
|
202 |
+
def crop(img, i, j, h, w):
|
203 |
+
"""Crop the given PIL Image.
|
204 |
+
Args:
|
205 |
+
img (numpy ndarray): Image to be cropped.
|
206 |
+
i: Upper pixel coordinate.
|
207 |
+
j: Left pixel coordinate.
|
208 |
+
h: Height of the cropped image.
|
209 |
+
w: Width of the cropped image.
|
210 |
+
Returns:
|
211 |
+
numpy ndarray: Cropped image.
|
212 |
+
"""
|
213 |
+
if not _is_numpy_image(img):
|
214 |
+
raise TypeError('img should be numpy image. Got {}'.format(type(img)))
|
215 |
+
|
216 |
+
return img[i:i + h, j:j + w, :]
|
217 |
+
|
218 |
+
|
219 |
+
def center_crop(img, output_size):
|
220 |
+
if isinstance(output_size, numbers.Number):
|
221 |
+
output_size = (int(output_size), int(output_size))
|
222 |
+
h, w = img.shape[0:2]
|
223 |
+
th, tw = output_size
|
224 |
+
i = int(round((h - th) / 2.))
|
225 |
+
j = int(round((w - tw) / 2.))
|
226 |
+
return crop(img, i, j, th, tw)
|
227 |
+
|
228 |
+
|
229 |
+
def resized_crop(img, i, j, h, w, size, interpolation=cv2.INTER_LINEAR):
|
230 |
+
"""Crop the given numpy ndarray and resize it to desired size.
|
231 |
+
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
|
232 |
+
Args:
|
233 |
+
img (numpy ndarray): Image to be cropped.
|
234 |
+
i: Upper pixel coordinate.
|
235 |
+
j: Left pixel coordinate.
|
236 |
+
h: Height of the cropped image.
|
237 |
+
w: Width of the cropped image.
|
238 |
+
size (sequence or int): Desired output size. Same semantics as ``scale``.
|
239 |
+
interpolation (int, optional): Desired interpolation. Default is
|
240 |
+
``cv2.INTER_CUBIC``.
|
241 |
+
Returns:
|
242 |
+
PIL Image: Cropped image.
|
243 |
+
"""
|
244 |
+
assert _is_numpy_image(img), 'img should be numpy image'
|
245 |
+
img = crop(img, i, j, h, w)
|
246 |
+
img = resize(img, size, interpolation=interpolation)
|
247 |
+
return img
|
248 |
+
|
249 |
+
|
250 |
+
def hflip(img):
|
251 |
+
"""Horizontally flip the given numpy ndarray.
|
252 |
+
Args:
|
253 |
+
img (numpy ndarray): image to be flipped.
|
254 |
+
Returns:
|
255 |
+
numpy ndarray: Horizontally flipped image.
|
256 |
+
"""
|
257 |
+
if not _is_numpy_image(img):
|
258 |
+
raise TypeError('img should be numpy image. Got {}'.format(type(img)))
|
259 |
+
# img[:,::-1] is much faster, but doesn't work with torch.from_numpy()!
|
260 |
+
if img.shape[2] == 1:
|
261 |
+
return cv2.flip(img, 1)[:, :, np.newaxis]
|
262 |
+
else:
|
263 |
+
return cv2.flip(img, 1)
|
264 |
+
|
265 |
+
|
266 |
+
def vflip(img):
|
267 |
+
"""Vertically flip the given numpy ndarray.
|
268 |
+
Args:
|
269 |
+
img (numpy ndarray): Image to be flipped.
|
270 |
+
Returns:
|
271 |
+
numpy ndarray: Vertically flipped image.
|
272 |
+
"""
|
273 |
+
if not _is_numpy_image(img):
|
274 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
275 |
+
if img.shape[2] == 1:
|
276 |
+
return cv2.flip(img, 0)[:, :, np.newaxis]
|
277 |
+
else:
|
278 |
+
return cv2.flip(img, 0)
|
279 |
+
# img[::-1] is much faster, but doesn't work with torch.from_numpy()!
|
280 |
+
|
281 |
+
|
282 |
+
def five_crop(img, size):
|
283 |
+
"""Crop the given numpy ndarray into four corners and the central crop.
|
284 |
+
.. Note::
|
285 |
+
This transform returns a tuple of images and there may be a
|
286 |
+
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
287 |
+
Args:
|
288 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
289 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
290 |
+
made.
|
291 |
+
Returns:
|
292 |
+
tuple: tuple (tl, tr, bl, br, center)
|
293 |
+
Corresponding top left, top right, bottom left, bottom right and center crop.
|
294 |
+
"""
|
295 |
+
if isinstance(size, numbers.Number):
|
296 |
+
size = (int(size), int(size))
|
297 |
+
else:
|
298 |
+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
299 |
+
|
300 |
+
h, w = img.shape[0:2]
|
301 |
+
crop_h, crop_w = size
|
302 |
+
if crop_w > w or crop_h > h:
|
303 |
+
raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w)))
|
304 |
+
tl = crop(img, 0, 0, crop_h, crop_w)
|
305 |
+
tr = crop(img, 0, w - crop_w, crop_h, crop_w)
|
306 |
+
bl = crop(img, h - crop_h, 0, crop_h, crop_w)
|
307 |
+
br = crop(img, h - crop_h, w - crop_w, crop_h, crop_w)
|
308 |
+
center = center_crop(img, (crop_h, crop_w))
|
309 |
+
return tl, tr, bl, br, center
|
310 |
+
|
311 |
+
|
312 |
+
def ten_crop(img, size, vertical_flip=False):
|
313 |
+
r"""Crop the given numpy ndarray into four corners and the central crop plus the
|
314 |
+
flipped version of these (horizontal flipping is used by default).
|
315 |
+
.. Note::
|
316 |
+
This transform returns a tuple of images and there may be a
|
317 |
+
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
318 |
+
Args:
|
319 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
320 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
321 |
+
made.
|
322 |
+
vertical_flip (bool): Use vertical flipping instead of horizontal
|
323 |
+
Returns:
|
324 |
+
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
|
325 |
+
Corresponding top left, top right, bottom left, bottom right and center crop
|
326 |
+
and same for the flipped image.
|
327 |
+
"""
|
328 |
+
if isinstance(size, numbers.Number):
|
329 |
+
size = (int(size), int(size))
|
330 |
+
else:
|
331 |
+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
332 |
+
|
333 |
+
first_five = five_crop(img, size)
|
334 |
+
|
335 |
+
if vertical_flip:
|
336 |
+
img = vflip(img)
|
337 |
+
else:
|
338 |
+
img = hflip(img)
|
339 |
+
|
340 |
+
second_five = five_crop(img, size)
|
341 |
+
return first_five + second_five
|
342 |
+
|
343 |
+
|
344 |
+
def adjust_brightness(img, brightness_factor):
|
345 |
+
"""Adjust brightness of an Image.
|
346 |
+
Args:
|
347 |
+
img (numpy ndarray): numpy ndarray to be adjusted.
|
348 |
+
brightness_factor (float): How much to adjust the brightness. Can be
|
349 |
+
any non negative number. 0 gives a black image, 1 gives the
|
350 |
+
original image while 2 increases the brightness by a factor of 2.
|
351 |
+
Returns:
|
352 |
+
numpy ndarray: Brightness adjusted image.
|
353 |
+
"""
|
354 |
+
if not _is_numpy_image(img):
|
355 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
356 |
+
table = np.array([i * brightness_factor for i in range(0, 256)]).clip(0, 255).astype('uint8')
|
357 |
+
# same thing but a bit slower
|
358 |
+
# cv2.convertScaleAbs(img, alpha=brightness_factor, beta=0)
|
359 |
+
if img.shape[2] == 1:
|
360 |
+
return cv2.LUT(img, table)[:, :, np.newaxis]
|
361 |
+
else:
|
362 |
+
return cv2.LUT(img, table)
|
363 |
+
|
364 |
+
|
365 |
+
def adjust_contrast(img, contrast_factor):
|
366 |
+
"""Adjust contrast of an mage.
|
367 |
+
Args:
|
368 |
+
img (numpy ndarray): numpy ndarray to be adjusted.
|
369 |
+
contrast_factor (float): How much to adjust the contrast. Can be any
|
370 |
+
non negative number. 0 gives a solid gray image, 1 gives the
|
371 |
+
original image while 2 increases the contrast by a factor of 2.
|
372 |
+
Returns:
|
373 |
+
numpy ndarray: Contrast adjusted image.
|
374 |
+
"""
|
375 |
+
# much faster to use the LUT construction than anything else I've tried
|
376 |
+
# it's because you have to change dtypes multiple times
|
377 |
+
if not _is_numpy_image(img):
|
378 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
379 |
+
|
380 |
+
# input is RGB
|
381 |
+
if img.ndim > 2 and img.shape[2] == 3:
|
382 |
+
mean_value = round(cv2.mean(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY))[0])
|
383 |
+
elif img.ndim == 2:
|
384 |
+
# grayscale input
|
385 |
+
mean_value = round(cv2.mean(img)[0])
|
386 |
+
else:
|
387 |
+
# multichannel input
|
388 |
+
mean_value = round(np.mean(img))
|
389 |
+
|
390 |
+
table = np.array([(i - mean_value) * contrast_factor + mean_value for i in range(0, 256)]).clip(0,
|
391 |
+
255).astype('uint8')
|
392 |
+
# enhancer = ImageEnhance.Contrast(img)
|
393 |
+
# img = enhancer.enhance(contrast_factor)
|
394 |
+
if img.ndim == 2 or img.shape[2] == 1:
|
395 |
+
return cv2.LUT(img, table)[:, :, np.newaxis]
|
396 |
+
else:
|
397 |
+
return cv2.LUT(img, table)
|
398 |
+
|
399 |
+
|
400 |
+
def adjust_saturation(img, saturation_factor):
|
401 |
+
"""Adjust color saturation of an image.
|
402 |
+
Args:
|
403 |
+
img (numpy ndarray): numpy ndarray to be adjusted.
|
404 |
+
saturation_factor (float): How much to adjust the saturation. 0 will
|
405 |
+
give a black and white image, 1 will give the original image while
|
406 |
+
2 will enhance the saturation by a factor of 2.
|
407 |
+
Returns:
|
408 |
+
numpy ndarray: Saturation adjusted image.
|
409 |
+
"""
|
410 |
+
# ~10ms slower than PIL!
|
411 |
+
if not _is_numpy_image(img):
|
412 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
413 |
+
img = Image.fromarray(img)
|
414 |
+
enhancer = ImageEnhance.Color(img)
|
415 |
+
img = enhancer.enhance(saturation_factor)
|
416 |
+
return np.array(img)
|
417 |
+
|
418 |
+
|
419 |
+
def adjust_hue(img, hue_factor):
|
420 |
+
"""Adjust hue of an image.
|
421 |
+
The image hue is adjusted by converting the image to HSV and
|
422 |
+
cyclically shifting the intensities in the hue channel (H).
|
423 |
+
The image is then converted back to original image mode.
|
424 |
+
`hue_factor` is the amount of shift in H channel and must be in the
|
425 |
+
interval `[-0.5, 0.5]`.
|
426 |
+
See `Hue`_ for more details.
|
427 |
+
.. _Hue: https://en.wikipedia.org/wiki/Hue
|
428 |
+
Args:
|
429 |
+
img (numpy ndarray): numpy ndarray to be adjusted.
|
430 |
+
hue_factor (float): How much to shift the hue channel. Should be in
|
431 |
+
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
432 |
+
HSV space in positive and negative direction respectively.
|
433 |
+
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
434 |
+
with complementary colors while 0 gives the original image.
|
435 |
+
Returns:
|
436 |
+
numpy ndarray: Hue adjusted image.
|
437 |
+
"""
|
438 |
+
# After testing, found that OpenCV calculates the Hue in a call to
|
439 |
+
# cv2.cvtColor(..., cv2.COLOR_BGR2HSV) differently from PIL
|
440 |
+
|
441 |
+
# This function takes 160ms! should be avoided
|
442 |
+
if not (-0.5 <= hue_factor <= 0.5):
|
443 |
+
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
|
444 |
+
if not _is_numpy_image(img):
|
445 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
446 |
+
img = Image.fromarray(img)
|
447 |
+
input_mode = img.mode
|
448 |
+
if input_mode in {'L', '1', 'I', 'F'}:
|
449 |
+
return np.array(img)
|
450 |
+
|
451 |
+
h, s, v = img.convert('HSV').split()
|
452 |
+
|
453 |
+
np_h = np.array(h, dtype=np.uint8)
|
454 |
+
# uint8 addition take cares of rotation across boundaries
|
455 |
+
with np.errstate(over='ignore'):
|
456 |
+
np_h += np.uint8(hue_factor * 255)
|
457 |
+
h = Image.fromarray(np_h, 'L')
|
458 |
+
|
459 |
+
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
460 |
+
return np.array(img)
|
461 |
+
|
462 |
+
|
463 |
+
def adjust_gamma(img, gamma, gain=1):
|
464 |
+
r"""Perform gamma correction on an image.
|
465 |
+
Also known as Power Law Transform. Intensities in RGB mode are adjusted
|
466 |
+
based on the following equation:
|
467 |
+
.. math::
|
468 |
+
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
|
469 |
+
See `Gamma Correction`_ for more details.
|
470 |
+
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
|
471 |
+
Args:
|
472 |
+
img (numpy ndarray): numpy ndarray to be adjusted.
|
473 |
+
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
|
474 |
+
gamma larger than 1 make the shadows darker,
|
475 |
+
while gamma smaller than 1 make dark regions lighter.
|
476 |
+
gain (float): The constant multiplier.
|
477 |
+
"""
|
478 |
+
if not _is_numpy_image(img):
|
479 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
480 |
+
|
481 |
+
if gamma < 0:
|
482 |
+
raise ValueError('Gamma should be a non-negative real number')
|
483 |
+
# from here
|
484 |
+
# https://stackoverflow.com/questions/33322488/how-to-change-image-illumination-in-opencv-python/41061351
|
485 |
+
table = np.array([((i / 255.0)**gamma) * 255 * gain for i in np.arange(0, 256)]).astype('uint8')
|
486 |
+
if img.shape[2] == 1:
|
487 |
+
return cv2.LUT(img, table)[:, :, np.newaxis]
|
488 |
+
else:
|
489 |
+
return cv2.LUT(img, table)
|
490 |
+
|
491 |
+
|
492 |
+
def rotate(img, angle, resample=False, expand=False, center=None):
|
493 |
+
"""Rotate the image by angle.
|
494 |
+
Args:
|
495 |
+
img (numpy ndarray): numpy ndarray to be rotated.
|
496 |
+
angle (float or int): In degrees degrees counter clockwise order.
|
497 |
+
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
|
498 |
+
An optional resampling filter. See `filters`_ for more information.
|
499 |
+
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
|
500 |
+
expand (bool, optional): Optional expansion flag.
|
501 |
+
If true, expands the output image to make it large enough to hold the entire rotated image.
|
502 |
+
If false or omitted, make the output image the same size as the input image.
|
503 |
+
Note that the expand flag assumes rotation around the center and no translation.
|
504 |
+
center (2-tuple, optional): Optional center of rotation.
|
505 |
+
Origin is the upper left corner.
|
506 |
+
Default is the center of the image.
|
507 |
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
508 |
+
"""
|
509 |
+
if not _is_numpy_image(img):
|
510 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
511 |
+
rows, cols = img.shape[0:2]
|
512 |
+
if center is None:
|
513 |
+
center = (cols / 2, rows / 2)
|
514 |
+
M = cv2.getRotationMatrix2D(center, angle, 1)
|
515 |
+
if img.shape[2] == 1:
|
516 |
+
return cv2.warpAffine(img, M, (cols, rows))[:, :, np.newaxis]
|
517 |
+
else:
|
518 |
+
return cv2.warpAffine(img, M, (cols, rows))
|
519 |
+
|
520 |
+
|
521 |
+
def _get_affine_matrix(center, angle, translate, scale, shear):
|
522 |
+
# Helper method to compute matrix for affine transformation
|
523 |
+
# We need compute affine transformation matrix: M = T * C * RSS * C^-1
|
524 |
+
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
|
525 |
+
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
|
526 |
+
# RSS is rotation with scale and shear matrix
|
527 |
+
# RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0]
|
528 |
+
# [ sin(a)*scale cos(a + shear)*scale 0]
|
529 |
+
# [ 0 0 1]
|
530 |
+
|
531 |
+
angle = math.radians(angle)
|
532 |
+
shear = math.radians(shear)
|
533 |
+
# scale = 1.0 / scale
|
534 |
+
|
535 |
+
T = np.array([[1, 0, translate[0]], [0, 1, translate[1]], [0, 0, 1]])
|
536 |
+
C = np.array([[1, 0, center[0]], [0, 1, center[1]], [0, 0, 1]])
|
537 |
+
RSS = np.array([[math.cos(angle) * scale, -math.sin(angle + shear) * scale, 0],
|
538 |
+
[math.sin(angle) * scale, math.cos(angle + shear) * scale, 0], [0, 0, 1]])
|
539 |
+
matrix = T @ C @ RSS @ np.linalg.inv(C)
|
540 |
+
|
541 |
+
return matrix[:2, :]
|
542 |
+
|
543 |
+
|
544 |
+
def affine(img, angle, translate, scale, shear, interpolation=cv2.INTER_LINEAR, mode=cv2.BORDER_CONSTANT, fillcolor=0):
|
545 |
+
"""Apply affine transformation on the image keeping image center invariant
|
546 |
+
Args:
|
547 |
+
img (numpy ndarray): numpy ndarray to be transformed.
|
548 |
+
angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
|
549 |
+
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
|
550 |
+
scale (float): overall scale
|
551 |
+
shear (float): shear angle value in degrees between -180 to 180, clockwise direction.
|
552 |
+
interpolation (``cv2.INTER_NEAREST` or ``cv2.INTER_LINEAR`` or ``cv2.INTER_AREA``, ``cv2.INTER_CUBIC``):
|
553 |
+
An optional resampling filter.
|
554 |
+
See `filters`_ for more information.
|
555 |
+
If omitted, it is set to ``cv2.INTER_CUBIC``, for bicubic interpolation.
|
556 |
+
mode (``cv2.BORDER_CONSTANT`` or ``cv2.BORDER_REPLICATE`` or ``cv2.BORDER_REFLECT`` or ``cv2.BORDER_REFLECT_101``)
|
557 |
+
Method for filling in border regions.
|
558 |
+
Defaults to cv2.BORDER_CONSTANT, meaning areas outside the image are filled with a value (val, default 0)
|
559 |
+
val (int): Optional fill color for the area outside the transform in the output image. Default: 0
|
560 |
+
"""
|
561 |
+
if not _is_numpy_image(img):
|
562 |
+
raise TypeError('img should be numpy Image. Got {}'.format(type(img)))
|
563 |
+
|
564 |
+
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
565 |
+
"Argument translate should be a list or tuple of length 2"
|
566 |
+
|
567 |
+
assert scale > 0.0, "Argument scale should be positive"
|
568 |
+
|
569 |
+
output_size = img.shape[0:2]
|
570 |
+
center = (img.shape[1] * 0.5 + 0.5, img.shape[0] * 0.5 + 0.5)
|
571 |
+
matrix = _get_affine_matrix(center, angle, translate, scale, shear)
|
572 |
+
|
573 |
+
if img.shape[2] == 1:
|
574 |
+
return cv2.warpAffine(img, matrix, output_size[::-1], interpolation, borderMode=mode,
|
575 |
+
borderValue=fillcolor)[:, :, np.newaxis]
|
576 |
+
else:
|
577 |
+
return cv2.warpAffine(img, matrix, output_size[::-1], interpolation, borderMode=mode, borderValue=fillcolor)
|
578 |
+
|
579 |
+
|
580 |
+
def to_grayscale(img, num_output_channels: int = 1):
|
581 |
+
"""Convert image to grayscale version of image.
|
582 |
+
Args:
|
583 |
+
img (numpy ndarray): Image to be converted to grayscale.
|
584 |
+
num_output_channels: int
|
585 |
+
if 1 : returned image is single channel
|
586 |
+
if 3 : returned image is 3 channel with r = g = b
|
587 |
+
Returns:
|
588 |
+
numpy ndarray: Grayscale version of the image.
|
589 |
+
"""
|
590 |
+
if not _is_numpy_image(img):
|
591 |
+
raise TypeError('img should be numpy ndarray. Got {}'.format(type(img)))
|
592 |
+
|
593 |
+
if num_output_channels == 1:
|
594 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
|
595 |
+
elif num_output_channels == 3:
|
596 |
+
# much faster than doing cvtColor to go back to gray
|
597 |
+
img = np.broadcast_to(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis], img.shape)
|
598 |
+
return img
|
dataset/opencv_transforms/transforms.py
ADDED
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import math
|
5 |
+
import numbers
|
6 |
+
import random
|
7 |
+
import types
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
# from PIL import Image, ImageOps, ImageEnhance
|
11 |
+
try:
|
12 |
+
import accimage
|
13 |
+
except ImportError:
|
14 |
+
accimage = None
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from . import functional as F
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"Compose", "ToTensor", "Normalize", "Resize", "Scale",
|
23 |
+
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice",
|
24 |
+
"RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip",
|
25 |
+
"RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
|
26 |
+
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine",
|
27 |
+
"Grayscale", "RandomGrayscale"
|
28 |
+
]
|
29 |
+
|
30 |
+
_cv2_pad_to_str = {
|
31 |
+
'constant': cv2.BORDER_CONSTANT,
|
32 |
+
'edge': cv2.BORDER_REPLICATE,
|
33 |
+
'reflect': cv2.BORDER_REFLECT_101,
|
34 |
+
'symmetric': cv2.BORDER_REFLECT
|
35 |
+
}
|
36 |
+
_cv2_interpolation_to_str = {
|
37 |
+
'nearest': cv2.INTER_NEAREST,
|
38 |
+
'bilinear': cv2.INTER_LINEAR,
|
39 |
+
'area': cv2.INTER_AREA,
|
40 |
+
'bicubic': cv2.INTER_CUBIC,
|
41 |
+
'lanczos': cv2.INTER_LANCZOS4
|
42 |
+
}
|
43 |
+
_cv2_interpolation_from_str = {
|
44 |
+
v: k
|
45 |
+
for k, v in _cv2_interpolation_to_str.items()
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
class Compose(object):
|
50 |
+
"""Composes several transforms together.
|
51 |
+
Args:
|
52 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
53 |
+
Example:
|
54 |
+
>>> transforms.Compose([
|
55 |
+
>>> transforms.CenterCrop(10),
|
56 |
+
>>> transforms.ToTensor(),
|
57 |
+
>>> ])
|
58 |
+
"""
|
59 |
+
def __init__(self, transforms):
|
60 |
+
self.transforms = transforms
|
61 |
+
|
62 |
+
def __call__(self, img):
|
63 |
+
for t in self.transforms:
|
64 |
+
img = t(img)
|
65 |
+
return img
|
66 |
+
|
67 |
+
def __repr__(self):
|
68 |
+
format_string = self.__class__.__name__ + '('
|
69 |
+
for t in self.transforms:
|
70 |
+
format_string += '\n'
|
71 |
+
format_string += ' {0}'.format(t)
|
72 |
+
format_string += '\n)'
|
73 |
+
return format_string
|
74 |
+
|
75 |
+
|
76 |
+
class ToTensor(object):
|
77 |
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
78 |
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
79 |
+
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
80 |
+
"""
|
81 |
+
def __call__(self, pic):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
85 |
+
Returns:
|
86 |
+
Tensor: Converted image.
|
87 |
+
"""
|
88 |
+
return F.to_tensor(pic)
|
89 |
+
|
90 |
+
def __repr__(self):
|
91 |
+
return self.__class__.__name__ + '()'
|
92 |
+
|
93 |
+
|
94 |
+
class Normalize(object):
|
95 |
+
"""Normalize a tensor image with mean and standard deviation.
|
96 |
+
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
97 |
+
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
98 |
+
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
99 |
+
.. note::
|
100 |
+
This transform acts in-place, i.e., it mutates the input tensor.
|
101 |
+
Args:
|
102 |
+
mean (sequence): Sequence of means for each channel.
|
103 |
+
std (sequence): Sequence of standard deviations for each channel.
|
104 |
+
"""
|
105 |
+
def __init__(self, mean, std):
|
106 |
+
self.mean = mean
|
107 |
+
self.std = std
|
108 |
+
|
109 |
+
def __call__(self, tensor):
|
110 |
+
"""
|
111 |
+
Args:
|
112 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
113 |
+
Returns:
|
114 |
+
Tensor: Normalized Tensor image.
|
115 |
+
"""
|
116 |
+
return F.normalize(tensor, self.mean, self.std)
|
117 |
+
|
118 |
+
def __repr__(self):
|
119 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(
|
120 |
+
self.mean, self.std)
|
121 |
+
|
122 |
+
|
123 |
+
class Resize(object):
|
124 |
+
"""Resize the input numpy ndarray to the given size.
|
125 |
+
Args:
|
126 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
127 |
+
(h, w), output size will be matched to this. If size is an int,
|
128 |
+
smaller edge of the image will be matched to this number.
|
129 |
+
i.e, if height > width, then image will be rescaled to
|
130 |
+
(size * height / width, size)
|
131 |
+
interpolation (int, optional): Desired interpolation. Default is
|
132 |
+
``cv2.INTER_CUBIC``, bicubic interpolation
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, size, interpolation=cv2.INTER_LINEAR):
|
136 |
+
# assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
|
137 |
+
if isinstance(size, int):
|
138 |
+
self.size = size
|
139 |
+
elif isinstance(size, collections.abc.Iterable) and len(size) == 2:
|
140 |
+
if type(size) == list:
|
141 |
+
size = tuple(size)
|
142 |
+
self.size = size
|
143 |
+
else:
|
144 |
+
raise ValueError('Unknown inputs for size: {}'.format(size))
|
145 |
+
self.interpolation = interpolation
|
146 |
+
|
147 |
+
def __call__(self, img):
|
148 |
+
"""
|
149 |
+
Args:
|
150 |
+
img (numpy ndarray): Image to be scaled.
|
151 |
+
Returns:
|
152 |
+
numpy ndarray: Rescaled image.
|
153 |
+
"""
|
154 |
+
return F.resize(img, self.size, self.interpolation)
|
155 |
+
|
156 |
+
def __repr__(self):
|
157 |
+
interpolate_str = _cv2_interpolation_from_str[self.interpolation]
|
158 |
+
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
|
159 |
+
self.size, interpolate_str)
|
160 |
+
|
161 |
+
|
162 |
+
class Scale(Resize):
|
163 |
+
"""
|
164 |
+
Note: This transform is deprecated in favor of Resize.
|
165 |
+
"""
|
166 |
+
def __init__(self, *args, **kwargs):
|
167 |
+
warnings.warn(
|
168 |
+
"The use of the transforms.Scale transform is deprecated, " +
|
169 |
+
"please use transforms.Resize instead.")
|
170 |
+
super(Scale, self).__init__(*args, **kwargs)
|
171 |
+
|
172 |
+
|
173 |
+
class CenterCrop(object):
|
174 |
+
"""Crops the given numpy ndarray at the center.
|
175 |
+
Args:
|
176 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
177 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
178 |
+
made.
|
179 |
+
"""
|
180 |
+
def __init__(self, size):
|
181 |
+
if isinstance(size, numbers.Number):
|
182 |
+
self.size = (int(size), int(size))
|
183 |
+
else:
|
184 |
+
self.size = size
|
185 |
+
|
186 |
+
def __call__(self, img):
|
187 |
+
"""
|
188 |
+
Args:
|
189 |
+
img (numpy ndarray): Image to be cropped.
|
190 |
+
Returns:
|
191 |
+
numpy ndarray: Cropped image.
|
192 |
+
"""
|
193 |
+
return F.center_crop(img, self.size)
|
194 |
+
|
195 |
+
def __repr__(self):
|
196 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
197 |
+
|
198 |
+
|
199 |
+
class Pad(object):
|
200 |
+
"""Pad the given numpy ndarray on all sides with the given "pad" value.
|
201 |
+
Args:
|
202 |
+
padding (int or tuple): Padding on each border. If a single int is provided this
|
203 |
+
is used to pad all borders. If tuple of length 2 is provided this is the padding
|
204 |
+
on left/right and top/bottom respectively. If a tuple of length 4 is provided
|
205 |
+
this is the padding for the left, top, right and bottom borders
|
206 |
+
respectively.
|
207 |
+
fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
|
208 |
+
length 3, it is used to fill R, G, B channels respectively.
|
209 |
+
This value is only used when the padding_mode is constant
|
210 |
+
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
|
211 |
+
Default is constant.
|
212 |
+
- constant: pads with a constant value, this value is specified with fill
|
213 |
+
- edge: pads with the last value at the edge of the image
|
214 |
+
- reflect: pads with reflection of image without repeating the last value on the edge
|
215 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
216 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
217 |
+
- symmetric: pads with reflection of image repeating the last value on the edge
|
218 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
219 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
220 |
+
"""
|
221 |
+
def __init__(self, padding, fill=0, padding_mode='constant'):
|
222 |
+
assert isinstance(padding, (numbers.Number, tuple, list))
|
223 |
+
assert isinstance(fill, (numbers.Number, str, tuple))
|
224 |
+
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
225 |
+
if isinstance(padding,
|
226 |
+
collections.Sequence) and len(padding) not in [2, 4]:
|
227 |
+
raise ValueError(
|
228 |
+
"Padding must be an int or a 2, or 4 element tuple, not a " +
|
229 |
+
"{} element tuple".format(len(padding)))
|
230 |
+
|
231 |
+
self.padding = padding
|
232 |
+
self.fill = fill
|
233 |
+
self.padding_mode = padding_mode
|
234 |
+
|
235 |
+
def __call__(self, img):
|
236 |
+
"""
|
237 |
+
Args:
|
238 |
+
img (numpy ndarray): Image to be padded.
|
239 |
+
Returns:
|
240 |
+
numpy ndarray: Padded image.
|
241 |
+
"""
|
242 |
+
return F.pad(img, self.padding, self.fill, self.padding_mode)
|
243 |
+
|
244 |
+
def __repr__(self):
|
245 |
+
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
|
246 |
+
format(self.padding, self.fill, self.padding_mode)
|
247 |
+
|
248 |
+
|
249 |
+
class Lambda(object):
|
250 |
+
"""Apply a user-defined lambda as a transform.
|
251 |
+
Args:
|
252 |
+
lambd (function): Lambda/function to be used for transform.
|
253 |
+
"""
|
254 |
+
def __init__(self, lambd):
|
255 |
+
assert isinstance(lambd, types.LambdaType)
|
256 |
+
self.lambd = lambd
|
257 |
+
|
258 |
+
def __call__(self, img):
|
259 |
+
return self.lambd(img)
|
260 |
+
|
261 |
+
def __repr__(self):
|
262 |
+
return self.__class__.__name__ + '()'
|
263 |
+
|
264 |
+
|
265 |
+
class RandomTransforms(object):
|
266 |
+
"""Base class for a list of transformations with randomness
|
267 |
+
Args:
|
268 |
+
transforms (list or tuple): list of transformations
|
269 |
+
"""
|
270 |
+
def __init__(self, transforms):
|
271 |
+
assert isinstance(transforms, (list, tuple))
|
272 |
+
self.transforms = transforms
|
273 |
+
|
274 |
+
def __call__(self, *args, **kwargs):
|
275 |
+
raise NotImplementedError()
|
276 |
+
|
277 |
+
def __repr__(self):
|
278 |
+
format_string = self.__class__.__name__ + '('
|
279 |
+
for t in self.transforms:
|
280 |
+
format_string += '\n'
|
281 |
+
format_string += ' {0}'.format(t)
|
282 |
+
format_string += '\n)'
|
283 |
+
return format_string
|
284 |
+
|
285 |
+
|
286 |
+
class RandomApply(RandomTransforms):
|
287 |
+
"""Apply randomly a list of transformations with a given probability
|
288 |
+
Args:
|
289 |
+
transforms (list or tuple): list of transformations
|
290 |
+
p (float): probability
|
291 |
+
"""
|
292 |
+
def __init__(self, transforms, p=0.5):
|
293 |
+
super(RandomApply, self).__init__(transforms)
|
294 |
+
self.p = p
|
295 |
+
|
296 |
+
def __call__(self, img):
|
297 |
+
if self.p < random.random():
|
298 |
+
return img
|
299 |
+
for t in self.transforms:
|
300 |
+
img = t(img)
|
301 |
+
return img
|
302 |
+
|
303 |
+
def __repr__(self):
|
304 |
+
format_string = self.__class__.__name__ + '('
|
305 |
+
format_string += '\n p={}'.format(self.p)
|
306 |
+
for t in self.transforms:
|
307 |
+
format_string += '\n'
|
308 |
+
format_string += ' {0}'.format(t)
|
309 |
+
format_string += '\n)'
|
310 |
+
return format_string
|
311 |
+
|
312 |
+
|
313 |
+
class RandomOrder(RandomTransforms):
|
314 |
+
"""Apply a list of transformations in a random order
|
315 |
+
"""
|
316 |
+
def __call__(self, img):
|
317 |
+
order = list(range(len(self.transforms)))
|
318 |
+
random.shuffle(order)
|
319 |
+
for i in order:
|
320 |
+
img = self.transforms[i](img)
|
321 |
+
return img
|
322 |
+
|
323 |
+
|
324 |
+
class RandomChoice(RandomTransforms):
|
325 |
+
"""Apply single transformation randomly picked from a list
|
326 |
+
"""
|
327 |
+
def __call__(self, img):
|
328 |
+
t = random.choice(self.transforms)
|
329 |
+
return t(img)
|
330 |
+
|
331 |
+
|
332 |
+
class RandomCrop(object):
|
333 |
+
"""Crop the given numpy ndarray at a random location.
|
334 |
+
Args:
|
335 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
336 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
337 |
+
made.
|
338 |
+
padding (int or sequence, optional): Optional padding on each border
|
339 |
+
of the image. Default is None, i.e no padding. If a sequence of length
|
340 |
+
4 is provided, it is used to pad left, top, right, bottom borders
|
341 |
+
respectively. If a sequence of length 2 is provided, it is used to
|
342 |
+
pad left/right, top/bottom borders, respectively.
|
343 |
+
pad_if_needed (boolean): It will pad the image if smaller than the
|
344 |
+
desired size to avoid raising an exception.
|
345 |
+
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
|
346 |
+
length 3, it is used to fill R, G, B channels respectively.
|
347 |
+
This value is only used when the padding_mode is constant
|
348 |
+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
|
349 |
+
- constant: pads with a constant value, this value is specified with fill
|
350 |
+
- edge: pads with the last value on the edge of the image
|
351 |
+
- reflect: pads with reflection of image (without repeating the last value on the edge)
|
352 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
353 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
354 |
+
- symmetric: pads with reflection of image (repeating the last value on the edge)
|
355 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
356 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
357 |
+
"""
|
358 |
+
def __init__(self,
|
359 |
+
size,
|
360 |
+
padding=None,
|
361 |
+
pad_if_needed=False,
|
362 |
+
fill=0,
|
363 |
+
padding_mode='constant'):
|
364 |
+
if isinstance(size, numbers.Number):
|
365 |
+
self.size = (int(size), int(size))
|
366 |
+
else:
|
367 |
+
self.size = size
|
368 |
+
self.padding = padding
|
369 |
+
self.pad_if_needed = pad_if_needed
|
370 |
+
self.fill = fill
|
371 |
+
self.padding_mode = padding_mode
|
372 |
+
|
373 |
+
@staticmethod
|
374 |
+
def get_params(img, output_size):
|
375 |
+
"""Get parameters for ``crop`` for a random crop.
|
376 |
+
Args:
|
377 |
+
img (numpy ndarray): Image to be cropped.
|
378 |
+
output_size (tuple): Expected output size of the crop.
|
379 |
+
Returns:
|
380 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
381 |
+
"""
|
382 |
+
h, w = img.shape[0:2]
|
383 |
+
th, tw = output_size
|
384 |
+
if w == tw and h == th:
|
385 |
+
return 0, 0, h, w
|
386 |
+
|
387 |
+
i = random.randint(0, h - th)
|
388 |
+
j = random.randint(0, w - tw)
|
389 |
+
return i, j, th, tw
|
390 |
+
|
391 |
+
def __call__(self, img):
|
392 |
+
"""
|
393 |
+
Args:
|
394 |
+
img (numpy ndarray): Image to be cropped.
|
395 |
+
Returns:
|
396 |
+
numpy ndarray: Cropped image.
|
397 |
+
"""
|
398 |
+
if self.padding is not None:
|
399 |
+
img = F.pad(img, self.padding, self.fill, self.padding_mode)
|
400 |
+
|
401 |
+
# pad the width if needed
|
402 |
+
if self.pad_if_needed and img.shape[1] < self.size[1]:
|
403 |
+
img = F.pad(img, (self.size[1] - img.shape[1], 0), self.fill,
|
404 |
+
self.padding_mode)
|
405 |
+
# pad the height if needed
|
406 |
+
if self.pad_if_needed and img.shape[0] < self.size[0]:
|
407 |
+
img = F.pad(img, (0, self.size[0] - img.shape[0]), self.fill,
|
408 |
+
self.padding_mode)
|
409 |
+
|
410 |
+
i, j, h, w = self.get_params(img, self.size)
|
411 |
+
|
412 |
+
return F.crop(img, i, j, h, w)
|
413 |
+
|
414 |
+
def __repr__(self):
|
415 |
+
return self.__class__.__name__ + '(size={0}, padding={1})'.format(
|
416 |
+
self.size, self.padding)
|
417 |
+
|
418 |
+
|
419 |
+
class RandomHorizontalFlip(object):
|
420 |
+
"""Horizontally flip the given PIL Image randomly with a given probability.
|
421 |
+
Args:
|
422 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
423 |
+
"""
|
424 |
+
def __init__(self, p=0.5):
|
425 |
+
self.p = p
|
426 |
+
|
427 |
+
def __call__(self, img):
|
428 |
+
"""random
|
429 |
+
Args:
|
430 |
+
img (numpy ndarray): Image to be flipped.
|
431 |
+
Returns:
|
432 |
+
numpy ndarray: Randomly flipped image.
|
433 |
+
"""
|
434 |
+
# if random.random() < self.p:
|
435 |
+
# print('flip')
|
436 |
+
# return F.hflip(img)
|
437 |
+
if random.random() < self.p:
|
438 |
+
return F.hflip(img)
|
439 |
+
return img
|
440 |
+
|
441 |
+
def __repr__(self):
|
442 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|
443 |
+
|
444 |
+
|
445 |
+
class RandomVerticalFlip(object):
|
446 |
+
"""Vertically flip the given PIL Image randomly with a given probability.
|
447 |
+
Args:
|
448 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
449 |
+
"""
|
450 |
+
def __init__(self, p=0.5):
|
451 |
+
self.p = p
|
452 |
+
|
453 |
+
def __call__(self, img):
|
454 |
+
"""
|
455 |
+
Args:
|
456 |
+
img (numpy ndarray): Image to be flipped.
|
457 |
+
Returns:
|
458 |
+
numpy ndarray: Randomly flipped image.
|
459 |
+
"""
|
460 |
+
if random.random() < self.p:
|
461 |
+
return F.vflip(img)
|
462 |
+
return img
|
463 |
+
|
464 |
+
def __repr__(self):
|
465 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|
466 |
+
|
467 |
+
|
468 |
+
class RandomResizedCrop(object):
|
469 |
+
"""Crop the given numpy ndarray to random size and aspect ratio.
|
470 |
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
471 |
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
472 |
+
is finally resized to given size.
|
473 |
+
This is popularly used to train the Inception networks.
|
474 |
+
Args:
|
475 |
+
size: expected output size of each edge
|
476 |
+
scale: range of size of the origin size cropped
|
477 |
+
ratio: range of aspect ratio of the origin aspect ratio cropped
|
478 |
+
interpolation: Default: cv2.INTER_CUBIC
|
479 |
+
"""
|
480 |
+
def __init__(self,
|
481 |
+
size,
|
482 |
+
scale=(0.08, 1.0),
|
483 |
+
ratio=(3. / 4., 4. / 3.),
|
484 |
+
interpolation=cv2.INTER_LINEAR):
|
485 |
+
self.size = (size, size)
|
486 |
+
self.interpolation = interpolation
|
487 |
+
self.scale = scale
|
488 |
+
self.ratio = ratio
|
489 |
+
|
490 |
+
@staticmethod
|
491 |
+
def get_params(img, scale, ratio):
|
492 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
493 |
+
Args:
|
494 |
+
img (numpy ndarray): Image to be cropped.
|
495 |
+
scale (tuple): range of size of the origin size cropped
|
496 |
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
497 |
+
Returns:
|
498 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
499 |
+
sized crop.
|
500 |
+
"""
|
501 |
+
for attempt in range(10):
|
502 |
+
area = img.shape[0] * img.shape[1]
|
503 |
+
target_area = random.uniform(*scale) * area
|
504 |
+
aspect_ratio = random.uniform(*ratio)
|
505 |
+
|
506 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
507 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
508 |
+
|
509 |
+
if random.random() < 0.5:
|
510 |
+
w, h = h, w
|
511 |
+
|
512 |
+
if w <= img.shape[1] and h <= img.shape[0]:
|
513 |
+
i = random.randint(0, img.shape[0] - h)
|
514 |
+
j = random.randint(0, img.shape[1] - w)
|
515 |
+
return i, j, h, w
|
516 |
+
|
517 |
+
# Fallback
|
518 |
+
w = min(img.shape[0], img.shape[1])
|
519 |
+
i = (img.shape[0] - w) // 2
|
520 |
+
j = (img.shape[1] - w) // 2
|
521 |
+
return i, j, w, w
|
522 |
+
|
523 |
+
def __call__(self, img):
|
524 |
+
"""
|
525 |
+
Args:
|
526 |
+
img (numpy ndarray): Image to be cropped and resized.
|
527 |
+
Returns:
|
528 |
+
numpy ndarray: Randomly cropped and resized image.
|
529 |
+
"""
|
530 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
531 |
+
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
|
532 |
+
|
533 |
+
def __repr__(self):
|
534 |
+
interpolate_str = _cv2_interpolation_from_str[self.interpolation]
|
535 |
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
536 |
+
format_string += ', scale={0}'.format(
|
537 |
+
tuple(round(s, 4) for s in self.scale))
|
538 |
+
format_string += ', ratio={0}'.format(
|
539 |
+
tuple(round(r, 4) for r in self.ratio))
|
540 |
+
format_string += ', interpolation={0})'.format(interpolate_str)
|
541 |
+
return format_string
|
542 |
+
|
543 |
+
|
544 |
+
class RandomSizedCrop(RandomResizedCrop):
|
545 |
+
"""
|
546 |
+
Note: This transform is deprecated in favor of RandomResizedCrop.
|
547 |
+
"""
|
548 |
+
def __init__(self, *args, **kwargs):
|
549 |
+
warnings.warn(
|
550 |
+
"The use of the transforms.RandomSizedCrop transform is deprecated, "
|
551 |
+
+ "please use transforms.RandomResizedCrop instead.")
|
552 |
+
super(RandomSizedCrop, self).__init__(*args, **kwargs)
|
553 |
+
|
554 |
+
|
555 |
+
class FiveCrop(object):
|
556 |
+
"""Crop the given numpy ndarray into four corners and the central crop
|
557 |
+
.. Note::
|
558 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
559 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
560 |
+
this.
|
561 |
+
Args:
|
562 |
+
size (sequence or int): Desired output size of the crop. If size is an ``int``
|
563 |
+
instead of sequence like (h, w), a square crop of size (size, size) is made.
|
564 |
+
Example:
|
565 |
+
>>> transform = Compose([
|
566 |
+
>>> FiveCrop(size), # this is a list of numpy ndarrays
|
567 |
+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
|
568 |
+
>>> ])
|
569 |
+
>>> #In your test loop you can do the following:
|
570 |
+
>>> input, target = batch # input is a 5d tensor, target is 2d
|
571 |
+
>>> bs, ncrops, c, h, w = input.size()
|
572 |
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
|
573 |
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
|
574 |
+
"""
|
575 |
+
def __init__(self, size):
|
576 |
+
self.size = size
|
577 |
+
if isinstance(size, numbers.Number):
|
578 |
+
self.size = (int(size), int(size))
|
579 |
+
else:
|
580 |
+
assert len(
|
581 |
+
size
|
582 |
+
) == 2, "Please provide only two dimensions (h, w) for size."
|
583 |
+
self.size = size
|
584 |
+
|
585 |
+
def __call__(self, img):
|
586 |
+
return F.five_crop(img, self.size)
|
587 |
+
|
588 |
+
def __repr__(self):
|
589 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
590 |
+
|
591 |
+
|
592 |
+
class TenCrop(object):
|
593 |
+
"""Crop the given numpy ndarray into four corners and the central crop plus the flipped version of
|
594 |
+
these (horizontal flipping is used by default)
|
595 |
+
.. Note::
|
596 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
597 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
598 |
+
this.
|
599 |
+
Args:
|
600 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
601 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
602 |
+
made.
|
603 |
+
vertical_flip(bool): Use vertical flipping instead of horizontal
|
604 |
+
Example:
|
605 |
+
>>> transform = Compose([
|
606 |
+
>>> TenCrop(size), # this is a list of PIL Images
|
607 |
+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
|
608 |
+
>>> ])
|
609 |
+
>>> #In your test loop you can do the following:
|
610 |
+
>>> input, target = batch # input is a 5d tensor, target is 2d
|
611 |
+
>>> bs, ncrops, c, h, w = input.size()
|
612 |
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
|
613 |
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
|
614 |
+
"""
|
615 |
+
def __init__(self, size, vertical_flip=False):
|
616 |
+
self.size = size
|
617 |
+
if isinstance(size, numbers.Number):
|
618 |
+
self.size = (int(size), int(size))
|
619 |
+
else:
|
620 |
+
assert len(
|
621 |
+
size
|
622 |
+
) == 2, "Please provide only two dimensions (h, w) for size."
|
623 |
+
self.size = size
|
624 |
+
self.vertical_flip = vertical_flip
|
625 |
+
|
626 |
+
def __call__(self, img):
|
627 |
+
return F.ten_crop(img, self.size, self.vertical_flip)
|
628 |
+
|
629 |
+
def __repr__(self):
|
630 |
+
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(
|
631 |
+
self.size, self.vertical_flip)
|
632 |
+
|
633 |
+
|
634 |
+
class LinearTransformation(object):
|
635 |
+
"""Transform a tensor image with a square transformation matrix computed
|
636 |
+
offline.
|
637 |
+
Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
|
638 |
+
product with the transformation matrix and reshape the tensor to its
|
639 |
+
original shape.
|
640 |
+
Applications:
|
641 |
+
- whitening: zero-center the data, compute the data covariance matrix
|
642 |
+
[D x D] with np.dot(X.T, X), perform SVD on this matrix and
|
643 |
+
pass it as transformation_matrix.
|
644 |
+
Args:
|
645 |
+
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
|
646 |
+
"""
|
647 |
+
def __init__(self, transformation_matrix):
|
648 |
+
if transformation_matrix.size(0) != transformation_matrix.size(1):
|
649 |
+
raise ValueError("transformation_matrix should be square. Got " +
|
650 |
+
"[{} x {}] rectangular matrix.".format(
|
651 |
+
*transformation_matrix.size()))
|
652 |
+
self.transformation_matrix = transformation_matrix
|
653 |
+
|
654 |
+
def __call__(self, tensor):
|
655 |
+
"""
|
656 |
+
Args:
|
657 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
|
658 |
+
Returns:
|
659 |
+
Tensor: Transformed image.
|
660 |
+
"""
|
661 |
+
if tensor.size(0) * tensor.size(1) * tensor.size(
|
662 |
+
2) != self.transformation_matrix.size(0):
|
663 |
+
raise ValueError(
|
664 |
+
"tensor and transformation matrix have incompatible shape." +
|
665 |
+
"[{} x {} x {}] != ".format(*tensor.size()) +
|
666 |
+
"{}".format(self.transformation_matrix.size(0)))
|
667 |
+
flat_tensor = tensor.view(1, -1)
|
668 |
+
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
|
669 |
+
tensor = transformed_tensor.view(tensor.size())
|
670 |
+
return tensor
|
671 |
+
|
672 |
+
def __repr__(self):
|
673 |
+
format_string = self.__class__.__name__ + '('
|
674 |
+
format_string += (str(self.transformation_matrix.numpy().tolist()) +
|
675 |
+
')')
|
676 |
+
return format_string
|
677 |
+
|
678 |
+
|
679 |
+
class ColorJitter(object):
|
680 |
+
"""Randomly change the brightness, contrast and saturation of an image.
|
681 |
+
Args:
|
682 |
+
brightness (float or tuple of float (min, max)): How much to jitter brightness.
|
683 |
+
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
|
684 |
+
or the given [min, max]. Should be non negative numbers.
|
685 |
+
contrast (float or tuple of float (min, max)): How much to jitter contrast.
|
686 |
+
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
|
687 |
+
or the given [min, max]. Should be non negative numbers.
|
688 |
+
saturation (float or tuple of float (min, max)): How much to jitter saturation.
|
689 |
+
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
|
690 |
+
or the given [min, max]. Should be non negative numbers.
|
691 |
+
hue (float or tuple of float (min, max)): How much to jitter hue.
|
692 |
+
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
|
693 |
+
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
694 |
+
"""
|
695 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
696 |
+
self.brightness = self._check_input(brightness, 'brightness')
|
697 |
+
self.contrast = self._check_input(contrast, 'contrast')
|
698 |
+
self.saturation = self._check_input(saturation, 'saturation')
|
699 |
+
self.hue = self._check_input(hue,
|
700 |
+
'hue',
|
701 |
+
center=0,
|
702 |
+
bound=(-0.5, 0.5),
|
703 |
+
clip_first_on_zero=False)
|
704 |
+
if self.saturation is not None:
|
705 |
+
warnings.warn(
|
706 |
+
'Saturation jitter enabled. Will slow down loading immensely.')
|
707 |
+
if self.hue is not None:
|
708 |
+
warnings.warn(
|
709 |
+
'Hue jitter enabled. Will slow down loading immensely.')
|
710 |
+
|
711 |
+
def _check_input(self,
|
712 |
+
value,
|
713 |
+
name,
|
714 |
+
center=1,
|
715 |
+
bound=(0, float('inf')),
|
716 |
+
clip_first_on_zero=True):
|
717 |
+
if isinstance(value, numbers.Number):
|
718 |
+
if value < 0:
|
719 |
+
raise ValueError(
|
720 |
+
"If {} is a single number, it must be non negative.".
|
721 |
+
format(name))
|
722 |
+
value = [center - value, center + value]
|
723 |
+
if clip_first_on_zero:
|
724 |
+
value[0] = max(value[0], 0)
|
725 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
726 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
727 |
+
raise ValueError("{} values should be between {}".format(
|
728 |
+
name, bound))
|
729 |
+
else:
|
730 |
+
raise TypeError(
|
731 |
+
"{} should be a single number or a list/tuple with length 2.".
|
732 |
+
format(name))
|
733 |
+
|
734 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
735 |
+
# or (0., 0.) for hue, do nothing
|
736 |
+
if value[0] == value[1] == center:
|
737 |
+
value = None
|
738 |
+
return value
|
739 |
+
|
740 |
+
@staticmethod
|
741 |
+
def get_params(brightness, contrast, saturation, hue):
|
742 |
+
"""Get a randomized transform to be applied on image.
|
743 |
+
Arguments are same as that of __init__.
|
744 |
+
Returns:
|
745 |
+
Transform which randomly adjusts brightness, contrast and
|
746 |
+
saturation in a random order.
|
747 |
+
"""
|
748 |
+
transforms = []
|
749 |
+
|
750 |
+
if brightness is not None:
|
751 |
+
brightness_factor = random.uniform(brightness[0], brightness[1])
|
752 |
+
transforms.append(
|
753 |
+
Lambda(
|
754 |
+
lambda img: F.adjust_brightness(img, brightness_factor)))
|
755 |
+
|
756 |
+
if contrast is not None:
|
757 |
+
contrast_factor = random.uniform(contrast[0], contrast[1])
|
758 |
+
transforms.append(
|
759 |
+
Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
|
760 |
+
|
761 |
+
if saturation is not None:
|
762 |
+
saturation_factor = random.uniform(saturation[0], saturation[1])
|
763 |
+
transforms.append(
|
764 |
+
Lambda(
|
765 |
+
lambda img: F.adjust_saturation(img, saturation_factor)))
|
766 |
+
|
767 |
+
if hue is not None:
|
768 |
+
hue_factor = random.uniform(hue[0], hue[1])
|
769 |
+
transforms.append(
|
770 |
+
Lambda(lambda img: F.adjust_hue(img, hue_factor)))
|
771 |
+
|
772 |
+
random.shuffle(transforms)
|
773 |
+
transform = Compose(transforms)
|
774 |
+
|
775 |
+
return transform
|
776 |
+
|
777 |
+
def __call__(self, img):
|
778 |
+
"""
|
779 |
+
Args:
|
780 |
+
img (numpy ndarray): Input image.
|
781 |
+
Returns:
|
782 |
+
numpy ndarray: Color jittered image.
|
783 |
+
"""
|
784 |
+
transform = self.get_params(self.brightness, self.contrast,
|
785 |
+
self.saturation, self.hue)
|
786 |
+
return transform(img)
|
787 |
+
|
788 |
+
def __repr__(self):
|
789 |
+
format_string = self.__class__.__name__ + '('
|
790 |
+
format_string += 'brightness={0}'.format(self.brightness)
|
791 |
+
format_string += ', contrast={0}'.format(self.contrast)
|
792 |
+
format_string += ', saturation={0}'.format(self.saturation)
|
793 |
+
format_string += ', hue={0})'.format(self.hue)
|
794 |
+
return format_string
|
795 |
+
|
796 |
+
|
797 |
+
class RandomRotation(object):
|
798 |
+
"""Rotate the image by angle.
|
799 |
+
Args:
|
800 |
+
degrees (sequence or float or int): Range of degrees to select from.
|
801 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
802 |
+
will be (-degrees, +degrees).
|
803 |
+
resample ({cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4}, optional):
|
804 |
+
An optional resampling filter. See `filters`_ for more information.
|
805 |
+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
806 |
+
expand (bool, optional): Optional expansion flag.
|
807 |
+
If true, expands the output to make it large enough to hold the entire rotated image.
|
808 |
+
If false or omitted, make the output image the same size as the input image.
|
809 |
+
Note that the expand flag assumes rotation around the center and no translation.
|
810 |
+
center (2-tuple, optional): Optional center of rotation.
|
811 |
+
Origin is the upper left corner.
|
812 |
+
Default is the center of the image.
|
813 |
+
"""
|
814 |
+
def __init__(self, degrees, resample=False, expand=False, center=None):
|
815 |
+
if isinstance(degrees, numbers.Number):
|
816 |
+
if degrees < 0:
|
817 |
+
raise ValueError(
|
818 |
+
"If degrees is a single number, it must be positive.")
|
819 |
+
self.degrees = (-degrees, degrees)
|
820 |
+
else:
|
821 |
+
if len(degrees) != 2:
|
822 |
+
raise ValueError(
|
823 |
+
"If degrees is a sequence, it must be of len 2.")
|
824 |
+
self.degrees = degrees
|
825 |
+
|
826 |
+
self.resample = resample
|
827 |
+
self.expand = expand
|
828 |
+
self.center = center
|
829 |
+
|
830 |
+
@staticmethod
|
831 |
+
def get_params(degrees):
|
832 |
+
"""Get parameters for ``rotate`` for a random rotation.
|
833 |
+
Returns:
|
834 |
+
sequence: params to be passed to ``rotate`` for random rotation.
|
835 |
+
"""
|
836 |
+
angle = random.uniform(degrees[0], degrees[1])
|
837 |
+
|
838 |
+
return angle
|
839 |
+
|
840 |
+
def __call__(self, img):
|
841 |
+
"""
|
842 |
+
img (numpy ndarray): Image to be rotated.
|
843 |
+
Returns:
|
844 |
+
numpy ndarray: Rotated image.
|
845 |
+
"""
|
846 |
+
|
847 |
+
angle = self.get_params(self.degrees)
|
848 |
+
|
849 |
+
return F.rotate(img, angle, self.resample, self.expand, self.center)
|
850 |
+
|
851 |
+
def __repr__(self):
|
852 |
+
format_string = self.__class__.__name__ + '(degrees={0}'.format(
|
853 |
+
self.degrees)
|
854 |
+
format_string += ', resample={0}'.format(self.resample)
|
855 |
+
format_string += ', expand={0}'.format(self.expand)
|
856 |
+
if self.center is not None:
|
857 |
+
format_string += ', center={0}'.format(self.center)
|
858 |
+
format_string += ')'
|
859 |
+
return format_string
|
860 |
+
|
861 |
+
|
862 |
+
class RandomAffine(object):
|
863 |
+
"""Random affine transformation of the image keeping center invariant
|
864 |
+
Args:
|
865 |
+
degrees (sequence or float or int): Range of degrees to select from.
|
866 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
867 |
+
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
|
868 |
+
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
|
869 |
+
and vertical translations. For example translate=(a, b), then horizontal shift
|
870 |
+
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
|
871 |
+
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
|
872 |
+
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
|
873 |
+
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
|
874 |
+
shear (sequence or float or int, optional): Range of degrees to select from.
|
875 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
876 |
+
will be (-degrees, +degrees). Will not apply shear by default
|
877 |
+
resample ({cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4}, optional):
|
878 |
+
An optional resampling filter. See `filters`_ for more information.
|
879 |
+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
880 |
+
fillcolor (int): Optional fill color for the area outside the transform in the output image.
|
881 |
+
"""
|
882 |
+
def __init__(self,
|
883 |
+
degrees,
|
884 |
+
translate=None,
|
885 |
+
scale=None,
|
886 |
+
shear=None,
|
887 |
+
interpolation=cv2.INTER_LINEAR,
|
888 |
+
fillcolor=0):
|
889 |
+
if isinstance(degrees, numbers.Number):
|
890 |
+
if degrees < 0:
|
891 |
+
raise ValueError(
|
892 |
+
"If degrees is a single number, it must be positive.")
|
893 |
+
self.degrees = (-degrees, degrees)
|
894 |
+
else:
|
895 |
+
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
|
896 |
+
"degrees should be a list or tuple and it must be of length 2."
|
897 |
+
self.degrees = degrees
|
898 |
+
|
899 |
+
if translate is not None:
|
900 |
+
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
901 |
+
"translate should be a list or tuple and it must be of length 2."
|
902 |
+
for t in translate:
|
903 |
+
if not (0.0 <= t <= 1.0):
|
904 |
+
raise ValueError(
|
905 |
+
"translation values should be between 0 and 1")
|
906 |
+
self.translate = translate
|
907 |
+
|
908 |
+
if scale is not None:
|
909 |
+
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
910 |
+
"scale should be a list or tuple and it must be of length 2."
|
911 |
+
for s in scale:
|
912 |
+
if s <= 0:
|
913 |
+
raise ValueError("scale values should be positive")
|
914 |
+
self.scale = scale
|
915 |
+
|
916 |
+
if shear is not None:
|
917 |
+
if isinstance(shear, numbers.Number):
|
918 |
+
if shear < 0:
|
919 |
+
raise ValueError(
|
920 |
+
"If shear is a single number, it must be positive.")
|
921 |
+
self.shear = (-shear, shear)
|
922 |
+
else:
|
923 |
+
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
|
924 |
+
"shear should be a list or tuple and it must be of length 2."
|
925 |
+
self.shear = shear
|
926 |
+
else:
|
927 |
+
self.shear = shear
|
928 |
+
|
929 |
+
# self.resample = resample
|
930 |
+
self.interpolation = interpolation
|
931 |
+
self.fillcolor = fillcolor
|
932 |
+
|
933 |
+
@staticmethod
|
934 |
+
def get_params(degrees, translate, scale_ranges, shears, img_size):
|
935 |
+
"""Get parameters for affine transformation
|
936 |
+
Returns:
|
937 |
+
sequence: params to be passed to the affine transformation
|
938 |
+
"""
|
939 |
+
angle = random.uniform(degrees[0], degrees[1])
|
940 |
+
if translate is not None:
|
941 |
+
max_dx = translate[0] * img_size[0]
|
942 |
+
max_dy = translate[1] * img_size[1]
|
943 |
+
translations = (np.round(random.uniform(-max_dx, max_dx)),
|
944 |
+
np.round(random.uniform(-max_dy, max_dy)))
|
945 |
+
else:
|
946 |
+
translations = (0, 0)
|
947 |
+
|
948 |
+
if scale_ranges is not None:
|
949 |
+
scale = random.uniform(scale_ranges[0], scale_ranges[1])
|
950 |
+
else:
|
951 |
+
scale = 1.0
|
952 |
+
|
953 |
+
if shears is not None:
|
954 |
+
shear = random.uniform(shears[0], shears[1])
|
955 |
+
else:
|
956 |
+
shear = 0.0
|
957 |
+
|
958 |
+
return angle, translations, scale, shear
|
959 |
+
|
960 |
+
def __call__(self, img):
|
961 |
+
"""
|
962 |
+
img (numpy ndarray): Image to be transformed.
|
963 |
+
Returns:
|
964 |
+
numpy ndarray: Affine transformed image.
|
965 |
+
"""
|
966 |
+
ret = self.get_params(self.degrees, self.translate, self.scale,
|
967 |
+
self.shear, (img.shape[1], img.shape[0]))
|
968 |
+
return F.affine(img,
|
969 |
+
*ret,
|
970 |
+
interpolation=self.interpolation,
|
971 |
+
fillcolor=self.fillcolor)
|
972 |
+
|
973 |
+
def __repr__(self):
|
974 |
+
s = '{name}(degrees={degrees}'
|
975 |
+
if self.translate is not None:
|
976 |
+
s += ', translate={translate}'
|
977 |
+
if self.scale is not None:
|
978 |
+
s += ', scale={scale}'
|
979 |
+
if self.shear is not None:
|
980 |
+
s += ', shear={shear}'
|
981 |
+
if self.resample > 0:
|
982 |
+
s += ', resample={resample}'
|
983 |
+
if self.fillcolor != 0:
|
984 |
+
s += ', fillcolor={fillcolor}'
|
985 |
+
s += ')'
|
986 |
+
d = dict(self.__dict__)
|
987 |
+
d['resample'] = _cv2_interpolation_to_str[d['resample']]
|
988 |
+
return s.format(name=self.__class__.__name__, **d)
|
989 |
+
|
990 |
+
|
991 |
+
class Grayscale(object):
|
992 |
+
"""Convert image to grayscale.
|
993 |
+
Args:
|
994 |
+
num_output_channels (int): (1 or 3) number of channels desired for output image
|
995 |
+
Returns:
|
996 |
+
numpy ndarray: Grayscale version of the input.
|
997 |
+
- If num_output_channels == 1 : returned image is single channel
|
998 |
+
- If num_output_channels == 3 : returned image is 3 channel with r == g == b
|
999 |
+
"""
|
1000 |
+
def __init__(self, num_output_channels=1):
|
1001 |
+
self.num_output_channels = num_output_channels
|
1002 |
+
|
1003 |
+
def __call__(self, img):
|
1004 |
+
"""
|
1005 |
+
Args:
|
1006 |
+
img (numpy ndarray): Image to be converted to grayscale.
|
1007 |
+
Returns:
|
1008 |
+
numpy ndarray: Randomly grayscaled image.
|
1009 |
+
"""
|
1010 |
+
return F.to_grayscale(img,
|
1011 |
+
num_output_channels=self.num_output_channels)
|
1012 |
+
|
1013 |
+
def __repr__(self):
|
1014 |
+
return self.__class__.__name__ + '(num_output_channels={0})'.format(
|
1015 |
+
self.num_output_channels)
|
1016 |
+
|
1017 |
+
|
1018 |
+
class RandomGrayscale(object):
|
1019 |
+
"""Randomly convert image to grayscale with a probability of p (default 0.1).
|
1020 |
+
Args:
|
1021 |
+
p (float): probability that image should be converted to grayscale.
|
1022 |
+
Returns:
|
1023 |
+
numpy ndarray: Grayscale version of the input image with probability p and unchanged
|
1024 |
+
with probability (1-p).
|
1025 |
+
- If input image is 1 channel: grayscale version is 1 channel
|
1026 |
+
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
|
1027 |
+
"""
|
1028 |
+
def __init__(self, p=0.1):
|
1029 |
+
self.p = p
|
1030 |
+
|
1031 |
+
def __call__(self, img):
|
1032 |
+
"""
|
1033 |
+
Args:
|
1034 |
+
img (numpy ndarray): Image to be converted to grayscale.
|
1035 |
+
Returns:
|
1036 |
+
numpy ndarray: Randomly grayscaled image.
|
1037 |
+
"""
|
1038 |
+
num_output_channels = 3
|
1039 |
+
if random.random() < self.p:
|
1040 |
+
return F.to_grayscale(img, num_output_channels=num_output_channels)
|
1041 |
+
return img
|
1042 |
+
|
1043 |
+
def __repr__(self):
|
1044 |
+
return self.__class__.__name__ + '(p={0})'.format(self.p)
|
dataset/setup.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import setuptools
|
2 |
+
|
3 |
+
with open('README.md', 'r') as fh:
|
4 |
+
long_description = fh.read()
|
5 |
+
|
6 |
+
setuptools.setup(
|
7 |
+
name='opencv_transforms',
|
8 |
+
version='0.0.6',
|
9 |
+
author='Jim Bohnslav',
|
10 |
+
author_email='[email protected]',
|
11 |
+
description='A drop-in replacement for Torchvision Transforms using OpenCV',
|
12 |
+
keywords='pytorch image augmentations',
|
13 |
+
long_description=long_description,
|
14 |
+
long_description_content_type='text/markdown',
|
15 |
+
url='https://github.com/jbohnslav/opencv_transforms',
|
16 |
+
packages=setuptools.find_packages(),
|
17 |
+
classifiers=[
|
18 |
+
"Programming Language :: Python :: 3",
|
19 |
+
"License :: OSI Approved :: MIT License",
|
20 |
+
"Operating System :: OS Independent",
|
21 |
+
],
|
22 |
+
python_requires='>=3.6',
|
23 |
+
)
|
dataset/tests/compare_to_pil_for_testing.ipynb
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import glob\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"import random\n",
|
12 |
+
"\n",
|
13 |
+
"import cv2\n",
|
14 |
+
"import matplotlib.pyplot as plt\n",
|
15 |
+
"from PIL import Image\n",
|
16 |
+
"\n",
|
17 |
+
"from torchvision import transforms as pil_transforms\n",
|
18 |
+
"from torchvision.transforms import functional as F_pil\n",
|
19 |
+
"\n",
|
20 |
+
"import sys\n",
|
21 |
+
"sys.path.insert(0, '..')\n",
|
22 |
+
"from opencv_transforms import transforms\n",
|
23 |
+
"from opencv_transforms import functional as F\n",
|
24 |
+
"\n",
|
25 |
+
"from setup_testing_directory import get_testing_directory"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"datadir = get_testing_directory()\n",
|
35 |
+
"print(datadir)"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True)\n",
|
45 |
+
"train_images.sort()\n",
|
46 |
+
"print('Number of training images: {:,}'.format(len(train_images)))"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": null,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"random.seed(1)\n",
|
56 |
+
"imfile = random.choice(train_images)"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"def plot_pil_and_opencv(pil_image, opencv_image, orientation='row'):\n",
|
66 |
+
" if orientation == 'row':\n",
|
67 |
+
" rows, cols = 1,3\n",
|
68 |
+
" size = (8, 4)\n",
|
69 |
+
" else: \n",
|
70 |
+
" rows, cols = 3,1\n",
|
71 |
+
" size = (12, 6)\n",
|
72 |
+
" fig, axes = plt.subplots(rows, cols,figsize=size)\n",
|
73 |
+
" ax = axes[0]\n",
|
74 |
+
" ax.imshow(pil_image)\n",
|
75 |
+
" ax.set_title('PIL')\n",
|
76 |
+
"\n",
|
77 |
+
" ax = axes[1]\n",
|
78 |
+
" ax.imshow(opencv_image)\n",
|
79 |
+
" ax.set_title('opencv')\n",
|
80 |
+
"\n",
|
81 |
+
" ax = axes[2]\n",
|
82 |
+
" l1 = np.abs(pil_image - opencv_image).mean(axis=2)\n",
|
83 |
+
" ax.imshow(l1)\n",
|
84 |
+
" ax.set_title('| PIL - opencv|\\nMAE:{:.4f}'.format(l1.mean()))\n",
|
85 |
+
" plt.tight_layout()\n",
|
86 |
+
" plt.show()"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": null,
|
92 |
+
"metadata": {},
|
93 |
+
"outputs": [],
|
94 |
+
"source": [
|
95 |
+
"pil_image = Image.open(imfile)\n",
|
96 |
+
"image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB)"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"plot_pil_and_opencv(pil_image, image)"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": null,
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [],
|
113 |
+
"source": [
|
114 |
+
"pil_resized = pil_transforms.Resize((224, 224))(pil_image)"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": null,
|
120 |
+
"metadata": {},
|
121 |
+
"outputs": [],
|
122 |
+
"source": [
|
123 |
+
"resized = transforms.Resize(224)(image)"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": null,
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"plot_pil_and_opencv(pil_resized, resized)\n",
|
133 |
+
"plt.show()"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"metadata": {},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"def L1(pil: Image, image: np.ndarray) -> float:\n",
|
143 |
+
" return np.mean(np.abs(np.asarray(pil) - image))"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": null,
|
149 |
+
"metadata": {},
|
150 |
+
"outputs": [],
|
151 |
+
"source": [
|
152 |
+
"TOL = 1e-4\n",
|
153 |
+
"\n",
|
154 |
+
"l1 = L1(pil_resized, resized)\n",
|
155 |
+
"assert l1 - 88.9559 < TOL"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": null,
|
161 |
+
"metadata": {},
|
162 |
+
"outputs": [],
|
163 |
+
"source": [
|
164 |
+
"random.seed(1)\n",
|
165 |
+
"pil = pil_transforms.RandomRotation(10)(pil_image)\n",
|
166 |
+
"random.seed(1)\n",
|
167 |
+
"np_img = transforms.RandomRotation(10)(image)"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": null,
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [],
|
175 |
+
"source": [
|
176 |
+
"plot_pil_and_opencv(pil, np_img)"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [],
|
184 |
+
"source": [
|
185 |
+
"pil = pil_transforms.FiveCrop((224, 224))(pil_image)\n",
|
186 |
+
"cv = transforms.FiveCrop((224,224))(image)"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": null,
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [],
|
194 |
+
"source": [
|
195 |
+
"pil_stacked = np.hstack([np.asarray(i) for i in pil])\n",
|
196 |
+
"cv_stacked = np.hstack(cv)\n",
|
197 |
+
"\n",
|
198 |
+
"plot_pil_and_opencv(pil_stacked, cv_stacked, orientation='col')"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": null,
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [],
|
206 |
+
"source": [
|
207 |
+
"pil_stacked.shape"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": null,
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"l1"
|
217 |
+
]
|
218 |
+
}
|
219 |
+
],
|
220 |
+
"metadata": {
|
221 |
+
"kernelspec": {
|
222 |
+
"display_name": "opencv_transforms",
|
223 |
+
"language": "python",
|
224 |
+
"name": "opencv_transforms"
|
225 |
+
},
|
226 |
+
"language_info": {
|
227 |
+
"codemirror_mode": {
|
228 |
+
"name": "ipython",
|
229 |
+
"version": 3
|
230 |
+
},
|
231 |
+
"file_extension": ".py",
|
232 |
+
"mimetype": "text/x-python",
|
233 |
+
"name": "python",
|
234 |
+
"nbconvert_exporter": "python",
|
235 |
+
"pygments_lexer": "ipython3",
|
236 |
+
"version": "3.7.9"
|
237 |
+
}
|
238 |
+
},
|
239 |
+
"nbformat": 4,
|
240 |
+
"nbformat_minor": 4
|
241 |
+
}
|
dataset/tests/setup_testing_directory.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import Union
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
|
7 |
+
def get_testing_directory() -> str:
|
8 |
+
directory_file = 'testing_directory.txt'
|
9 |
+
directory_files = [directory_file, os.path.join('tests', directory_file)]
|
10 |
+
|
11 |
+
for directory_file in directory_files:
|
12 |
+
if os.path.isfile(directory_file):
|
13 |
+
with open(directory_file, 'r') as f:
|
14 |
+
testing_directory = f.read()
|
15 |
+
return testing_directory
|
16 |
+
raise ValueError('please run setup_testing_directory.py before attempting to run unit tests')
|
17 |
+
|
18 |
+
|
19 |
+
def setup_testing_directory(datadir: Union[str, os.PathLike], overwrite: bool = False) -> str:
|
20 |
+
testing_path_file = 'testing_directory.txt'
|
21 |
+
|
22 |
+
should_setup = True
|
23 |
+
if os.path.isfile(testing_path_file):
|
24 |
+
with open(testing_path_file, 'r') as f:
|
25 |
+
testing_directory = f.read()
|
26 |
+
if not os.path.isfile(testing_directory):
|
27 |
+
raise ValueError('saved testing directory {} does not exist, re-run ')
|
28 |
+
warnings.warn(
|
29 |
+
'Saved testing directory {} does not exist, downloading Thumos14...'.format(testing_directory))
|
30 |
+
else:
|
31 |
+
should_setup = False
|
32 |
+
if not should_setup:
|
33 |
+
return testing_directory
|
34 |
+
|
35 |
+
testing_directory = datadir
|
36 |
+
assert os.path.isdir(testing_directory)
|
37 |
+
assert os.path.isdir(os.path.join(testing_directory, 'train'))
|
38 |
+
assert os.path.isdir(os.path.join(testing_directory, 'val'))
|
39 |
+
with open('testing_directory.txt', 'w') as f:
|
40 |
+
f.write(testing_directory)
|
41 |
+
return testing_directory
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
parser = argparse.ArgumentParser('Setting up image directory for opencv transforms testing')
|
46 |
+
parser.add_argument('-d', '--datadir', default=os.getcwd(), help='Imagenet directory')
|
47 |
+
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
setup_testing_directory(args.datadir)
|
dataset/tests/test_color.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from PIL import Image
|
9 |
+
from PIL.Image import Image as PIL_image # for typing
|
10 |
+
import pytest
|
11 |
+
from torchvision import transforms as pil_transforms
|
12 |
+
from torchvision.transforms import functional as F_pil
|
13 |
+
|
14 |
+
from opencv_transforms import transforms
|
15 |
+
from opencv_transforms import functional as F
|
16 |
+
from setup_testing_directory import get_testing_directory
|
17 |
+
|
18 |
+
TOL = 1e-4
|
19 |
+
|
20 |
+
datadir = get_testing_directory()
|
21 |
+
train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True)
|
22 |
+
train_images.sort()
|
23 |
+
print('Number of training images: {:,}'.format(len(train_images)))
|
24 |
+
|
25 |
+
random.seed(1)
|
26 |
+
imfile = random.choice(train_images)
|
27 |
+
pil_image = Image.open(imfile)
|
28 |
+
image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB)
|
29 |
+
|
30 |
+
|
31 |
+
class TestContrast:
|
32 |
+
@pytest.mark.parametrize('random_seed', [1, 2, 3, 4])
|
33 |
+
@pytest.mark.parametrize('contrast_factor', [0.0, 0.5, 1.0, 2.0])
|
34 |
+
def test_contrast(self, contrast_factor, random_seed):
|
35 |
+
random.seed(random_seed)
|
36 |
+
imfile = random.choice(train_images)
|
37 |
+
pil_image = Image.open(imfile)
|
38 |
+
image = np.array(pil_image).copy()
|
39 |
+
|
40 |
+
pil_enhanced = F_pil.adjust_contrast(pil_image, contrast_factor)
|
41 |
+
np_enhanced = F.adjust_contrast(image, contrast_factor)
|
42 |
+
assert np.array_equal(np.array(pil_enhanced), np_enhanced.squeeze())
|
43 |
+
|
44 |
+
@pytest.mark.parametrize('n_images', [1, 11])
|
45 |
+
def test_multichannel_contrast(self, n_images, contrast_factor=0.1):
|
46 |
+
imfile = random.choice(train_images)
|
47 |
+
|
48 |
+
pil_image = Image.open(imfile)
|
49 |
+
image = np.array(pil_image).copy()
|
50 |
+
|
51 |
+
multichannel_image = np.concatenate([image for _ in range(n_images)], axis=-1)
|
52 |
+
# this will raise an exception in version 0.0.5
|
53 |
+
np_enchanced = F.adjust_contrast(multichannel_image, contrast_factor)
|
54 |
+
|
55 |
+
@pytest.mark.parametrize('contrast_factor', [0, 0.5, 1.0])
|
56 |
+
def test_grayscale_contrast(self, contrast_factor):
|
57 |
+
imfile = random.choice(train_images)
|
58 |
+
|
59 |
+
pil_image = Image.open(imfile)
|
60 |
+
image = np.array(pil_image).copy()
|
61 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
62 |
+
|
63 |
+
# make sure grayscale images work
|
64 |
+
pil_image = pil_image.convert('L')
|
65 |
+
|
66 |
+
pil_enhanced = F_pil.adjust_contrast(pil_image, contrast_factor)
|
67 |
+
np_enhanced = F.adjust_contrast(image, contrast_factor)
|
68 |
+
assert np.array_equal(np.array(pil_enhanced), np_enhanced.squeeze())
|
dataset/tests/test_spatial.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from PIL import Image
|
9 |
+
from PIL.Image import Image as PIL_image # for typing
|
10 |
+
|
11 |
+
from torchvision import transforms as pil_transforms
|
12 |
+
from torchvision.transforms import functional as F_pil
|
13 |
+
from opencv_transforms import transforms
|
14 |
+
from opencv_transforms import functional as F
|
15 |
+
|
16 |
+
from setup_testing_directory import get_testing_directory
|
17 |
+
from utils import L1
|
18 |
+
|
19 |
+
TOL = 1e-4
|
20 |
+
|
21 |
+
datadir = get_testing_directory()
|
22 |
+
train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True)
|
23 |
+
train_images.sort()
|
24 |
+
print('Number of training images: {:,}'.format(len(train_images)))
|
25 |
+
|
26 |
+
random.seed(1)
|
27 |
+
imfile = random.choice(train_images)
|
28 |
+
pil_image = Image.open(imfile)
|
29 |
+
image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB)
|
30 |
+
|
31 |
+
|
32 |
+
def test_resize():
|
33 |
+
pil_resized = pil_transforms.Resize((224, 224))(pil_image)
|
34 |
+
resized = transforms.Resize((224, 224))(image)
|
35 |
+
l1 = L1(pil_resized, resized)
|
36 |
+
assert l1 - 88.9559 < TOL
|
37 |
+
|
38 |
+
def test_rotation():
|
39 |
+
random.seed(1)
|
40 |
+
pil = pil_transforms.RandomRotation(10)(pil_image)
|
41 |
+
random.seed(1)
|
42 |
+
np_img = transforms.RandomRotation(10)(image)
|
43 |
+
l1 = L1(pil, np_img)
|
44 |
+
assert l1 - 86.7955 < TOL
|
45 |
+
|
46 |
+
def test_five_crop():
|
47 |
+
pil = pil_transforms.FiveCrop((224, 224))(pil_image)
|
48 |
+
cv = transforms.FiveCrop((224, 224))(image)
|
49 |
+
pil_stacked = np.hstack([np.asarray(i) for i in pil])
|
50 |
+
cv_stacked = np.hstack(cv)
|
51 |
+
l1 = L1(pil_stacked, cv_stacked)
|
52 |
+
assert l1 - 22.0444 < TOL
|
dataset/tests/utils.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL.Image import Image as PIL_image # for typing
|
5 |
+
|
6 |
+
|
7 |
+
def L1(pil: Union[PIL_image, np.ndarray], np_image: np.ndarray) -> float:
|
8 |
+
return np.abs(np.asarray(pil) - np_image).mean()
|
inference.yaml
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: vtdm.vtdm_gen_v01.VideoLDM
|
3 |
+
base_learning_rate: 1.0e-05
|
4 |
+
params:
|
5 |
+
input_key: video
|
6 |
+
scale_factor: 0.18215
|
7 |
+
log_keys: caption
|
8 |
+
num_samples: 25 #frame_rate
|
9 |
+
trained_param_keys:
|
10 |
+
- diffusion_model.label_emb.0.0.weight
|
11 |
+
- .emb_layers.
|
12 |
+
- .time_stack.
|
13 |
+
en_and_decode_n_samples_a_time: 25 #frame_rate
|
14 |
+
disable_first_stage_autocast: true
|
15 |
+
denoiser_config:
|
16 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
17 |
+
params:
|
18 |
+
scaling_config:
|
19 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
20 |
+
network_config:
|
21 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
22 |
+
params:
|
23 |
+
adm_in_channels: 768
|
24 |
+
num_classes: sequential
|
25 |
+
use_checkpoint: true
|
26 |
+
in_channels: 8
|
27 |
+
out_channels: 4
|
28 |
+
model_channels: 320
|
29 |
+
attention_resolutions:
|
30 |
+
- 4
|
31 |
+
- 2
|
32 |
+
- 1
|
33 |
+
num_res_blocks: 2
|
34 |
+
channel_mult:
|
35 |
+
- 1
|
36 |
+
- 2
|
37 |
+
- 4
|
38 |
+
- 4
|
39 |
+
num_head_channels: 64
|
40 |
+
use_linear_in_transformer: true
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 1024
|
43 |
+
spatial_transformer_attn_type: softmax-xformers
|
44 |
+
extra_ff_mix_layer: true
|
45 |
+
use_spatial_context: true
|
46 |
+
merge_strategy: learned_with_images
|
47 |
+
video_kernel_size:
|
48 |
+
- 3
|
49 |
+
- 1
|
50 |
+
- 1
|
51 |
+
conditioner_config:
|
52 |
+
target: sgm.modules.GeneralConditioner
|
53 |
+
params:
|
54 |
+
emb_models:
|
55 |
+
- is_trainable: false
|
56 |
+
input_key: cond_frames_without_noise
|
57 |
+
ucg_rate: 0.1
|
58 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
59 |
+
params:
|
60 |
+
n_cond_frames: 1
|
61 |
+
n_copies: 1
|
62 |
+
open_clip_embedding_config:
|
63 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
64 |
+
params:
|
65 |
+
version: ckpts/open_clip_pytorch_model.bin
|
66 |
+
freeze: true
|
67 |
+
- is_trainable: false
|
68 |
+
input_key: video
|
69 |
+
ucg_rate: 0.0
|
70 |
+
target: vtdm.encoders.AesEmbedder
|
71 |
+
- is_trainable: false
|
72 |
+
input_key: elevation
|
73 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
74 |
+
params:
|
75 |
+
outdim: 256
|
76 |
+
- input_key: cond_frames
|
77 |
+
is_trainable: false
|
78 |
+
ucg_rate: 0.1
|
79 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
80 |
+
params:
|
81 |
+
disable_encoder_autocast: true
|
82 |
+
n_cond_frames: 1
|
83 |
+
n_copies: 25 #frame_rate
|
84 |
+
is_ae: true
|
85 |
+
encoder_config:
|
86 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
87 |
+
params:
|
88 |
+
embed_dim: 4
|
89 |
+
monitor: val/rec_loss
|
90 |
+
ddconfig:
|
91 |
+
attn_type: vanilla-xformers
|
92 |
+
double_z: true
|
93 |
+
z_channels: 4
|
94 |
+
resolution: 256
|
95 |
+
in_channels: 3
|
96 |
+
out_ch: 3
|
97 |
+
ch: 128
|
98 |
+
ch_mult:
|
99 |
+
- 1
|
100 |
+
- 2
|
101 |
+
- 4
|
102 |
+
- 4
|
103 |
+
num_res_blocks: 2
|
104 |
+
attn_resolutions: []
|
105 |
+
dropout: 0.0
|
106 |
+
lossconfig:
|
107 |
+
target: torch.nn.Identity
|
108 |
+
- input_key: cond_aug
|
109 |
+
is_trainable: false
|
110 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
111 |
+
params:
|
112 |
+
outdim: 256
|
113 |
+
first_stage_config:
|
114 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
115 |
+
params:
|
116 |
+
embed_dim: 4
|
117 |
+
monitor: val/rec_loss
|
118 |
+
ddconfig:
|
119 |
+
attn_type: vanilla-xformers
|
120 |
+
double_z: true
|
121 |
+
z_channels: 4
|
122 |
+
resolution: 256
|
123 |
+
in_channels: 3
|
124 |
+
out_ch: 3
|
125 |
+
ch: 128
|
126 |
+
ch_mult:
|
127 |
+
- 1
|
128 |
+
- 2
|
129 |
+
- 4
|
130 |
+
- 4
|
131 |
+
num_res_blocks: 2
|
132 |
+
attn_resolutions: []
|
133 |
+
dropout: 0.0
|
134 |
+
lossconfig:
|
135 |
+
target: torch.nn.Identity
|
136 |
+
loss_fn_config:
|
137 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
138 |
+
params:
|
139 |
+
num_frames: 25 #frame_rate
|
140 |
+
batch2model_keys:
|
141 |
+
- num_video_frames
|
142 |
+
- image_only_indicator
|
143 |
+
sigma_sampler_config:
|
144 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
145 |
+
params:
|
146 |
+
p_mean: 1.0
|
147 |
+
p_std: 1.6
|
148 |
+
loss_weighting_config:
|
149 |
+
target: sgm.modules.diffusionmodules.loss_weighting.VWeighting
|
150 |
+
sampler_config:
|
151 |
+
target: sgm.modules.diffusionmodules.sampling.LinearMultistepSampler
|
152 |
+
params:
|
153 |
+
num_steps: 50
|
154 |
+
verbose: True
|
155 |
+
|
156 |
+
discretization_config:
|
157 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
158 |
+
params:
|
159 |
+
sigma_max: 700.0
|
160 |
+
|
161 |
+
guider_config:
|
162 |
+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
163 |
+
params:
|
164 |
+
num_frames: 25 #frame_rate
|
165 |
+
max_scale: 2.5
|
166 |
+
min_scale: 1.0
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
libopencv-dev
|
2 |
+
build-essential
|
pipeline.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
import cv2
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import einops
|
11 |
+
from pytorch_lightning import seed_everything
|
12 |
+
from transparent_background import Remover
|
13 |
+
|
14 |
+
from dataset.opencv_transforms.functional import to_tensor, center_crop
|
15 |
+
from vtdm.model import create_model
|
16 |
+
from vtdm.util import tensor2vid
|
17 |
+
|
18 |
+
remover = Remover(jit=False)
|
19 |
+
|
20 |
+
def cv2_to_pil(cv_image: np.ndarray) -> Image.Image:
|
21 |
+
return Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
|
22 |
+
|
23 |
+
|
24 |
+
def pil_to_cv2(pil_image: Image.Image) -> np.ndarray:
|
25 |
+
cv_image = np.array(pil_image)
|
26 |
+
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
|
27 |
+
return cv_image
|
28 |
+
|
29 |
+
def prepare_white_image(input_image: Image.Image) -> Image.Image:
|
30 |
+
# remove bg
|
31 |
+
output = remover.process(input_image, type='rgba')
|
32 |
+
|
33 |
+
# expand image
|
34 |
+
width, height = output.size
|
35 |
+
max_side = max(width, height)
|
36 |
+
white_image = Image.new('RGBA', (max_side, max_side), (0, 0, 0, 0))
|
37 |
+
x_offset = (max_side - width) // 2
|
38 |
+
y_offset = (max_side - height) // 2
|
39 |
+
white_image.paste(output, (x_offset, y_offset))
|
40 |
+
|
41 |
+
return white_image
|
42 |
+
|
43 |
+
|
44 |
+
class MultiViewGenerator:
|
45 |
+
def __init__(self, checkpoint_path, config_path="inference.yaml"):
|
46 |
+
self.models = {}
|
47 |
+
denoising_model = create_model(config_path).cpu()
|
48 |
+
denoising_model.init_from_ckpt(checkpoint_path)
|
49 |
+
denoising_model = denoising_model.cuda().half()
|
50 |
+
self.models["denoising_model"] = denoising_model
|
51 |
+
|
52 |
+
def denoising(self, frames, args):
|
53 |
+
with torch.no_grad():
|
54 |
+
C, T, H, W = frames.shape
|
55 |
+
batch = {"video": frames.unsqueeze(0)}
|
56 |
+
batch["elevation"] = (
|
57 |
+
torch.Tensor([args["elevation"]]).to(torch.int64).to(frames.device)
|
58 |
+
)
|
59 |
+
batch["fps_id"] = torch.Tensor([7]).to(torch.int64).to(frames.device)
|
60 |
+
batch["motion_bucket_id"] = (
|
61 |
+
torch.Tensor([127]).to(torch.int64).to(frames.device)
|
62 |
+
)
|
63 |
+
batch = self.models["denoising_model"].add_custom_cond(batch, infer=True)
|
64 |
+
|
65 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
66 |
+
c, uc = self.models[
|
67 |
+
"denoising_model"
|
68 |
+
].conditioner.get_unconditional_conditioning(
|
69 |
+
batch,
|
70 |
+
force_uc_zero_embeddings=["cond_frames", "cond_frames_without_noise"],
|
71 |
+
)
|
72 |
+
|
73 |
+
additional_model_inputs = {
|
74 |
+
"image_only_indicator": torch.zeros(2, T).to(
|
75 |
+
self.models["denoising_model"].device
|
76 |
+
),
|
77 |
+
"num_video_frames": batch["num_video_frames"],
|
78 |
+
}
|
79 |
+
|
80 |
+
def denoiser(input, sigma, c):
|
81 |
+
return self.models["denoising_model"].denoiser(
|
82 |
+
self.models["denoising_model"].model,
|
83 |
+
input,
|
84 |
+
sigma,
|
85 |
+
c,
|
86 |
+
**additional_model_inputs
|
87 |
+
)
|
88 |
+
|
89 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
90 |
+
randn = torch.randn(
|
91 |
+
[T, 4, H // 8, W // 8], device=self.models["denoising_model"].device
|
92 |
+
)
|
93 |
+
samples = self.models["denoising_model"].sampler(denoiser, randn, cond=c, uc=uc)
|
94 |
+
|
95 |
+
samples = self.models["denoising_model"].decode_first_stage(samples.half())
|
96 |
+
samples = einops.rearrange(samples, "(b t) c h w -> b c t h w", t=T)
|
97 |
+
|
98 |
+
return tensor2vid(samples)
|
99 |
+
|
100 |
+
def video_pipeline(self, frames, args) -> List[Image.Image]:
|
101 |
+
num_iter = args["num_iter"]
|
102 |
+
out_list = []
|
103 |
+
|
104 |
+
for _ in range(num_iter):
|
105 |
+
with torch.no_grad():
|
106 |
+
results = self.denoising(frames, args)
|
107 |
+
|
108 |
+
if len(out_list) == 0:
|
109 |
+
out_list = out_list + results
|
110 |
+
else:
|
111 |
+
out_list = out_list + results[1:]
|
112 |
+
|
113 |
+
img = out_list[-1]
|
114 |
+
img = to_tensor(img)
|
115 |
+
img = (img - 0.5) * 2.0
|
116 |
+
frames[:, 0] = img
|
117 |
+
|
118 |
+
result = []
|
119 |
+
|
120 |
+
for i, frame in enumerate(out_list):
|
121 |
+
input_image = cv2_to_pil(frame)
|
122 |
+
output_image = remover.process(input_image, type='rgba')
|
123 |
+
result.append(output_image)
|
124 |
+
|
125 |
+
return result
|
126 |
+
|
127 |
+
def process(self, white_image: Image.Image, args) -> List[Image.Image]:
|
128 |
+
img = pil_to_cv2(white_image)
|
129 |
+
frame_list = [img] * args["clip_size"]
|
130 |
+
|
131 |
+
h, w = frame_list[0].shape[0:2]
|
132 |
+
rate = max(
|
133 |
+
args["input_resolution"][0] * 1.0 / h, args["input_resolution"][1] * 1.0 / w
|
134 |
+
)
|
135 |
+
frame_list = [
|
136 |
+
cv2.resize(f, [math.ceil(w * rate), math.ceil(h * rate)]) for f in frame_list
|
137 |
+
]
|
138 |
+
frame_list = [
|
139 |
+
center_crop(f, [args["input_resolution"][0], args["input_resolution"][1]])
|
140 |
+
for f in frame_list
|
141 |
+
]
|
142 |
+
frame_list = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frame_list]
|
143 |
+
|
144 |
+
frame_list = [to_tensor(f) for f in frame_list]
|
145 |
+
frame_list = [(f - 0.5) * 2.0 for f in frame_list]
|
146 |
+
frames = torch.stack(frame_list, 1)
|
147 |
+
frames = frames.cuda()
|
148 |
+
|
149 |
+
self.models["denoising_model"].num_samples = args["clip_size"]
|
150 |
+
self.models["denoising_model"].image_size = args["input_resolution"]
|
151 |
+
|
152 |
+
return self.video_pipeline(frames, args)
|
153 |
+
|
154 |
+
def infer(self, white_image: Image.Image) -> List[Image.Image]:
|
155 |
+
seed = random.randint(0, 65535)
|
156 |
+
seed_everything(seed)
|
157 |
+
|
158 |
+
params = {
|
159 |
+
"clip_size": 25,
|
160 |
+
"input_resolution": [512, 512],
|
161 |
+
"num_iter": 1,
|
162 |
+
"aes": 6.0,
|
163 |
+
"mv": [0.0, 0.0, 0.0, 10.0],
|
164 |
+
"elevation": 0,
|
165 |
+
}
|
166 |
+
|
167 |
+
return self.process(white_image, params)
|
168 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
av
|
2 |
+
black==23.7.0
|
3 |
+
chardet==5.1.0
|
4 |
+
clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
|
5 |
+
cupy-cuda113
|
6 |
+
einops>=0.6.1
|
7 |
+
fairscale>=0.4.13
|
8 |
+
fire>=0.5.0
|
9 |
+
fsspec>=2023.6.0
|
10 |
+
invisible-watermark>=0.2.0
|
11 |
+
kornia==0.6.9
|
12 |
+
matplotlib>=3.7.2
|
13 |
+
natsort>=8.4.0
|
14 |
+
ninja>=1.11.1
|
15 |
+
numpy==1.26.4
|
16 |
+
omegaconf>=2.3.0
|
17 |
+
open-clip-torch>=2.20.0
|
18 |
+
opencv-python==4.6.0.66
|
19 |
+
pandas>=2.0.3
|
20 |
+
pillow>=9.5.0
|
21 |
+
pudb>=2022.1.3
|
22 |
+
pytorch-lightning==1.9
|
23 |
+
pyyaml>=6.0.1
|
24 |
+
transparent_background
|
25 |
+
scipy>=1.10.1
|
26 |
+
streamlit>=0.73.1
|
27 |
+
tensorboardx==2.6
|
28 |
+
timm>=0.9.2
|
29 |
+
tokenizers==0.12.1
|
30 |
+
torch>=2.1.0
|
31 |
+
torchaudio>=2.1.0
|
32 |
+
torchdata>=0.6.1
|
33 |
+
torchmetrics>=1.0.1
|
34 |
+
torchvision>=0.16.0
|
35 |
+
tqdm>=4.65.0
|
36 |
+
transformers==4.19.1
|
37 |
+
triton>=2.0.0
|
38 |
+
urllib3<1.27,>=1.25.4
|
39 |
+
wandb>=0.15.6
|
40 |
+
webdataset>=0.2.33
|
41 |
+
wheel>=0.41.0
|
42 |
+
xformers>=0.0.20
|
43 |
+
gradio
|
44 |
+
streamlit-keyup==0.2.0
|
45 |
+
deepspeed==0.14.5
|
46 |
+
test-tube
|
47 |
+
-e git+https://github.com/Stability-AI/datapipelines.git@8bce77d147033b3a5285b6d45ee85f33866964fc#egg=sdata
|
48 |
+
basicsr
|
49 |
+
pillow-heif
|
sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import AutoencodingEngine, DiffusionEngine
|
2 |
+
from .util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
sgm/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .dataset import StableDataModuleFromConfig
|
sgm/data/dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torchdata.datapipes.iter
|
4 |
+
import webdataset as wds
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from pytorch_lightning import LightningDataModule
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sdata import create_dataset, create_dummy_dataset, create_loader
|
10 |
+
except ImportError as e:
|
11 |
+
print("#" * 100)
|
12 |
+
print("Datasets not yet available")
|
13 |
+
print("to enable, we need to add stable-datasets as a submodule")
|
14 |
+
print("please use ``git submodule update --init --recursive``")
|
15 |
+
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
16 |
+
print("#" * 100)
|
17 |
+
exit(1)
|
18 |
+
|
19 |
+
|
20 |
+
class StableDataModuleFromConfig(LightningDataModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
train: DictConfig,
|
24 |
+
validation: Optional[DictConfig] = None,
|
25 |
+
test: Optional[DictConfig] = None,
|
26 |
+
skip_val_loader: bool = False,
|
27 |
+
dummy: bool = False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.train_config = train
|
31 |
+
assert (
|
32 |
+
"datapipeline" in self.train_config and "loader" in self.train_config
|
33 |
+
), "train config requires the fields `datapipeline` and `loader`"
|
34 |
+
|
35 |
+
self.val_config = validation
|
36 |
+
if not skip_val_loader:
|
37 |
+
if self.val_config is not None:
|
38 |
+
assert (
|
39 |
+
"datapipeline" in self.val_config and "loader" in self.val_config
|
40 |
+
), "validation config requires the fields `datapipeline` and `loader`"
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"Warning: No Validation datapipeline defined, using that one from training"
|
44 |
+
)
|
45 |
+
self.val_config = train
|
46 |
+
|
47 |
+
self.test_config = test
|
48 |
+
if self.test_config is not None:
|
49 |
+
assert (
|
50 |
+
"datapipeline" in self.test_config and "loader" in self.test_config
|
51 |
+
), "test config requires the fields `datapipeline` and `loader`"
|
52 |
+
|
53 |
+
self.dummy = dummy
|
54 |
+
if self.dummy:
|
55 |
+
print("#" * 100)
|
56 |
+
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
57 |
+
print("#" * 100)
|
58 |
+
|
59 |
+
def setup(self, stage: str) -> None:
|
60 |
+
print("Preparing datasets")
|
61 |
+
if self.dummy:
|
62 |
+
data_fn = create_dummy_dataset
|
63 |
+
else:
|
64 |
+
data_fn = create_dataset
|
65 |
+
|
66 |
+
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
|
67 |
+
if self.val_config:
|
68 |
+
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
|
69 |
+
if self.test_config:
|
70 |
+
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
|
71 |
+
|
72 |
+
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
|
73 |
+
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
|
74 |
+
return loader
|
75 |
+
|
76 |
+
def val_dataloader(self) -> wds.DataPipeline:
|
77 |
+
return create_loader(self.val_datapipeline, **self.val_config.loader)
|
78 |
+
|
79 |
+
def test_dataloader(self) -> wds.DataPipeline:
|
80 |
+
return create_loader(self.test_datapipeline, **self.test_config.loader)
|
sgm/data/video_dataset.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import PIL
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from skimage.io import imread
|
8 |
+
import webdataset as wds
|
9 |
+
import PIL.Image as Image
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
# from ldm.base_utils import read_pickle, pose_inverse
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
import torchvision
|
17 |
+
from einops import rearrange
|
18 |
+
|
19 |
+
def add_margin(pil_img, color=0, size=256):
|
20 |
+
width, height = pil_img.size
|
21 |
+
result = Image.new(pil_img.mode, (size, size), color)
|
22 |
+
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
|
23 |
+
return result
|
24 |
+
|
25 |
+
def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256):
|
26 |
+
image_input = Image.open(image_path)
|
27 |
+
|
28 |
+
if crop_size!=-1:
|
29 |
+
alpha_np = np.asarray(image_input)[:, :, 3]
|
30 |
+
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
|
31 |
+
min_x, min_y = np.min(coords, 0)
|
32 |
+
max_x, max_y = np.max(coords, 0)
|
33 |
+
ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
|
34 |
+
h, w = ref_img_.height, ref_img_.width
|
35 |
+
scale = crop_size / max(h, w)
|
36 |
+
h_, w_ = int(scale * h), int(scale * w)
|
37 |
+
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
|
38 |
+
image_input = add_margin(ref_img_, size=image_size)
|
39 |
+
else:
|
40 |
+
image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
|
41 |
+
image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
|
42 |
+
|
43 |
+
image_input = np.asarray(image_input)
|
44 |
+
image_input = image_input.astype(np.float32) / 255.0
|
45 |
+
ref_mask = image_input[:, :, 3:]
|
46 |
+
image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
|
47 |
+
image_input = image_input[:, :, :3] * 2.0 - 1.0
|
48 |
+
image_input = torch.from_numpy(image_input.astype(np.float32))
|
49 |
+
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
|
50 |
+
return {"input_image": image_input, "input_elevation": elevation_input}
|
51 |
+
|
52 |
+
|
53 |
+
class VideoTrainDataset(Dataset):
|
54 |
+
def __init__(self, base_folder='/data/yanghaibo/datas/OBJAVERSE-LVIS/images', width=1024, height=576, sample_frames=25):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
num_samples (int): Number of samples in the dataset.
|
58 |
+
channels (int): Number of channels, default is 3 for RGB.
|
59 |
+
"""
|
60 |
+
# Define the path to the folder containing video frames
|
61 |
+
self.base_folder = base_folder
|
62 |
+
self.folders = os.listdir(self.base_folder)
|
63 |
+
self.num_samples = len(self.folders)
|
64 |
+
self.channels = 3
|
65 |
+
self.width = width
|
66 |
+
self.height = height
|
67 |
+
self.sample_frames = sample_frames
|
68 |
+
self.elevations = [-10, 0, 10, 20, 30, 40]
|
69 |
+
|
70 |
+
def __len__(self):
|
71 |
+
return self.num_samples
|
72 |
+
|
73 |
+
def load_im(self, path):
|
74 |
+
img = imread(path)
|
75 |
+
img = img.astype(np.float32) / 255.0
|
76 |
+
mask = img[:,:,3:]
|
77 |
+
img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background
|
78 |
+
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
|
79 |
+
return img, mask
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
idx (int): Index of the sample to return.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
|
88 |
+
"""
|
89 |
+
# Randomly select a folder (representing a video) from the base folder
|
90 |
+
chosen_folder = random.choice(self.folders)
|
91 |
+
folder_path = os.path.join(self.base_folder, chosen_folder)
|
92 |
+
frames = os.listdir(folder_path)
|
93 |
+
# Sort the frames by name
|
94 |
+
frames.sort()
|
95 |
+
|
96 |
+
# Ensure the selected folder has at least `sample_frames`` frames
|
97 |
+
if len(frames) < self.sample_frames:
|
98 |
+
raise ValueError(
|
99 |
+
f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.")
|
100 |
+
|
101 |
+
# Randomly select a start index for frame sequence. Fixed elevation
|
102 |
+
start_idx = random.randint(0, len(frames) - 1)
|
103 |
+
range_id = int(start_idx / 16) # 0, 1, 2, 3, 4, 5
|
104 |
+
elevation = self.elevations[range_id]
|
105 |
+
selected_frames = []
|
106 |
+
|
107 |
+
for frame_idx in range(start_idx, (range_id + 1) * 16):
|
108 |
+
selected_frames.append(frames[frame_idx])
|
109 |
+
for frame_idx in range((range_id) * 16, start_idx):
|
110 |
+
selected_frames.append(frames[frame_idx])
|
111 |
+
|
112 |
+
# Initialize a tensor to store the pixel values
|
113 |
+
pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width))
|
114 |
+
|
115 |
+
# Load and process each frame
|
116 |
+
for i, frame_name in enumerate(selected_frames):
|
117 |
+
frame_path = os.path.join(folder_path, frame_name)
|
118 |
+
img, mask = self.load_im(frame_path)
|
119 |
+
# Resize the image and convert it to a tensor
|
120 |
+
img_resized = img.resize((self.width, self.height))
|
121 |
+
img_tensor = torch.from_numpy(np.array(img_resized)).float()
|
122 |
+
|
123 |
+
# Normalize the image by scaling pixel values to [-1, 1]
|
124 |
+
img_normalized = img_tensor / 127.5 - 1
|
125 |
+
|
126 |
+
# Rearrange channels if necessary
|
127 |
+
if self.channels == 3:
|
128 |
+
img_normalized = img_normalized.permute(
|
129 |
+
2, 0, 1) # For RGB images
|
130 |
+
elif self.channels == 1:
|
131 |
+
img_normalized = img_normalized.mean(
|
132 |
+
dim=2, keepdim=True) # For grayscale images
|
133 |
+
|
134 |
+
pixel_values[i] = img_normalized
|
135 |
+
|
136 |
+
pixel_values = rearrange(pixel_values, 't c h w -> c t h w')
|
137 |
+
|
138 |
+
caption = chosen_folder + "_" + str(start_idx)
|
139 |
+
|
140 |
+
return {'video': pixel_values, 'elevation': elevation, 'caption': caption, "fps_id": 7, "motion_bucket_id": 127}
|
141 |
+
|
142 |
+
class SyncDreamerEvalData(Dataset):
|
143 |
+
def __init__(self, image_dir):
|
144 |
+
self.image_size = 512
|
145 |
+
self.image_dir = Path(image_dir)
|
146 |
+
self.crop_size = 20
|
147 |
+
|
148 |
+
self.fns = []
|
149 |
+
for fn in Path(image_dir).iterdir():
|
150 |
+
if fn.suffix=='.png':
|
151 |
+
self.fns.append(fn)
|
152 |
+
print('============= length of dataset %d =============' % len(self.fns))
|
153 |
+
|
154 |
+
def __len__(self):
|
155 |
+
return len(self.fns)
|
156 |
+
|
157 |
+
def get_data_for_index(self, index):
|
158 |
+
input_img_fn = self.fns[index]
|
159 |
+
elevation = 0
|
160 |
+
return prepare_inputs(input_img_fn, elevation, 512)
|
161 |
+
|
162 |
+
def __getitem__(self, index):
|
163 |
+
return self.get_data_for_index(index)
|
164 |
+
|
165 |
+
class VideoDataset(pl.LightningDataModule):
|
166 |
+
def __init__(self, base_folder, eval_folder, width, height, sample_frames, batch_size, num_workers=4, seed=0, **kwargs):
|
167 |
+
super().__init__()
|
168 |
+
self.base_folder = base_folder
|
169 |
+
self.eval_folder = eval_folder
|
170 |
+
self.width = width
|
171 |
+
self.height = height
|
172 |
+
self.sample_frames = sample_frames
|
173 |
+
self.batch_size = batch_size
|
174 |
+
self.num_workers = num_workers
|
175 |
+
self.seed = seed
|
176 |
+
self.additional_args = kwargs
|
177 |
+
|
178 |
+
def setup(self):
|
179 |
+
self.train_dataset = VideoTrainDataset(self.base_folder, self.width, self.height, self.sample_frames)
|
180 |
+
self.val_dataset = SyncDreamerEvalData(image_dir=self.eval_folder)
|
181 |
+
|
182 |
+
def train_dataloader(self):
|
183 |
+
sampler = DistributedSampler(self.train_dataset, seed=self.seed)
|
184 |
+
return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
185 |
+
|
186 |
+
def val_dataloader(self):
|
187 |
+
loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
188 |
+
return loader
|
189 |
+
|
190 |
+
def test_dataloader(self):
|
191 |
+
return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
sgm/data/video_dataset_stage2_degradeImages.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import PIL
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from skimage.io import imread
|
8 |
+
import webdataset as wds
|
9 |
+
import PIL.Image as Image
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
# from ldm.base_utils import read_pickle, pose_inverse
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
import torchvision
|
17 |
+
from einops import rearrange
|
18 |
+
|
19 |
+
# for the degraded images
|
20 |
+
import yaml
|
21 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
22 |
+
import math
|
23 |
+
|
24 |
+
def add_margin(pil_img, color=0, size=256):
|
25 |
+
width, height = pil_img.size
|
26 |
+
result = Image.new(pil_img.mode, (size, size), color)
|
27 |
+
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
|
28 |
+
return result
|
29 |
+
|
30 |
+
def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256):
|
31 |
+
image_input = Image.open(image_path)
|
32 |
+
|
33 |
+
if crop_size!=-1:
|
34 |
+
alpha_np = np.asarray(image_input)[:, :, 3]
|
35 |
+
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
|
36 |
+
min_x, min_y = np.min(coords, 0)
|
37 |
+
max_x, max_y = np.max(coords, 0)
|
38 |
+
ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
|
39 |
+
h, w = ref_img_.height, ref_img_.width
|
40 |
+
scale = crop_size / max(h, w)
|
41 |
+
h_, w_ = int(scale * h), int(scale * w)
|
42 |
+
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
|
43 |
+
image_input = add_margin(ref_img_, size=image_size)
|
44 |
+
else:
|
45 |
+
image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
|
46 |
+
image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
|
47 |
+
|
48 |
+
image_input = np.asarray(image_input)
|
49 |
+
image_input = image_input.astype(np.float32) / 255.0
|
50 |
+
ref_mask = image_input[:, :, 3:]
|
51 |
+
image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
|
52 |
+
image_input = image_input[:, :, :3] * 2.0 - 1.0
|
53 |
+
image_input = torch.from_numpy(image_input.astype(np.float32))
|
54 |
+
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
|
55 |
+
return {"input_image": image_input, "input_elevation": elevation_input}
|
56 |
+
|
57 |
+
|
58 |
+
class VideoTrainDataset(Dataset):
|
59 |
+
def __init__(self, base_folder='/data/yanghaibo/datas/OBJAVERSE-LVIS/images', depth_folder="/mnt/drive2/3d/OBJAVERSE-DEPTH/depth256", width=1024, height=576, sample_frames=25):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
num_samples (int): Number of samples in the dataset.
|
63 |
+
channels (int): Number of channels, default is 3 for RGB.
|
64 |
+
"""
|
65 |
+
# Define the path to the folder containing video frames
|
66 |
+
self.base_folder = base_folder
|
67 |
+
self.depth_folder = depth_folder
|
68 |
+
# self.folders1 = os.listdir(self.base_folder)
|
69 |
+
# self.folders2 = os.listdir(self.depth_folder)
|
70 |
+
# self.folders = list(set(self.folders1).intersection(set(self.folders2)))
|
71 |
+
self.folders = os.listdir(self.base_folder)
|
72 |
+
self.num_samples = len(self.folders)
|
73 |
+
self.channels = 3
|
74 |
+
self.width = width
|
75 |
+
self.height = height
|
76 |
+
self.sample_frames = sample_frames
|
77 |
+
self.elevations = [-10, 0, 10, 20, 30, 40]
|
78 |
+
|
79 |
+
# for degraded images
|
80 |
+
with open('configs/train_realesrnet_x4plus.yml', mode='r') as f:
|
81 |
+
opt = yaml.load(f, Loader=yaml.FullLoader)
|
82 |
+
self.opt = opt
|
83 |
+
# blur settings for the first degradation
|
84 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
85 |
+
self.kernel_list = opt['kernel_list']
|
86 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
87 |
+
self.blur_sigma = opt['blur_sigma']
|
88 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
89 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
90 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
91 |
+
|
92 |
+
# blur settings for the second degradation
|
93 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
94 |
+
self.kernel_list2 = opt['kernel_list2']
|
95 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
96 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
97 |
+
self.betag_range2 = opt['betag_range2']
|
98 |
+
self.betap_range2 = opt['betap_range2']
|
99 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
100 |
+
|
101 |
+
# a final sinc filter
|
102 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
103 |
+
|
104 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
105 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
106 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
107 |
+
self.pulse_tensor[10, 10] = 1
|
108 |
+
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
return self.num_samples
|
112 |
+
|
113 |
+
def load_im(self, path):
|
114 |
+
img = imread(path)
|
115 |
+
img = img.astype(np.float32) / 255.0
|
116 |
+
mask = img[:,:,3:]
|
117 |
+
img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background
|
118 |
+
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
|
119 |
+
return img, mask
|
120 |
+
|
121 |
+
def __getitem__(self, idx):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
idx (int): Index of the sample to return.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
|
128 |
+
"""
|
129 |
+
# Randomly select a folder (representing a video) from the base folder
|
130 |
+
chosen_folder = random.choice(self.folders)
|
131 |
+
folder_path = os.path.join(self.base_folder, chosen_folder)
|
132 |
+
frames = os.listdir(folder_path)
|
133 |
+
# Sort the frames by name
|
134 |
+
frames.sort()
|
135 |
+
|
136 |
+
# Ensure the selected folder has at least `sample_frames`` frames
|
137 |
+
if len(frames) < self.sample_frames:
|
138 |
+
raise ValueError(
|
139 |
+
f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.")
|
140 |
+
|
141 |
+
# Randomly select a start index for frame sequence. Fixed elevation
|
142 |
+
start_idx = random.randint(0, len(frames) - 1)
|
143 |
+
# start_idx = random.choice([0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, 84, 88, 92])
|
144 |
+
range_id = int(start_idx / 16) # 0, 1, 2, 3, 4, 5
|
145 |
+
elevation = self.elevations[range_id]
|
146 |
+
selected_frames = []
|
147 |
+
|
148 |
+
for frame_idx in range(start_idx, (range_id + 1) * 16):
|
149 |
+
selected_frames.append(frames[frame_idx])
|
150 |
+
for frame_idx in range((range_id) * 16, start_idx):
|
151 |
+
selected_frames.append(frames[frame_idx])
|
152 |
+
|
153 |
+
# Initialize a tensor to store the pixel values
|
154 |
+
pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width))
|
155 |
+
masks = []
|
156 |
+
|
157 |
+
# Load and process each frame
|
158 |
+
for i, frame_name in enumerate(selected_frames):
|
159 |
+
frame_path = os.path.join(folder_path, frame_name)
|
160 |
+
img, mask = self.load_im(frame_path)
|
161 |
+
mask = mask.squeeze(-1)
|
162 |
+
masks.append(mask)
|
163 |
+
# Resize the image and convert it to a tensor
|
164 |
+
img_resized = img.resize((self.width, self.height))
|
165 |
+
img_tensor = torch.from_numpy(np.array(img_resized)).float()
|
166 |
+
|
167 |
+
# Normalize the image by scaling pixel values to [-1, 1]
|
168 |
+
img_normalized = img_tensor / 127.5 - 1
|
169 |
+
|
170 |
+
# Rearrange channels if necessary
|
171 |
+
if self.channels == 3:
|
172 |
+
img_normalized = img_normalized.permute(
|
173 |
+
2, 0, 1) # For RGB images
|
174 |
+
elif self.channels == 1:
|
175 |
+
img_normalized = img_normalized.mean(
|
176 |
+
dim=2, keepdim=True) # For grayscale images
|
177 |
+
|
178 |
+
pixel_values[i] = img_normalized
|
179 |
+
|
180 |
+
pixel_values = rearrange(pixel_values, 't c h w -> c t h w')
|
181 |
+
masks = torch.from_numpy(np.array(masks))
|
182 |
+
caption = chosen_folder
|
183 |
+
|
184 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! get the kernels for degraded images
|
185 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
186 |
+
kernels = []
|
187 |
+
kernel2s = []
|
188 |
+
sinc_kernels = []
|
189 |
+
for i in range(16):
|
190 |
+
kernel_size = random.choice(self.kernel_range)
|
191 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
192 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
193 |
+
if kernel_size < 13:
|
194 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
195 |
+
else:
|
196 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
197 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
198 |
+
else:
|
199 |
+
kernel = random_mixed_kernels(
|
200 |
+
self.kernel_list,
|
201 |
+
self.kernel_prob,
|
202 |
+
kernel_size,
|
203 |
+
self.blur_sigma,
|
204 |
+
self.blur_sigma, [-math.pi, math.pi],
|
205 |
+
self.betag_range,
|
206 |
+
self.betap_range,
|
207 |
+
noise_range=None)
|
208 |
+
# pad kernel
|
209 |
+
pad_size = (21 - kernel_size) // 2
|
210 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
211 |
+
kernels.append(kernel)
|
212 |
+
|
213 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
214 |
+
kernel_size = random.choice(self.kernel_range)
|
215 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
216 |
+
if kernel_size < 13:
|
217 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
218 |
+
else:
|
219 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
220 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
221 |
+
else:
|
222 |
+
kernel2 = random_mixed_kernels(
|
223 |
+
self.kernel_list2,
|
224 |
+
self.kernel_prob2,
|
225 |
+
kernel_size,
|
226 |
+
self.blur_sigma2,
|
227 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
228 |
+
self.betag_range2,
|
229 |
+
self.betap_range2,
|
230 |
+
noise_range=None)
|
231 |
+
|
232 |
+
# pad kernel
|
233 |
+
pad_size = (21 - kernel_size) // 2
|
234 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
235 |
+
kernel2s.append(kernel2)
|
236 |
+
|
237 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
238 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
239 |
+
kernel_size = random.choice(self.kernel_range)
|
240 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
241 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
242 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
243 |
+
else:
|
244 |
+
sinc_kernel = self.pulse_tensor
|
245 |
+
sinc_kernels.append(sinc_kernel)
|
246 |
+
kernels = np.array(kernels)
|
247 |
+
kernel2s = np.array(kernel2s)
|
248 |
+
sinc_kernels = torch.stack(sinc_kernels, 0)
|
249 |
+
kernels = torch.FloatTensor(kernels)
|
250 |
+
kernel2s = torch.FloatTensor(kernel2s)
|
251 |
+
return {'video': pixel_values, 'masks': masks, 'elevation': elevation, 'caption': caption, 'kernel1s': kernels, 'kernel2s': kernel2s, 'sinc_kernels': sinc_kernels} # (16, 3, 512, 512)-> (3, 16, 512, 512)
|
252 |
+
|
253 |
+
class SyncDreamerEvalData(Dataset):
|
254 |
+
def __init__(self, image_dir):
|
255 |
+
self.image_size = 512
|
256 |
+
self.image_dir = Path(image_dir)
|
257 |
+
self.crop_size = 20
|
258 |
+
|
259 |
+
self.fns = []
|
260 |
+
for fn in Path(image_dir).iterdir():
|
261 |
+
if fn.suffix=='.png':
|
262 |
+
self.fns.append(fn)
|
263 |
+
print('============= length of dataset %d =============' % len(self.fns))
|
264 |
+
|
265 |
+
def __len__(self):
|
266 |
+
return len(self.fns)
|
267 |
+
|
268 |
+
def get_data_for_index(self, index):
|
269 |
+
input_img_fn = self.fns[index]
|
270 |
+
elevation = 0
|
271 |
+
return prepare_inputs(input_img_fn, elevation, 512)
|
272 |
+
|
273 |
+
def __getitem__(self, index):
|
274 |
+
return self.get_data_for_index(index)
|
275 |
+
|
276 |
+
class VideoDataset(pl.LightningDataModule):
|
277 |
+
def __init__(self, base_folder, depth_folder, eval_folder, width, height, sample_frames, batch_size, num_workers=4, seed=0, **kwargs):
|
278 |
+
super().__init__()
|
279 |
+
self.base_folder = base_folder
|
280 |
+
self.depth_folder = depth_folder
|
281 |
+
self.eval_folder = eval_folder
|
282 |
+
self.width = width
|
283 |
+
self.height = height
|
284 |
+
self.sample_frames = sample_frames
|
285 |
+
self.batch_size = batch_size
|
286 |
+
self.num_workers = num_workers
|
287 |
+
self.seed = seed
|
288 |
+
self.additional_args = kwargs
|
289 |
+
|
290 |
+
def setup(self):
|
291 |
+
self.train_dataset = VideoTrainDataset(self.base_folder, self.depth_folder, self.width, self.height, self.sample_frames)
|
292 |
+
self.val_dataset = SyncDreamerEvalData(image_dir=self.eval_folder)
|
293 |
+
|
294 |
+
def train_dataloader(self):
|
295 |
+
sampler = DistributedSampler(self.train_dataset, seed=self.seed)
|
296 |
+
return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
297 |
+
|
298 |
+
def val_dataloader(self):
|
299 |
+
loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
300 |
+
return loader
|
301 |
+
|
302 |
+
def test_dataloader(self):
|
303 |
+
return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
sgm/inference/api.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
|
9 |
+
do_sample)
|
10 |
+
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
11 |
+
DPMPP2SAncestralSampler,
|
12 |
+
EulerAncestralSampler,
|
13 |
+
EulerEDMSampler,
|
14 |
+
HeunEDMSampler,
|
15 |
+
LinearMultistepSampler)
|
16 |
+
from sgm.util import load_model_from_config
|
17 |
+
|
18 |
+
|
19 |
+
class ModelArchitecture(str, Enum):
|
20 |
+
SD_2_1 = "stable-diffusion-v2-1"
|
21 |
+
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
22 |
+
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
23 |
+
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
24 |
+
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
25 |
+
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
26 |
+
|
27 |
+
|
28 |
+
class Sampler(str, Enum):
|
29 |
+
EULER_EDM = "EulerEDMSampler"
|
30 |
+
HEUN_EDM = "HeunEDMSampler"
|
31 |
+
EULER_ANCESTRAL = "EulerAncestralSampler"
|
32 |
+
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
33 |
+
DPMPP2M = "DPMPP2MSampler"
|
34 |
+
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
35 |
+
|
36 |
+
|
37 |
+
class Discretization(str, Enum):
|
38 |
+
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
39 |
+
EDM = "EDMDiscretization"
|
40 |
+
|
41 |
+
|
42 |
+
class Guider(str, Enum):
|
43 |
+
VANILLA = "VanillaCFG"
|
44 |
+
IDENTITY = "IdentityGuider"
|
45 |
+
|
46 |
+
|
47 |
+
class Thresholder(str, Enum):
|
48 |
+
NONE = "None"
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class SamplingParams:
|
53 |
+
width: int = 1024
|
54 |
+
height: int = 1024
|
55 |
+
steps: int = 50
|
56 |
+
sampler: Sampler = Sampler.DPMPP2M
|
57 |
+
discretization: Discretization = Discretization.LEGACY_DDPM
|
58 |
+
guider: Guider = Guider.VANILLA
|
59 |
+
thresholder: Thresholder = Thresholder.NONE
|
60 |
+
scale: float = 6.0
|
61 |
+
aesthetic_score: float = 5.0
|
62 |
+
negative_aesthetic_score: float = 5.0
|
63 |
+
img2img_strength: float = 1.0
|
64 |
+
orig_width: int = 1024
|
65 |
+
orig_height: int = 1024
|
66 |
+
crop_coords_top: int = 0
|
67 |
+
crop_coords_left: int = 0
|
68 |
+
sigma_min: float = 0.0292
|
69 |
+
sigma_max: float = 14.6146
|
70 |
+
rho: float = 3.0
|
71 |
+
s_churn: float = 0.0
|
72 |
+
s_tmin: float = 0.0
|
73 |
+
s_tmax: float = 999.0
|
74 |
+
s_noise: float = 1.0
|
75 |
+
eta: float = 1.0
|
76 |
+
order: int = 4
|
77 |
+
|
78 |
+
|
79 |
+
@dataclass
|
80 |
+
class SamplingSpec:
|
81 |
+
width: int
|
82 |
+
height: int
|
83 |
+
channels: int
|
84 |
+
factor: int
|
85 |
+
is_legacy: bool
|
86 |
+
config: str
|
87 |
+
ckpt: str
|
88 |
+
is_guided: bool
|
89 |
+
|
90 |
+
|
91 |
+
model_specs = {
|
92 |
+
ModelArchitecture.SD_2_1: SamplingSpec(
|
93 |
+
height=512,
|
94 |
+
width=512,
|
95 |
+
channels=4,
|
96 |
+
factor=8,
|
97 |
+
is_legacy=True,
|
98 |
+
config="sd_2_1.yaml",
|
99 |
+
ckpt="v2-1_512-ema-pruned.safetensors",
|
100 |
+
is_guided=True,
|
101 |
+
),
|
102 |
+
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
103 |
+
height=768,
|
104 |
+
width=768,
|
105 |
+
channels=4,
|
106 |
+
factor=8,
|
107 |
+
is_legacy=True,
|
108 |
+
config="sd_2_1_768.yaml",
|
109 |
+
ckpt="v2-1_768-ema-pruned.safetensors",
|
110 |
+
is_guided=True,
|
111 |
+
),
|
112 |
+
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
113 |
+
height=1024,
|
114 |
+
width=1024,
|
115 |
+
channels=4,
|
116 |
+
factor=8,
|
117 |
+
is_legacy=False,
|
118 |
+
config="sd_xl_base.yaml",
|
119 |
+
ckpt="sd_xl_base_0.9.safetensors",
|
120 |
+
is_guided=True,
|
121 |
+
),
|
122 |
+
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
123 |
+
height=1024,
|
124 |
+
width=1024,
|
125 |
+
channels=4,
|
126 |
+
factor=8,
|
127 |
+
is_legacy=True,
|
128 |
+
config="sd_xl_refiner.yaml",
|
129 |
+
ckpt="sd_xl_refiner_0.9.safetensors",
|
130 |
+
is_guided=True,
|
131 |
+
),
|
132 |
+
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
133 |
+
height=1024,
|
134 |
+
width=1024,
|
135 |
+
channels=4,
|
136 |
+
factor=8,
|
137 |
+
is_legacy=False,
|
138 |
+
config="sd_xl_base.yaml",
|
139 |
+
ckpt="sd_xl_base_1.0.safetensors",
|
140 |
+
is_guided=True,
|
141 |
+
),
|
142 |
+
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
143 |
+
height=1024,
|
144 |
+
width=1024,
|
145 |
+
channels=4,
|
146 |
+
factor=8,
|
147 |
+
is_legacy=True,
|
148 |
+
config="sd_xl_refiner.yaml",
|
149 |
+
ckpt="sd_xl_refiner_1.0.safetensors",
|
150 |
+
is_guided=True,
|
151 |
+
),
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
class SamplingPipeline:
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
model_id: ModelArchitecture,
|
159 |
+
model_path="checkpoints",
|
160 |
+
config_path="configs/inference",
|
161 |
+
device="cuda",
|
162 |
+
use_fp16=True,
|
163 |
+
) -> None:
|
164 |
+
if model_id not in model_specs:
|
165 |
+
raise ValueError(f"Model {model_id} not supported")
|
166 |
+
self.model_id = model_id
|
167 |
+
self.specs = model_specs[self.model_id]
|
168 |
+
self.config = str(pathlib.Path(config_path, self.specs.config))
|
169 |
+
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
170 |
+
self.device = device
|
171 |
+
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
172 |
+
|
173 |
+
def _load_model(self, device="cuda", use_fp16=True):
|
174 |
+
config = OmegaConf.load(self.config)
|
175 |
+
model = load_model_from_config(config, self.ckpt)
|
176 |
+
if model is None:
|
177 |
+
raise ValueError(f"Model {self.model_id} could not be loaded")
|
178 |
+
model.to(device)
|
179 |
+
if use_fp16:
|
180 |
+
model.conditioner.half()
|
181 |
+
model.model.half()
|
182 |
+
return model
|
183 |
+
|
184 |
+
def text_to_image(
|
185 |
+
self,
|
186 |
+
params: SamplingParams,
|
187 |
+
prompt: str,
|
188 |
+
negative_prompt: str = "",
|
189 |
+
samples: int = 1,
|
190 |
+
return_latents: bool = False,
|
191 |
+
):
|
192 |
+
sampler = get_sampler_config(params)
|
193 |
+
value_dict = asdict(params)
|
194 |
+
value_dict["prompt"] = prompt
|
195 |
+
value_dict["negative_prompt"] = negative_prompt
|
196 |
+
value_dict["target_width"] = params.width
|
197 |
+
value_dict["target_height"] = params.height
|
198 |
+
return do_sample(
|
199 |
+
self.model,
|
200 |
+
sampler,
|
201 |
+
value_dict,
|
202 |
+
samples,
|
203 |
+
params.height,
|
204 |
+
params.width,
|
205 |
+
self.specs.channels,
|
206 |
+
self.specs.factor,
|
207 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
208 |
+
return_latents=return_latents,
|
209 |
+
filter=None,
|
210 |
+
)
|
211 |
+
|
212 |
+
def image_to_image(
|
213 |
+
self,
|
214 |
+
params: SamplingParams,
|
215 |
+
image,
|
216 |
+
prompt: str,
|
217 |
+
negative_prompt: str = "",
|
218 |
+
samples: int = 1,
|
219 |
+
return_latents: bool = False,
|
220 |
+
):
|
221 |
+
sampler = get_sampler_config(params)
|
222 |
+
|
223 |
+
if params.img2img_strength < 1.0:
|
224 |
+
sampler.discretization = Img2ImgDiscretizationWrapper(
|
225 |
+
sampler.discretization,
|
226 |
+
strength=params.img2img_strength,
|
227 |
+
)
|
228 |
+
height, width = image.shape[2], image.shape[3]
|
229 |
+
value_dict = asdict(params)
|
230 |
+
value_dict["prompt"] = prompt
|
231 |
+
value_dict["negative_prompt"] = negative_prompt
|
232 |
+
value_dict["target_width"] = width
|
233 |
+
value_dict["target_height"] = height
|
234 |
+
return do_img2img(
|
235 |
+
image,
|
236 |
+
self.model,
|
237 |
+
sampler,
|
238 |
+
value_dict,
|
239 |
+
samples,
|
240 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
241 |
+
return_latents=return_latents,
|
242 |
+
filter=None,
|
243 |
+
)
|
244 |
+
|
245 |
+
def refiner(
|
246 |
+
self,
|
247 |
+
params: SamplingParams,
|
248 |
+
image,
|
249 |
+
prompt: str,
|
250 |
+
negative_prompt: Optional[str] = None,
|
251 |
+
samples: int = 1,
|
252 |
+
return_latents: bool = False,
|
253 |
+
):
|
254 |
+
sampler = get_sampler_config(params)
|
255 |
+
value_dict = {
|
256 |
+
"orig_width": image.shape[3] * 8,
|
257 |
+
"orig_height": image.shape[2] * 8,
|
258 |
+
"target_width": image.shape[3] * 8,
|
259 |
+
"target_height": image.shape[2] * 8,
|
260 |
+
"prompt": prompt,
|
261 |
+
"negative_prompt": negative_prompt,
|
262 |
+
"crop_coords_top": 0,
|
263 |
+
"crop_coords_left": 0,
|
264 |
+
"aesthetic_score": 6.0,
|
265 |
+
"negative_aesthetic_score": 2.5,
|
266 |
+
}
|
267 |
+
|
268 |
+
return do_img2img(
|
269 |
+
image,
|
270 |
+
self.model,
|
271 |
+
sampler,
|
272 |
+
value_dict,
|
273 |
+
samples,
|
274 |
+
skip_encode=True,
|
275 |
+
return_latents=return_latents,
|
276 |
+
filter=None,
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
def get_guider_config(params: SamplingParams):
|
281 |
+
if params.guider == Guider.IDENTITY:
|
282 |
+
guider_config = {
|
283 |
+
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
284 |
+
}
|
285 |
+
elif params.guider == Guider.VANILLA:
|
286 |
+
scale = params.scale
|
287 |
+
|
288 |
+
thresholder = params.thresholder
|
289 |
+
|
290 |
+
if thresholder == Thresholder.NONE:
|
291 |
+
dyn_thresh_config = {
|
292 |
+
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
293 |
+
}
|
294 |
+
else:
|
295 |
+
raise NotImplementedError
|
296 |
+
|
297 |
+
guider_config = {
|
298 |
+
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
299 |
+
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
300 |
+
}
|
301 |
+
else:
|
302 |
+
raise NotImplementedError
|
303 |
+
return guider_config
|
304 |
+
|
305 |
+
|
306 |
+
def get_discretization_config(params: SamplingParams):
|
307 |
+
if params.discretization == Discretization.LEGACY_DDPM:
|
308 |
+
discretization_config = {
|
309 |
+
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
310 |
+
}
|
311 |
+
elif params.discretization == Discretization.EDM:
|
312 |
+
discretization_config = {
|
313 |
+
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
314 |
+
"params": {
|
315 |
+
"sigma_min": params.sigma_min,
|
316 |
+
"sigma_max": params.sigma_max,
|
317 |
+
"rho": params.rho,
|
318 |
+
},
|
319 |
+
}
|
320 |
+
else:
|
321 |
+
raise ValueError(f"unknown discretization {params.discretization}")
|
322 |
+
return discretization_config
|
323 |
+
|
324 |
+
|
325 |
+
def get_sampler_config(params: SamplingParams):
|
326 |
+
discretization_config = get_discretization_config(params)
|
327 |
+
guider_config = get_guider_config(params)
|
328 |
+
sampler = None
|
329 |
+
if params.sampler == Sampler.EULER_EDM:
|
330 |
+
return EulerEDMSampler(
|
331 |
+
num_steps=params.steps,
|
332 |
+
discretization_config=discretization_config,
|
333 |
+
guider_config=guider_config,
|
334 |
+
s_churn=params.s_churn,
|
335 |
+
s_tmin=params.s_tmin,
|
336 |
+
s_tmax=params.s_tmax,
|
337 |
+
s_noise=params.s_noise,
|
338 |
+
verbose=True,
|
339 |
+
)
|
340 |
+
if params.sampler == Sampler.HEUN_EDM:
|
341 |
+
return HeunEDMSampler(
|
342 |
+
num_steps=params.steps,
|
343 |
+
discretization_config=discretization_config,
|
344 |
+
guider_config=guider_config,
|
345 |
+
s_churn=params.s_churn,
|
346 |
+
s_tmin=params.s_tmin,
|
347 |
+
s_tmax=params.s_tmax,
|
348 |
+
s_noise=params.s_noise,
|
349 |
+
verbose=True,
|
350 |
+
)
|
351 |
+
if params.sampler == Sampler.EULER_ANCESTRAL:
|
352 |
+
return EulerAncestralSampler(
|
353 |
+
num_steps=params.steps,
|
354 |
+
discretization_config=discretization_config,
|
355 |
+
guider_config=guider_config,
|
356 |
+
eta=params.eta,
|
357 |
+
s_noise=params.s_noise,
|
358 |
+
verbose=True,
|
359 |
+
)
|
360 |
+
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
361 |
+
return DPMPP2SAncestralSampler(
|
362 |
+
num_steps=params.steps,
|
363 |
+
discretization_config=discretization_config,
|
364 |
+
guider_config=guider_config,
|
365 |
+
eta=params.eta,
|
366 |
+
s_noise=params.s_noise,
|
367 |
+
verbose=True,
|
368 |
+
)
|
369 |
+
if params.sampler == Sampler.DPMPP2M:
|
370 |
+
return DPMPP2MSampler(
|
371 |
+
num_steps=params.steps,
|
372 |
+
discretization_config=discretization_config,
|
373 |
+
guider_config=guider_config,
|
374 |
+
verbose=True,
|
375 |
+
)
|
376 |
+
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
377 |
+
return LinearMultistepSampler(
|
378 |
+
num_steps=params.steps,
|
379 |
+
discretization_config=discretization_config,
|
380 |
+
guider_config=guider_config,
|
381 |
+
order=params.order,
|
382 |
+
verbose=True,
|
383 |
+
)
|
384 |
+
|
385 |
+
raise ValueError(f"unknown sampler {params.sampler}!")
|
sgm/inference/helpers.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
from imwatermark import WatermarkEncoder
|
9 |
+
from omegaconf import ListConfig
|
10 |
+
from PIL import Image
|
11 |
+
from torch import autocast
|
12 |
+
|
13 |
+
from sgm.util import append_dims
|
14 |
+
|
15 |
+
|
16 |
+
class WatermarkEmbedder:
|
17 |
+
def __init__(self, watermark):
|
18 |
+
self.watermark = watermark
|
19 |
+
self.num_bits = len(WATERMARK_BITS)
|
20 |
+
self.encoder = WatermarkEncoder()
|
21 |
+
self.encoder.set_watermark("bits", self.watermark)
|
22 |
+
|
23 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Adds a predefined watermark to the input image
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image: ([N,] B, RGB, H, W) in range [0, 1]
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
same as input but watermarked
|
32 |
+
"""
|
33 |
+
squeeze = len(image.shape) == 4
|
34 |
+
if squeeze:
|
35 |
+
image = image[None, ...]
|
36 |
+
n = image.shape[0]
|
37 |
+
image_np = rearrange(
|
38 |
+
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
39 |
+
).numpy()[:, :, :, ::-1]
|
40 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
41 |
+
# watermarking libary expects input as cv2 BGR format
|
42 |
+
for k in range(image_np.shape[0]):
|
43 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
44 |
+
image = torch.from_numpy(
|
45 |
+
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
46 |
+
).to(image.device)
|
47 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
48 |
+
if squeeze:
|
49 |
+
image = image[0]
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
# A fixed 48-bit message that was choosen at random
|
54 |
+
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
55 |
+
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
56 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
57 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
58 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
59 |
+
|
60 |
+
|
61 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
62 |
+
return list({x.input_key for x in conditioner.embedders})
|
63 |
+
|
64 |
+
|
65 |
+
def perform_save_locally(save_path, samples):
|
66 |
+
os.makedirs(os.path.join(save_path), exist_ok=True)
|
67 |
+
base_count = len(os.listdir(os.path.join(save_path)))
|
68 |
+
samples = embed_watermark(samples)
|
69 |
+
for sample in samples:
|
70 |
+
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
71 |
+
Image.fromarray(sample.astype(np.uint8)).save(
|
72 |
+
os.path.join(save_path, f"{base_count:09}.png")
|
73 |
+
)
|
74 |
+
base_count += 1
|
75 |
+
|
76 |
+
|
77 |
+
class Img2ImgDiscretizationWrapper:
|
78 |
+
"""
|
79 |
+
wraps a discretizer, and prunes the sigmas
|
80 |
+
params:
|
81 |
+
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, discretization, strength: float = 1.0):
|
85 |
+
self.discretization = discretization
|
86 |
+
self.strength = strength
|
87 |
+
assert 0.0 <= self.strength <= 1.0
|
88 |
+
|
89 |
+
def __call__(self, *args, **kwargs):
|
90 |
+
# sigmas start large first, and decrease then
|
91 |
+
sigmas = self.discretization(*args, **kwargs)
|
92 |
+
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
93 |
+
sigmas = torch.flip(sigmas, (0,))
|
94 |
+
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
95 |
+
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
96 |
+
sigmas = torch.flip(sigmas, (0,))
|
97 |
+
print(f"sigmas after pruning: ", sigmas)
|
98 |
+
return sigmas
|
99 |
+
|
100 |
+
|
101 |
+
def do_sample(
|
102 |
+
model,
|
103 |
+
sampler,
|
104 |
+
value_dict,
|
105 |
+
num_samples,
|
106 |
+
H,
|
107 |
+
W,
|
108 |
+
C,
|
109 |
+
F,
|
110 |
+
force_uc_zero_embeddings: Optional[List] = None,
|
111 |
+
batch2model_input: Optional[List] = None,
|
112 |
+
return_latents=False,
|
113 |
+
filter=None,
|
114 |
+
device="cuda",
|
115 |
+
):
|
116 |
+
if force_uc_zero_embeddings is None:
|
117 |
+
force_uc_zero_embeddings = []
|
118 |
+
if batch2model_input is None:
|
119 |
+
batch2model_input = []
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
with autocast(device) as precision_scope:
|
123 |
+
with model.ema_scope():
|
124 |
+
num_samples = [num_samples]
|
125 |
+
batch, batch_uc = get_batch(
|
126 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
127 |
+
value_dict,
|
128 |
+
num_samples,
|
129 |
+
)
|
130 |
+
for key in batch:
|
131 |
+
if isinstance(batch[key], torch.Tensor):
|
132 |
+
print(key, batch[key].shape)
|
133 |
+
elif isinstance(batch[key], list):
|
134 |
+
print(key, [len(l) for l in batch[key]])
|
135 |
+
else:
|
136 |
+
print(key, batch[key])
|
137 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
138 |
+
batch,
|
139 |
+
batch_uc=batch_uc,
|
140 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
141 |
+
)
|
142 |
+
|
143 |
+
for k in c:
|
144 |
+
if not k == "crossattn":
|
145 |
+
c[k], uc[k] = map(
|
146 |
+
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
147 |
+
)
|
148 |
+
|
149 |
+
additional_model_inputs = {}
|
150 |
+
for k in batch2model_input:
|
151 |
+
additional_model_inputs[k] = batch[k]
|
152 |
+
|
153 |
+
shape = (math.prod(num_samples), C, H // F, W // F)
|
154 |
+
randn = torch.randn(shape).to(device)
|
155 |
+
|
156 |
+
def denoiser(input, sigma, c):
|
157 |
+
return model.denoiser(
|
158 |
+
model.model, input, sigma, c, **additional_model_inputs
|
159 |
+
)
|
160 |
+
|
161 |
+
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
162 |
+
samples_x = model.decode_first_stage(samples_z)
|
163 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
164 |
+
|
165 |
+
if filter is not None:
|
166 |
+
samples = filter(samples)
|
167 |
+
|
168 |
+
if return_latents:
|
169 |
+
return samples, samples_z
|
170 |
+
return samples
|
171 |
+
|
172 |
+
|
173 |
+
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
174 |
+
# Hardcoded demo setups; might undergo some changes in the future
|
175 |
+
|
176 |
+
batch = {}
|
177 |
+
batch_uc = {}
|
178 |
+
|
179 |
+
for key in keys:
|
180 |
+
if key == "txt":
|
181 |
+
batch["txt"] = (
|
182 |
+
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
183 |
+
.reshape(N)
|
184 |
+
.tolist()
|
185 |
+
)
|
186 |
+
batch_uc["txt"] = (
|
187 |
+
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
188 |
+
.reshape(N)
|
189 |
+
.tolist()
|
190 |
+
)
|
191 |
+
elif key == "original_size_as_tuple":
|
192 |
+
batch["original_size_as_tuple"] = (
|
193 |
+
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
194 |
+
.to(device)
|
195 |
+
.repeat(*N, 1)
|
196 |
+
)
|
197 |
+
elif key == "crop_coords_top_left":
|
198 |
+
batch["crop_coords_top_left"] = (
|
199 |
+
torch.tensor(
|
200 |
+
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
201 |
+
)
|
202 |
+
.to(device)
|
203 |
+
.repeat(*N, 1)
|
204 |
+
)
|
205 |
+
elif key == "aesthetic_score":
|
206 |
+
batch["aesthetic_score"] = (
|
207 |
+
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
208 |
+
)
|
209 |
+
batch_uc["aesthetic_score"] = (
|
210 |
+
torch.tensor([value_dict["negative_aesthetic_score"]])
|
211 |
+
.to(device)
|
212 |
+
.repeat(*N, 1)
|
213 |
+
)
|
214 |
+
|
215 |
+
elif key == "target_size_as_tuple":
|
216 |
+
batch["target_size_as_tuple"] = (
|
217 |
+
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
218 |
+
.to(device)
|
219 |
+
.repeat(*N, 1)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
batch[key] = value_dict[key]
|
223 |
+
|
224 |
+
for key in batch.keys():
|
225 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
226 |
+
batch_uc[key] = torch.clone(batch[key])
|
227 |
+
return batch, batch_uc
|
228 |
+
|
229 |
+
|
230 |
+
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
231 |
+
w, h = image.size
|
232 |
+
print(f"loaded input image of size ({w}, {h})")
|
233 |
+
width, height = map(
|
234 |
+
lambda x: x - x % 64, (w, h)
|
235 |
+
) # resize to integer multiple of 64
|
236 |
+
image = image.resize((width, height))
|
237 |
+
image_array = np.array(image.convert("RGB"))
|
238 |
+
image_array = image_array[None].transpose(0, 3, 1, 2)
|
239 |
+
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
240 |
+
return image_tensor.to(device)
|
241 |
+
|
242 |
+
|
243 |
+
def do_img2img(
|
244 |
+
img,
|
245 |
+
model,
|
246 |
+
sampler,
|
247 |
+
value_dict,
|
248 |
+
num_samples,
|
249 |
+
force_uc_zero_embeddings=[],
|
250 |
+
additional_kwargs={},
|
251 |
+
offset_noise_level: float = 0.0,
|
252 |
+
return_latents=False,
|
253 |
+
skip_encode=False,
|
254 |
+
filter=None,
|
255 |
+
device="cuda",
|
256 |
+
):
|
257 |
+
with torch.no_grad():
|
258 |
+
with autocast(device) as precision_scope:
|
259 |
+
with model.ema_scope():
|
260 |
+
batch, batch_uc = get_batch(
|
261 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
262 |
+
value_dict,
|
263 |
+
[num_samples],
|
264 |
+
)
|
265 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
266 |
+
batch,
|
267 |
+
batch_uc=batch_uc,
|
268 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
269 |
+
)
|
270 |
+
|
271 |
+
for k in c:
|
272 |
+
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
273 |
+
|
274 |
+
for k in additional_kwargs:
|
275 |
+
c[k] = uc[k] = additional_kwargs[k]
|
276 |
+
if skip_encode:
|
277 |
+
z = img
|
278 |
+
else:
|
279 |
+
z = model.encode_first_stage(img)
|
280 |
+
noise = torch.randn_like(z)
|
281 |
+
sigmas = sampler.discretization(sampler.num_steps)
|
282 |
+
sigma = sigmas[0].to(z.device)
|
283 |
+
|
284 |
+
if offset_noise_level > 0.0:
|
285 |
+
noise = noise + offset_noise_level * append_dims(
|
286 |
+
torch.randn(z.shape[0], device=z.device), z.ndim
|
287 |
+
)
|
288 |
+
noised_z = z + noise * append_dims(sigma, z.ndim)
|
289 |
+
noised_z = noised_z / torch.sqrt(
|
290 |
+
1.0 + sigmas[0] ** 2.0
|
291 |
+
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
292 |
+
|
293 |
+
def denoiser(x, sigma, c):
|
294 |
+
return model.denoiser(model.model, x, sigma, c)
|
295 |
+
|
296 |
+
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
297 |
+
samples_x = model.decode_first_stage(samples_z)
|
298 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
299 |
+
|
300 |
+
if filter is not None:
|
301 |
+
samples = filter(samples)
|
302 |
+
|
303 |
+
if return_latents:
|
304 |
+
return samples, samples_z
|
305 |
+
return samples
|