denoise model update
Browse files- app.py +215 -3
- draco.yaml +19 -0
- draco/__init__.py +2 -0
- draco/configuration/__init__.py +4 -0
- draco/configuration/base.yaml +57 -0
- draco/configuration/config.py +52 -0
- draco/configuration/configurable.py +138 -0
- draco/configuration/draco2d-b_triplet_pretrain.yaml +20 -0
- draco/configuration/draco2d-h_triplet_pretrain.yaml +5 -0
- draco/configuration/draco2d-l_triplet_pretrain.yaml +4 -0
- draco/model/__init__.py +6 -0
- draco/model/build.py +22 -0
- draco/model/checkpoint.py +61 -0
- draco/model/draco2d.py +663 -0
- draco/model/draco_base.py +35 -0
- draco/model/layer/__init__.py +3 -0
- draco/model/layer/normalization.py +22 -0
- draco/model/utils/constant.py +24 -0
- requirements.txt +11 -0
app.py
CHANGED
@@ -1,7 +1,219 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
demo.launch()
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import h5py
|
3 |
+
import mrcfile
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import DictConfig
|
7 |
+
import torch
|
8 |
+
from pathlib import Path
|
9 |
+
from torchvision.transforms import functional as F
|
10 |
+
import torchvision.transforms.v2 as v2
|
11 |
|
|
|
|
|
12 |
|
13 |
+
from draco.configuration import CfgNode
|
14 |
+
from draco.model import (
|
15 |
+
build_model,
|
16 |
+
load_pretrained
|
17 |
+
)
|
18 |
+
|
19 |
+
class DRACODenoiser(object):
|
20 |
+
def __init__(self,
|
21 |
+
cfg: DictConfig,
|
22 |
+
ckpt_path: Path,
|
23 |
+
) -> None:
|
24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
|
26 |
+
self.transform = self.build_transform()
|
27 |
+
self.model = build_model(cfg).to(self.device).eval()
|
28 |
+
self.model = load_pretrained(self.model, ckpt_path, self.device)
|
29 |
+
self.patch_size = cfg.MODEL.PATCH_SIZE
|
30 |
+
|
31 |
+
def patchify(self, image: torch.Tensor) -> torch.Tensor:
|
32 |
+
B, C, H, W = image.shape
|
33 |
+
P = self.patch_size
|
34 |
+
if H % P != 0 or W % P != 0:
|
35 |
+
image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0)
|
36 |
+
|
37 |
+
patches = image.unfold(2, P, P).unfold(3, P, P)
|
38 |
+
patches = patches.permute(0, 2, 3, 4, 5, 1)
|
39 |
+
patches = patches.reshape(B, -1, P * P * C)
|
40 |
+
return patches
|
41 |
+
|
42 |
+
def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
43 |
+
B = patches.shape[0]
|
44 |
+
P = self.patch_size
|
45 |
+
|
46 |
+
images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1)
|
47 |
+
images = images.permute(0, 5, 1, 3, 2, 4)
|
48 |
+
images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P)
|
49 |
+
images = images[..., :H, :W]
|
50 |
+
return images
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def build_transform(cls) -> v2.Compose:
|
54 |
+
return v2.Compose([
|
55 |
+
v2.ToImage(),
|
56 |
+
v2.ToDtype(torch.float32, scale=True)
|
57 |
+
])
|
58 |
+
|
59 |
+
@torch.inference_mode()
|
60 |
+
def inference(self, image: Image.Image) -> None:
|
61 |
+
W, H = image.size
|
62 |
+
|
63 |
+
x = self.transform(image).unsqueeze(0).to(self.device)
|
64 |
+
y = self.model(x)
|
65 |
+
|
66 |
+
x = self.patchify(x).detach().cpu().numpy()
|
67 |
+
denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
|
68 |
+
|
69 |
+
return denoised
|
70 |
+
|
71 |
+
# Model Initialization
|
72 |
+
cfg = CfgNode.load_yaml_with_base(Path("draco.yaml"))
|
73 |
+
CfgNode.merge_with_dotlist(cfg, [])
|
74 |
+
ckpt_path = Path("denoise.ckpt")
|
75 |
+
denoiser = DRACODenoiser(cfg, ckpt_path)
|
76 |
+
|
77 |
+
def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray:
|
78 |
+
|
79 |
+
image = (image - image.min()) / (image.max() - image.min())
|
80 |
+
mean = image.mean()
|
81 |
+
std = image.std()
|
82 |
+
|
83 |
+
f = std / t_sd
|
84 |
+
|
85 |
+
black = mean - t_mean * f
|
86 |
+
white = mean + (1 - t_mean) * f
|
87 |
+
|
88 |
+
new_image = np.clip(image, black, white)
|
89 |
+
new_image = (new_image - black) / (white - black)
|
90 |
+
return new_image
|
91 |
+
|
92 |
+
|
93 |
+
def load_data(file_path) -> np.ndarray:
|
94 |
+
if file_path.endswith('.h5'):
|
95 |
+
with h5py.File(file_path, "r") as f:
|
96 |
+
full_micrograph = f["micrograph"] if "micrograph" in f else f["data"]
|
97 |
+
full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean()
|
98 |
+
full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std()
|
99 |
+
data = full_micrograph[:].astype(np.float32)
|
100 |
+
elif file_path.endswith('.mrc'):
|
101 |
+
with mrcfile.open(file_path, "r") as f:
|
102 |
+
data = f.data[:].astype(np.float32)
|
103 |
+
full_mean = data.mean()
|
104 |
+
full_std = data.std()
|
105 |
+
else:
|
106 |
+
raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.")
|
107 |
+
data = (data - full_mean) / full_std
|
108 |
+
return data
|
109 |
+
|
110 |
+
def display_crop(data, x_offset, y_offset, auto_contrast) -> Image:
|
111 |
+
|
112 |
+
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
|
113 |
+
original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min())
|
114 |
+
input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8))
|
115 |
+
|
116 |
+
return input_image
|
117 |
+
|
118 |
+
def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image:
|
119 |
+
|
120 |
+
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
|
121 |
+
denoised_data = denoiser.inference(Image.fromarray(crop))
|
122 |
+
|
123 |
+
denoised_data = denoised_data.squeeze()
|
124 |
+
denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min())
|
125 |
+
denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8))
|
126 |
+
|
127 |
+
return denoised_image
|
128 |
+
|
129 |
+
def clear_images() -> tuple:
|
130 |
+
return None, None, None, gr.update(maximum=512), gr.update(maximum=512)
|
131 |
+
|
132 |
+
with gr.Blocks(css="""
|
133 |
+
.gradio-container {
|
134 |
+
background-color: #f7f9fc;
|
135 |
+
font-family: Arial, sans-serif;
|
136 |
+
}
|
137 |
+
.title-text {
|
138 |
+
text-align: center;
|
139 |
+
font-size: 30px;
|
140 |
+
font-weight: bold;
|
141 |
+
margin-bottom: 10px;
|
142 |
+
}
|
143 |
+
.description-text {
|
144 |
+
text-align: center;
|
145 |
+
font-size: 18px;
|
146 |
+
margin-bottom: 20px;
|
147 |
+
}
|
148 |
+
""") as demo:
|
149 |
+
# Centered Title and Description
|
150 |
+
with gr.Column():
|
151 |
+
gr.Markdown(
|
152 |
+
"""
|
153 |
+
<div style="text-align: center; font-size: 30px; font-weight: bold; margin-bottom: 10px;">
|
154 |
+
Denoising Demo
|
155 |
+
</div>
|
156 |
+
<div style="text-align: center; font-size: 18px;">
|
157 |
+
Upload a Raw file or select an example to view the original and denoised images
|
158 |
+
</div>
|
159 |
+
"""
|
160 |
+
)
|
161 |
+
|
162 |
+
file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format")
|
163 |
+
auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False)
|
164 |
+
|
165 |
+
x_slider = gr.Slider(0, 512, step=10, label="X Offset")
|
166 |
+
y_slider = gr.Slider(0, 512, step=10, label="Y Offset")
|
167 |
+
|
168 |
+
with gr.Row():
|
169 |
+
denoise_button = gr.Button("Denoise")
|
170 |
+
clear_button = gr.Button("Clear")
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column():
|
174 |
+
original_image = gr.Image(type="pil", label="Original Image")
|
175 |
+
with gr.Column():
|
176 |
+
denoised_image = gr.Image(type="pil", label="Denoised Image")
|
177 |
+
|
178 |
+
active_data = gr.State()
|
179 |
+
|
180 |
+
def load_image_and_update_sliders(file_path) -> tuple:
|
181 |
+
data = load_data(file_path)
|
182 |
+
h, w = data.shape[:2]
|
183 |
+
return data, gr.update(maximum=w-1024), gr.update(maximum=h-1024)
|
184 |
+
|
185 |
+
|
186 |
+
file_input.clear(
|
187 |
+
clear_images,
|
188 |
+
inputs=None,
|
189 |
+
outputs=[original_image, denoised_image, active_data, x_slider, y_slider]
|
190 |
+
)
|
191 |
+
|
192 |
+
file_input.change(
|
193 |
+
lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=512), gr.update(maximum=512)),
|
194 |
+
inputs=file_input,
|
195 |
+
outputs=[active_data, x_slider, y_slider]
|
196 |
+
)
|
197 |
+
|
198 |
+
x_slider.change(
|
199 |
+
display_crop,
|
200 |
+
inputs=[active_data, x_slider, y_slider, auto_contrast],
|
201 |
+
outputs=original_image
|
202 |
+
)
|
203 |
+
|
204 |
+
y_slider.change(
|
205 |
+
display_crop,
|
206 |
+
inputs=[active_data, x_slider, y_slider, auto_contrast],
|
207 |
+
outputs=original_image
|
208 |
+
)
|
209 |
+
|
210 |
+
denoise_button.click(
|
211 |
+
process_and_denoise,
|
212 |
+
inputs=[active_data, x_slider, y_slider, auto_contrast],
|
213 |
+
outputs=denoised_image
|
214 |
+
)
|
215 |
+
|
216 |
+
clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider])
|
217 |
+
|
218 |
demo.launch()
|
219 |
+
|
draco.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
NAME: DracoDenoiseAutoencoder
|
3 |
+
DEVICE: cuda
|
4 |
+
|
5 |
+
IMG_SIZE: 1024
|
6 |
+
PATCH_SIZE: 32
|
7 |
+
IN_CHANS: 1
|
8 |
+
VIT_SCALE: base
|
9 |
+
DYNAMIC_IMG_SIZE: True
|
10 |
+
DYNAMIC_IMG_PAD: True
|
11 |
+
DECODER_EMBED_DIM: 512
|
12 |
+
DECODER_DEPTH: 8
|
13 |
+
DECODER_NUM_HEADS: 16
|
14 |
+
DECODER_USE_NECK: True
|
15 |
+
DECODER_NECK_DIM: 256
|
16 |
+
USE_ABS_POS: true
|
17 |
+
USE_DECODER_NECK: True
|
18 |
+
WINDOW_SIZE: 28
|
19 |
+
DECODER_GLOBAL_ATTN_INDEXES: [3, 7]
|
draco/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import draco.model
|
draco/configuration/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import CfgNode
|
2 |
+
from .configurable import configurable
|
3 |
+
|
4 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
draco/configuration/base.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATALOADER:
|
2 |
+
BATCH_SIZE: 0
|
3 |
+
NUM_WORKERS: 0
|
4 |
+
PIN_MEMORY: False
|
5 |
+
DROP_LAST: False
|
6 |
+
PERSISTENT_WORKERS: False
|
7 |
+
|
8 |
+
DATASET:
|
9 |
+
NAME: null
|
10 |
+
|
11 |
+
TRANSFORM:
|
12 |
+
NAME: null
|
13 |
+
|
14 |
+
MODEL:
|
15 |
+
NAME: null
|
16 |
+
DEVICE: cuda
|
17 |
+
|
18 |
+
METRIC:
|
19 |
+
NAME: null
|
20 |
+
TYPE: null
|
21 |
+
|
22 |
+
MODULE:
|
23 |
+
NAME: null
|
24 |
+
COMPILE: False
|
25 |
+
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: null
|
28 |
+
|
29 |
+
SCHEDULER:
|
30 |
+
NAME: null
|
31 |
+
|
32 |
+
TRAINER:
|
33 |
+
STRATEGY: auto # Set to `auto`, `ddp`, `deepspeed_stage_2`, `deepspeed_stage_3` ...
|
34 |
+
MIXED_PRECISION: False
|
35 |
+
CHECKPOINT:
|
36 |
+
EVERY_N_EPOCHS: 10
|
37 |
+
|
38 |
+
SAVE_BEST: False # If True, monitor will be required
|
39 |
+
MONITOR: null
|
40 |
+
MONITOR_MODE: min # Set to `min` or `max`
|
41 |
+
|
42 |
+
MAX_EPOCHS: -1 # If profiler is enabled, this will be *automatically* set to 1
|
43 |
+
LOG_EVERY_N_STEPS: 1
|
44 |
+
ACCUMULATE_GRAD_BATCHES: 1
|
45 |
+
|
46 |
+
CLIP_GRAD:
|
47 |
+
ALGORITHM: null
|
48 |
+
VALUE: null
|
49 |
+
|
50 |
+
DETERMINISTIC: False # Set to True to enable cudnn.deterministic
|
51 |
+
BENCHMARK: False # Set to True to enable cudnn.benchmark
|
52 |
+
PROFILER: null # Set to `advanced` or `pytorch` to enable profiling
|
53 |
+
DETECT_ANOMALY: False # Set to True to enable anomaly detection
|
54 |
+
SYNC_BATCHNORM: False # Set to True to enable sync batchnorm
|
55 |
+
|
56 |
+
SEED: null
|
57 |
+
OUTPUT_DIR: null
|
draco/configuration/config.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from omegaconf import DictConfig, OmegaConf
|
5 |
+
|
6 |
+
BASE_KEY = "_BASE_"
|
7 |
+
ROOT_KEY = "cfg"
|
8 |
+
|
9 |
+
class CfgNode(OmegaConf):
|
10 |
+
"""
|
11 |
+
A wrapper around OmegaConf that provides some additional functionality.
|
12 |
+
"""
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def load_yaml_with_base(filename: str) -> DictConfig:
|
16 |
+
cfg = OmegaConf.load(filename)
|
17 |
+
|
18 |
+
def _load_with_base(base_cfg_file: str) -> dict[str, Any]:
|
19 |
+
if base_cfg_file.startswith("~"):
|
20 |
+
base_cfg_file = os.path.expanduser(base_cfg_file)
|
21 |
+
if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])):
|
22 |
+
# the path to base cfg is relative to the config file itself.
|
23 |
+
base_cfg_file = os.path.join(os.path.dirname(filename), base_cfg_file)
|
24 |
+
return CfgNode.load_yaml_with_base(base_cfg_file)
|
25 |
+
|
26 |
+
if BASE_KEY in cfg:
|
27 |
+
if isinstance(cfg[BASE_KEY], list):
|
28 |
+
base_cfg: dict[str, Any] = {}
|
29 |
+
base_cfg_files = cfg[BASE_KEY]
|
30 |
+
for base_cfg_file in base_cfg_files:
|
31 |
+
base_cfg = CfgNode.merge(base_cfg, _load_with_base(base_cfg_file))
|
32 |
+
else:
|
33 |
+
base_cfg_file = cfg[BASE_KEY]
|
34 |
+
base_cfg = _load_with_base(base_cfg_file)
|
35 |
+
del cfg[BASE_KEY]
|
36 |
+
|
37 |
+
base_cfg = CfgNode.merge(base_cfg, cfg)
|
38 |
+
return base_cfg
|
39 |
+
|
40 |
+
if ROOT_KEY in cfg:
|
41 |
+
return cfg[ROOT_KEY]
|
42 |
+
return cfg
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def merge_with_dotlist(cfg: DictConfig, dotlist: list[str]) -> None:
|
46 |
+
if len(dotlist) == 0:
|
47 |
+
return
|
48 |
+
|
49 |
+
new_dotlist = []
|
50 |
+
for key, value in zip(dotlist[::2], dotlist[1::2]):
|
51 |
+
new_dotlist.append(f"{key}={value}")
|
52 |
+
cfg.merge_with_dotlist(new_dotlist)
|
draco/configuration/configurable.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import inspect
|
3 |
+
from typing import Any, Callable
|
4 |
+
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
|
7 |
+
__all__ = ["configurable"]
|
8 |
+
|
9 |
+
|
10 |
+
def _called_with_cfg(*args, **kwargs) -> bool:
|
11 |
+
"""
|
12 |
+
Check if the function is called with a `DictConfig` as the first argument.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
(bool): whether the function is called with a `DictConfig` as the first argument.
|
16 |
+
Or the `cfg` keyword argument is a `DictConfig`.
|
17 |
+
"""
|
18 |
+
|
19 |
+
if len(args) > 0 and isinstance(args[0], DictConfig):
|
20 |
+
return True
|
21 |
+
if isinstance(kwargs.get("cfg", None), DictConfig):
|
22 |
+
return True
|
23 |
+
return False
|
24 |
+
|
25 |
+
|
26 |
+
def _get_args_from_cfg(from_config_func: Callable[[Any], dict[str, Any]], *args, **kwargs) -> dict[str, Any]:
|
27 |
+
"""
|
28 |
+
Get the input arguments of the decorated function from a `DictConfig` object.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
(dict): The input arguments of the class `__init__` method.
|
32 |
+
"""
|
33 |
+
|
34 |
+
signature = inspect.signature(from_config_func)
|
35 |
+
if list(signature.parameters.keys())[0] != "cfg":
|
36 |
+
raise ValueError("The first argument of `{}` must be named as `cfg`.".format(from_config_func.__name__))
|
37 |
+
|
38 |
+
# Forwarding all arguments to `from_config`, if the arguments of `from_config` are only `*args` or `*kwargs`.
|
39 |
+
if any(param.kind in [param.VAR_POSITIONAL or param.VAR_KEYWORD] for param in signature.parameters.values()):
|
40 |
+
result = from_config_func(*args, **kwargs)
|
41 |
+
|
42 |
+
# If there is any positional arguments.
|
43 |
+
else:
|
44 |
+
positional_args_name = set(signature.parameters.keys())
|
45 |
+
extra_kwargs = {}
|
46 |
+
for name in kwargs.keys():
|
47 |
+
if name not in positional_args_name:
|
48 |
+
extra_kwargs[name] = kwargs.pop(name)
|
49 |
+
result = from_config_func(*args, **kwargs)
|
50 |
+
# These args are forwarded directly to `__init__` method.
|
51 |
+
result.update(extra_kwargs)
|
52 |
+
|
53 |
+
return result
|
54 |
+
|
55 |
+
|
56 |
+
def configurable(init_func: Callable = None, *, from_config: Callable[[Any], dict[str, Any]] | None = None) -> Callable:
|
57 |
+
"""
|
58 |
+
A decorator of a function or a class `__init__` method,
|
59 |
+
to make it configurable by a `DictConfig` object.
|
60 |
+
|
61 |
+
Example:
|
62 |
+
```python
|
63 |
+
# 1. Decorate a function.
|
64 |
+
@configurable(from_config=lambda cfg: { "x": cfg.x })
|
65 |
+
def func(x, y=2, z=3):
|
66 |
+
pass
|
67 |
+
|
68 |
+
a1 = func(x=1, y=2) # Call with regular args.
|
69 |
+
a2 = func(cfg) # Call with a `DictConfig` object.
|
70 |
+
a3 = func(cfg, y=2, z=3) # Call with a `DictConfig` object and regular arguments.
|
71 |
+
|
72 |
+
# 2. Decorate a class `__init__` method.
|
73 |
+
class A:
|
74 |
+
@configurable
|
75 |
+
def __init__(self, *args, **kwargs) -> None:
|
76 |
+
pass
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def from_config(cls, cfg) -> dict:
|
80 |
+
pass
|
81 |
+
|
82 |
+
a1 = A(x, y) # Call with regular constructor.
|
83 |
+
a2 = A(cfg) # Call with a `DictConfig` object.
|
84 |
+
a3 = A(cfg, x, y) # Call with a `DictConfig` object and regular arguments.
|
85 |
+
```
|
86 |
+
|
87 |
+
Args:
|
88 |
+
`init_func` (callable): a function or a class method.
|
89 |
+
`from_config` (callable): a function that converts a `DictConfig` to the
|
90 |
+
input arguments of the decorated function.
|
91 |
+
It is always required.
|
92 |
+
"""
|
93 |
+
|
94 |
+
# Decorating a function
|
95 |
+
if init_func is None:
|
96 |
+
# Prevent common misuse: `@configurable()`.
|
97 |
+
if from_config is None:
|
98 |
+
return configurable
|
99 |
+
|
100 |
+
assert inspect.isfunction(from_config), "`from_config` must be a function."
|
101 |
+
|
102 |
+
def wrapper(func):
|
103 |
+
@functools.wraps(func)
|
104 |
+
def wrapped(*args, **kwargs):
|
105 |
+
if _called_with_cfg(*args, **kwargs):
|
106 |
+
explicit_args = _get_args_from_cfg(from_config, *args, **kwargs)
|
107 |
+
return func(**explicit_args)
|
108 |
+
else:
|
109 |
+
return func(*args, **kwargs)
|
110 |
+
|
111 |
+
wrapped.from_config = from_config
|
112 |
+
return wrapped
|
113 |
+
|
114 |
+
return wrapper
|
115 |
+
|
116 |
+
# Decorating a class `__init__` method
|
117 |
+
else:
|
118 |
+
assert(
|
119 |
+
inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__"
|
120 |
+
), "Invalid usage of @configurable."
|
121 |
+
|
122 |
+
@functools.wraps(init_func)
|
123 |
+
def wrapped(self, *args, **kwargs):
|
124 |
+
try:
|
125 |
+
from_config_func = getattr(self, "from_config")
|
126 |
+
except AttributeError as e:
|
127 |
+
raise AttributeError("Class with `@configurable` should have a `from_config` classmethod.") from e
|
128 |
+
|
129 |
+
if not inspect.ismethod(from_config_func):
|
130 |
+
raise AttributeError("Class with `@configurable` should have a `from_config` classmethod.")
|
131 |
+
|
132 |
+
if _called_with_cfg(*args, **kwargs):
|
133 |
+
explicit_args = _get_args_from_cfg(from_config_func, *args, **kwargs)
|
134 |
+
init_func(self, **explicit_args)
|
135 |
+
else:
|
136 |
+
init_func(self, *args, **kwargs)
|
137 |
+
|
138 |
+
return wrapped
|
draco/configuration/draco2d-b_triplet_pretrain.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: base.yaml
|
2 |
+
|
3 |
+
MODEL:
|
4 |
+
NAME: DenoisingReconstructionAutoencoderVisionTransformer2d
|
5 |
+
|
6 |
+
IMG_SIZE: 256
|
7 |
+
PATCH_SIZE: 16
|
8 |
+
IN_CHANS: 1
|
9 |
+
VIT_SCALE: base
|
10 |
+
DYNAMIC_IMG_SIZE: False
|
11 |
+
DYNAMIC_IMG_PAD: False
|
12 |
+
USE_ABS_POS: True
|
13 |
+
DECODER_EMBED_DIM: 512
|
14 |
+
DECODER_DEPTH: 8
|
15 |
+
DECODER_NUM_HEADS: 16
|
16 |
+
DECODER_USE_NECK: True
|
17 |
+
DECODER_NECK_DIM: 256
|
18 |
+
|
19 |
+
SEED: 0
|
20 |
+
OUTPUT_DIR: null
|
draco/configuration/draco2d-h_triplet_pretrain.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: draco-b_imagenet_pretrain.yaml
|
2 |
+
|
3 |
+
MODEL:
|
4 |
+
PATCH_SIZE: 14
|
5 |
+
VIT_SCALE: huge
|
draco/configuration/draco2d-l_triplet_pretrain.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: draco-b_imagenet_pretrain.yaml
|
2 |
+
|
3 |
+
MODEL:
|
4 |
+
VIT_SCALE: large
|
draco/model/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .build import MODEL_REGISTRY, build_model
|
2 |
+
from .checkpoint import load_pretrained
|
3 |
+
|
4 |
+
from .draco2d import DenoisingReconstructionAutoencoderVisionTransformer2d, DracoDenoiseAutoencoder
|
5 |
+
|
6 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
draco/model/build.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fvcore.common.registry import Registry
|
2 |
+
from omegaconf import DictConfig
|
3 |
+
import torch
|
4 |
+
|
5 |
+
__all__ = ["MODEL_REGISTRY", "build_model"]
|
6 |
+
|
7 |
+
|
8 |
+
MODEL_REGISTRY = Registry("MODEL")
|
9 |
+
MODEL_REGISTRY.__doc__ = "Registry for the model."
|
10 |
+
|
11 |
+
def build_model(cfg: DictConfig) -> torch.nn.Module:
|
12 |
+
"""
|
13 |
+
Build the model defined by `cfg.MODEL.NAME`.
|
14 |
+
It moves the model to the device defined by `cfg.MODEL.DEVICE`.
|
15 |
+
It does not load checkpoints from `cfg`.
|
16 |
+
"""
|
17 |
+
model_name = cfg.MODEL.NAME
|
18 |
+
try:
|
19 |
+
model = MODEL_REGISTRY.get(model_name)(cfg)
|
20 |
+
except KeyError as e:
|
21 |
+
raise KeyError(MODEL_REGISTRY) from e
|
22 |
+
return model
|
draco/model/checkpoint.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None:
|
8 |
+
"""
|
9 |
+
Strip the prefix in metadata, if any.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
state_dict (OrderedDict): a state-dict to be loaded to the model.
|
13 |
+
prefix (str): prefix.
|
14 |
+
"""
|
15 |
+
keys = sorted(state_dict.keys())
|
16 |
+
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
|
17 |
+
return
|
18 |
+
|
19 |
+
for key in keys:
|
20 |
+
newkey = key[len(prefix) :]
|
21 |
+
state_dict[newkey] = state_dict.pop(key)
|
22 |
+
|
23 |
+
# also strip the prefix in metadata, if any..
|
24 |
+
try:
|
25 |
+
metadata = state_dict._metadata # pyre-ignore
|
26 |
+
except AttributeError:
|
27 |
+
pass
|
28 |
+
else:
|
29 |
+
for key in list(metadata.keys()):
|
30 |
+
# for the metadata dict, the key can be:
|
31 |
+
# '': for the DDP module, which we want to remove.
|
32 |
+
# 'module': for the actual model.
|
33 |
+
# 'module.xx.xx': for the rest.
|
34 |
+
|
35 |
+
if len(key) == 0:
|
36 |
+
continue
|
37 |
+
newkey = key[len(prefix) :]
|
38 |
+
metadata[newkey] = metadata.pop(key)
|
39 |
+
|
40 |
+
|
41 |
+
def load_pretrained(model: torch.nn.Module, ckpt_path: Path, device: torch.device = "cuda") -> torch.nn.Module:
|
42 |
+
"""
|
43 |
+
Load the pre-trained model from the checkpoint file.
|
44 |
+
"""
|
45 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
46 |
+
|
47 |
+
if "state_dict" in ckpt:
|
48 |
+
checkpoint_state_dict = ckpt["state_dict"]
|
49 |
+
elif "model" in ckpt:
|
50 |
+
checkpoint_state_dict = ckpt["model"]
|
51 |
+
else:
|
52 |
+
checkpoint_state_dict = ckpt
|
53 |
+
|
54 |
+
_strip_prefix_if_present(checkpoint_state_dict, "module.") # for DistributedDataParallel
|
55 |
+
_strip_prefix_if_present(checkpoint_state_dict, "model.") # for PyTorch Lightning Module
|
56 |
+
_strip_prefix_if_present(checkpoint_state_dict, "_orig_mod.") # for torch.compile
|
57 |
+
|
58 |
+
msg = model.load_state_dict(checkpoint_state_dict, strict=False)
|
59 |
+
print(f"Loaded pre-trained model from {ckpt_path} with message: {msg}")
|
60 |
+
|
61 |
+
return model
|
draco/model/draco2d.py
ADDED
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Callable
|
3 |
+
|
4 |
+
from omegaconf import DictConfig
|
5 |
+
from timm.layers import build_sincos2d_pos_embed, resample_abs_pos_embed_nhwc, PatchEmbed, Mlp, LayerType
|
6 |
+
from timm.models.vision_transformer import Block
|
7 |
+
from timm.models.vision_transformer_sam import Block as SAMBlock
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from draco.configuration import configurable
|
12 |
+
from .build import MODEL_REGISTRY
|
13 |
+
from .layer import LayerNorm2d
|
14 |
+
from .draco_base import DenoisingReconstructionAutoencoderVisionTransformerBase
|
15 |
+
from .utils.constant import get_vit_scale, get_global_attn_indexes
|
16 |
+
|
17 |
+
__all__ = ["DenoisingReconstructionAutoencoderVisionTransformer2d", "DracoDenoiseAutoencoder"]
|
18 |
+
|
19 |
+
|
20 |
+
@MODEL_REGISTRY.register()
|
21 |
+
class DenoisingReconstructionAutoencoderVisionTransformer2d(DenoisingReconstructionAutoencoderVisionTransformerBase):
|
22 |
+
@configurable
|
23 |
+
def __init__(self, *,
|
24 |
+
img_size: int = 224,
|
25 |
+
patch_size: int = 16,
|
26 |
+
in_chans: int = 3,
|
27 |
+
embed_layer: Callable = PatchEmbed,
|
28 |
+
dynamic_img_size: bool = False,
|
29 |
+
dynamic_img_pad: bool = False,
|
30 |
+
use_abs_pos: bool = True,
|
31 |
+
block_fn: nn.Module = Block,
|
32 |
+
norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
33 |
+
act_layer: LayerType = nn.GELU,
|
34 |
+
mlp_layer: nn.Module = Mlp,
|
35 |
+
embed_dim: int = 768,
|
36 |
+
depth: int = 12,
|
37 |
+
num_heads: int = 12,
|
38 |
+
mlp_ratio: float = 4.0,
|
39 |
+
qkv_bias: bool = True,
|
40 |
+
qk_norm: bool = False,
|
41 |
+
decoder_block_fn: nn.Module = Block,
|
42 |
+
decoder_norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
43 |
+
decoder_act_layer: LayerType = nn.GELU,
|
44 |
+
decoder_mlp_layer: nn.Module = Mlp,
|
45 |
+
decoder_embed_dim: int = 512,
|
46 |
+
decoder_depth: int = 8,
|
47 |
+
decoder_num_heads: int = 16,
|
48 |
+
decoder_use_neck: bool = True,
|
49 |
+
decoder_neck_dim: int = 256,
|
50 |
+
) -> None:
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.dynamic_img_size = dynamic_img_size
|
54 |
+
self.decoder_use_neck = decoder_use_neck
|
55 |
+
|
56 |
+
self.init_encoder(
|
57 |
+
img_size=img_size,
|
58 |
+
patch_size=patch_size,
|
59 |
+
in_chans=in_chans,
|
60 |
+
embed_layer=embed_layer,
|
61 |
+
dynamic_img_size=dynamic_img_size,
|
62 |
+
dynamic_img_pad=dynamic_img_pad,
|
63 |
+
use_abs_pos=use_abs_pos,
|
64 |
+
block_fn=block_fn,
|
65 |
+
norm_layer=norm_layer,
|
66 |
+
act_layer=act_layer,
|
67 |
+
mlp_layer=mlp_layer,
|
68 |
+
embed_dim=embed_dim,
|
69 |
+
depth=depth,
|
70 |
+
num_heads=num_heads,
|
71 |
+
mlp_ratio=mlp_ratio,
|
72 |
+
qkv_bias=qkv_bias,
|
73 |
+
qk_norm=qk_norm,
|
74 |
+
)
|
75 |
+
self.init_decoder(
|
76 |
+
patch_size=patch_size,
|
77 |
+
in_chans=in_chans,
|
78 |
+
embed_dim=embed_dim,
|
79 |
+
use_abs_pos=use_abs_pos,
|
80 |
+
decoder_block_fn=decoder_block_fn,
|
81 |
+
decoder_norm_layer=decoder_norm_layer,
|
82 |
+
decoder_act_layer=decoder_act_layer,
|
83 |
+
decoder_mlp_layer=decoder_mlp_layer,
|
84 |
+
decoder_embed_dim=decoder_embed_dim,
|
85 |
+
decoder_depth=decoder_depth,
|
86 |
+
decoder_num_heads=decoder_num_heads,
|
87 |
+
decoder_use_neck=decoder_use_neck,
|
88 |
+
decoder_neck_dim=decoder_neck_dim,
|
89 |
+
mlp_ratio=mlp_ratio,
|
90 |
+
qkv_bias=qkv_bias,
|
91 |
+
qk_norm=qk_norm,
|
92 |
+
)
|
93 |
+
self.init_weights(
|
94 |
+
grid_size=self.patch_embed.grid_size,
|
95 |
+
embed_dim=embed_dim,
|
96 |
+
decoder_embed_dim=decoder_embed_dim,
|
97 |
+
)
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def from_config(cls, cfg: DictConfig) -> dict[str, Any]:
|
101 |
+
embed_dim, depth, num_heads = get_vit_scale(cfg.MODEL.VIT_SCALE)
|
102 |
+
return {
|
103 |
+
"img_size": cfg.MODEL.IMG_SIZE,
|
104 |
+
"patch_size": cfg.MODEL.PATCH_SIZE,
|
105 |
+
"in_chans": cfg.MODEL.IN_CHANS,
|
106 |
+
"dynamic_img_size": cfg.MODEL.DYNAMIC_IMG_SIZE,
|
107 |
+
"dynamic_img_pad": cfg.MODEL.DYNAMIC_IMG_PAD,
|
108 |
+
"use_abs_pos": cfg.MODEL.USE_ABS_POS,
|
109 |
+
"embed_dim": embed_dim,
|
110 |
+
"depth": depth,
|
111 |
+
"num_heads": num_heads,
|
112 |
+
"decoder_embed_dim": cfg.MODEL.DECODER_EMBED_DIM,
|
113 |
+
"decoder_depth": cfg.MODEL.DECODER_DEPTH,
|
114 |
+
"decoder_num_heads": cfg.MODEL.DECODER_NUM_HEADS,
|
115 |
+
"decoder_use_neck": cfg.MODEL.DECODER_USE_NECK,
|
116 |
+
"decoder_neck_dim": cfg.MODEL.DECODER_NECK_DIM,
|
117 |
+
}
|
118 |
+
|
119 |
+
def init_encoder(self, *,
|
120 |
+
img_size: int,
|
121 |
+
patch_size: int,
|
122 |
+
in_chans: int,
|
123 |
+
embed_layer: Callable,
|
124 |
+
dynamic_img_size: bool,
|
125 |
+
dynamic_img_pad: bool,
|
126 |
+
use_abs_pos: bool,
|
127 |
+
block_fn: nn.Module,
|
128 |
+
norm_layer: LayerType | None,
|
129 |
+
act_layer: LayerType | None,
|
130 |
+
mlp_layer: nn.Module,
|
131 |
+
embed_dim: int,
|
132 |
+
depth: int,
|
133 |
+
num_heads: int,
|
134 |
+
mlp_ratio: float,
|
135 |
+
qkv_bias: bool,
|
136 |
+
qk_norm: bool,
|
137 |
+
) -> None:
|
138 |
+
embed_args = {}
|
139 |
+
if dynamic_img_size:
|
140 |
+
embed_args.update(dict(strict_img_size=False))
|
141 |
+
self.patch_embed = embed_layer(
|
142 |
+
img_size=img_size,
|
143 |
+
patch_size=patch_size,
|
144 |
+
in_chans=in_chans,
|
145 |
+
embed_dim=embed_dim,
|
146 |
+
dynamic_img_pad=dynamic_img_pad,
|
147 |
+
output_fmt="NHWC",
|
148 |
+
**embed_args
|
149 |
+
)
|
150 |
+
|
151 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, embed_dim)) if use_abs_pos else None
|
152 |
+
self.blocks = nn.ModuleList([
|
153 |
+
block_fn(
|
154 |
+
dim=embed_dim,
|
155 |
+
num_heads=num_heads,
|
156 |
+
mlp_ratio=mlp_ratio,
|
157 |
+
qkv_bias=qkv_bias,
|
158 |
+
qk_norm=qk_norm,
|
159 |
+
norm_layer=norm_layer,
|
160 |
+
act_layer=act_layer,
|
161 |
+
mlp_layer=mlp_layer,
|
162 |
+
) for _ in range(depth)
|
163 |
+
])
|
164 |
+
self.norm = norm_layer(embed_dim)
|
165 |
+
|
166 |
+
def init_decoder(self, *,
|
167 |
+
patch_size: int,
|
168 |
+
in_chans: int,
|
169 |
+
embed_dim: int,
|
170 |
+
use_abs_pos: bool,
|
171 |
+
decoder_block_fn: nn.Module,
|
172 |
+
decoder_norm_layer: LayerType | None,
|
173 |
+
decoder_act_layer: LayerType | None,
|
174 |
+
decoder_mlp_layer: nn.Module,
|
175 |
+
decoder_embed_dim: int,
|
176 |
+
decoder_depth: int,
|
177 |
+
decoder_num_heads: int,
|
178 |
+
decoder_use_neck: bool,
|
179 |
+
decoder_neck_dim: int,
|
180 |
+
mlp_ratio: float,
|
181 |
+
qkv_bias: bool,
|
182 |
+
qk_norm: bool,
|
183 |
+
) -> None:
|
184 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
185 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
|
186 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, decoder_embed_dim)) if use_abs_pos else None
|
187 |
+
self.decoder_blocks = nn.ModuleList([
|
188 |
+
decoder_block_fn(
|
189 |
+
dim=decoder_embed_dim,
|
190 |
+
num_heads=decoder_num_heads,
|
191 |
+
mlp_ratio=mlp_ratio,
|
192 |
+
qkv_bias=qkv_bias,
|
193 |
+
qk_norm=qk_norm,
|
194 |
+
norm_layer=decoder_norm_layer,
|
195 |
+
act_layer=decoder_act_layer,
|
196 |
+
mlp_layer=decoder_mlp_layer,
|
197 |
+
) for _ in range(decoder_depth)
|
198 |
+
])
|
199 |
+
self.decoder_norm = decoder_norm_layer(decoder_embed_dim)
|
200 |
+
if decoder_use_neck:
|
201 |
+
self.decoder_neck = nn.Sequential(
|
202 |
+
nn.Conv2d(
|
203 |
+
in_channels=decoder_embed_dim,
|
204 |
+
out_channels=decoder_neck_dim,
|
205 |
+
kernel_size=1,
|
206 |
+
bias=False,
|
207 |
+
),
|
208 |
+
LayerNorm2d(decoder_neck_dim),
|
209 |
+
decoder_act_layer(),
|
210 |
+
nn.Conv2d(
|
211 |
+
in_channels=decoder_neck_dim,
|
212 |
+
out_channels=decoder_neck_dim,
|
213 |
+
kernel_size=3,
|
214 |
+
padding=1,
|
215 |
+
bias=False,
|
216 |
+
),
|
217 |
+
LayerNorm2d(decoder_neck_dim),
|
218 |
+
decoder_act_layer(),
|
219 |
+
nn.Conv2d(
|
220 |
+
in_channels=decoder_neck_dim,
|
221 |
+
out_channels=decoder_embed_dim,
|
222 |
+
kernel_size=1,
|
223 |
+
bias=False,
|
224 |
+
),
|
225 |
+
LayerNorm2d(decoder_embed_dim),
|
226 |
+
)
|
227 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans)
|
228 |
+
|
229 |
+
def init_weights(self, *,
|
230 |
+
grid_size: tuple[int, int],
|
231 |
+
embed_dim: int,
|
232 |
+
decoder_embed_dim: int
|
233 |
+
) -> None:
|
234 |
+
w = self.patch_embed.proj.weight.data
|
235 |
+
torch.nn.init.xavier_uniform_(w.view(w.size(0), -1))
|
236 |
+
|
237 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
238 |
+
|
239 |
+
if self.pos_embed is not None:
|
240 |
+
self.pos_embed.data.copy_(build_sincos2d_pos_embed(
|
241 |
+
feat_shape=grid_size,
|
242 |
+
dim=embed_dim,
|
243 |
+
interleave_sin_cos=True
|
244 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
245 |
+
|
246 |
+
if self.decoder_pos_embed is not None:
|
247 |
+
self.decoder_pos_embed.data.copy_(build_sincos2d_pos_embed(
|
248 |
+
feat_shape=grid_size,
|
249 |
+
dim=decoder_embed_dim,
|
250 |
+
interleave_sin_cos=True
|
251 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
252 |
+
|
253 |
+
if self.decoder_use_neck:
|
254 |
+
for m in self.decoder_neck.modules():
|
255 |
+
if isinstance(m, nn.Conv2d):
|
256 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
257 |
+
if m.bias is not None:
|
258 |
+
nn.init.zeros_(m.bias)
|
259 |
+
nn.init.zeros_(self.decoder_neck[-1].weight)
|
260 |
+
nn.init.zeros_(self.decoder_neck[-1].bias)
|
261 |
+
|
262 |
+
self.apply(self._init_weights)
|
263 |
+
|
264 |
+
def _init_weights(self, module: nn.Module) -> None:
|
265 |
+
if isinstance(module, nn.Linear):
|
266 |
+
nn.init.xavier_uniform_(module.weight)
|
267 |
+
if module.bias is not None:
|
268 |
+
nn.init.constant_(module.bias, 0.0)
|
269 |
+
|
270 |
+
def forward_encoder(self, x: torch.Tensor, mask_ratio: float) -> tuple[torch.Tensor, torch.BoolTensor, int, int]:
|
271 |
+
x = self.patch_embed(x)
|
272 |
+
B, H, W, E = x.shape
|
273 |
+
if self.pos_embed is not None:
|
274 |
+
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, (H, W))
|
275 |
+
x = x.view(B, -1, E)
|
276 |
+
|
277 |
+
mask = super().random_masking(x, mask_ratio)
|
278 |
+
x = x[~mask].reshape(B, -1, E)
|
279 |
+
|
280 |
+
for block in self.blocks:
|
281 |
+
x = block(x)
|
282 |
+
x = self.norm(x)
|
283 |
+
|
284 |
+
return x, mask, H, W
|
285 |
+
|
286 |
+
def forward_decoder(self, x: torch.Tensor, mask: torch.BoolTensor, H: int, W: int) -> torch.Tensor:
|
287 |
+
x = self.decoder_embed(x)
|
288 |
+
|
289 |
+
B, L = mask.shape
|
290 |
+
E = x.shape[-1]
|
291 |
+
mask_tokens = self.mask_token.repeat(B, L, 1).to(x.dtype)
|
292 |
+
mask_tokens[~mask] = x.reshape(-1, E)
|
293 |
+
x = mask_tokens
|
294 |
+
|
295 |
+
if self.decoder_pos_embed is not None:
|
296 |
+
x = x.view(B, H, W, E)
|
297 |
+
x = x + resample_abs_pos_embed_nhwc(self.decoder_pos_embed, (H, W))
|
298 |
+
x = x.view(B, -1, E)
|
299 |
+
|
300 |
+
for block in self.decoder_blocks:
|
301 |
+
x = block(x)
|
302 |
+
x = self.decoder_norm(x)
|
303 |
+
if self.decoder_use_neck:
|
304 |
+
x = x + self.decoder_neck(
|
305 |
+
x.permute(0, 2, 1).reshape(B, E, H, W).contiguous()
|
306 |
+
).permute(0, 2, 3, 1).reshape(B, L, -1).contiguous()
|
307 |
+
x = self.decoder_pred(x)
|
308 |
+
|
309 |
+
return x
|
310 |
+
|
311 |
+
def forward(self, x: torch.Tensor, mask_ratio: float) -> tuple[torch.Tensor, torch.BoolTensor]:
|
312 |
+
x, mask, H, W = self.forward_encoder(x, mask_ratio)
|
313 |
+
x = self.forward_decoder(x, mask, H, W)
|
314 |
+
return x, mask
|
315 |
+
|
316 |
+
@MODEL_REGISTRY.register()
|
317 |
+
class DracoDenoiseAutoencoder(DenoisingReconstructionAutoencoderVisionTransformerBase):
|
318 |
+
"""
|
319 |
+
Masked Autoencoder (MAE) with Vision Transformer backbone.
|
320 |
+
Note that `cls_token` is discarded.
|
321 |
+
"""
|
322 |
+
|
323 |
+
@configurable
|
324 |
+
def __init__(self, *,
|
325 |
+
img_size: int = 224,
|
326 |
+
patch_size: int = 16,
|
327 |
+
in_chans: int = 3,
|
328 |
+
embed_layer: Callable = PatchEmbed,
|
329 |
+
dynamic_img_size: bool = False,
|
330 |
+
dynamic_img_pad: bool = False,
|
331 |
+
use_abs_pos: bool = True,
|
332 |
+
block_fn: nn.Module = SAMBlock,
|
333 |
+
norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
334 |
+
act_layer: LayerType = nn.GELU,
|
335 |
+
mlp_layer: nn.Module = Mlp,
|
336 |
+
embed_dim: int = 768,
|
337 |
+
depth: int = 12,
|
338 |
+
num_heads: int = 12,
|
339 |
+
mlp_ratio: float = 4.0,
|
340 |
+
qkv_bias: bool = True,
|
341 |
+
qk_norm: bool = False,
|
342 |
+
window_size: int = 16,
|
343 |
+
global_attn_indexes: list[int] = [2, 5, 8, 11],
|
344 |
+
decoder_block_fn: nn.Module = SAMBlock,
|
345 |
+
decoder_norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
346 |
+
decoder_act_layer: LayerType = nn.GELU,
|
347 |
+
decoder_mlp_layer: nn.Module = Mlp,
|
348 |
+
decoder_embed_dim: int = 512,
|
349 |
+
decoder_depth: int = 8,
|
350 |
+
decoder_num_heads: int = 16,
|
351 |
+
decoder_use_neck: bool = True,
|
352 |
+
decoder_neck_dim: int = 256,
|
353 |
+
decoder_global_attn_indexes: list[int] = [3, 7],
|
354 |
+
) -> None:
|
355 |
+
super().__init__()
|
356 |
+
|
357 |
+
self.dynamic_img_size = dynamic_img_size
|
358 |
+
self.decoder_use_neck = decoder_use_neck
|
359 |
+
|
360 |
+
self.init_encoder(
|
361 |
+
img_size=img_size,
|
362 |
+
patch_size=patch_size,
|
363 |
+
in_chans=in_chans,
|
364 |
+
embed_layer=embed_layer,
|
365 |
+
dynamic_img_size=dynamic_img_size,
|
366 |
+
dynamic_img_pad=dynamic_img_pad,
|
367 |
+
use_abs_pos=use_abs_pos,
|
368 |
+
block_fn=block_fn,
|
369 |
+
norm_layer=norm_layer,
|
370 |
+
act_layer=act_layer,
|
371 |
+
mlp_layer=mlp_layer,
|
372 |
+
embed_dim=embed_dim,
|
373 |
+
depth=depth,
|
374 |
+
num_heads=num_heads,
|
375 |
+
mlp_ratio=mlp_ratio,
|
376 |
+
qkv_bias=qkv_bias,
|
377 |
+
qk_norm=qk_norm,
|
378 |
+
window_size=window_size,
|
379 |
+
global_attn_indexes=global_attn_indexes
|
380 |
+
)
|
381 |
+
self.init_decoder(
|
382 |
+
img_size=img_size,
|
383 |
+
patch_size=patch_size,
|
384 |
+
in_chans=in_chans,
|
385 |
+
embed_dim=embed_dim,
|
386 |
+
use_abs_pos=use_abs_pos,
|
387 |
+
decoder_block_fn=decoder_block_fn,
|
388 |
+
decoder_norm_layer=decoder_norm_layer,
|
389 |
+
decoder_act_layer=decoder_act_layer,
|
390 |
+
decoder_mlp_layer=decoder_mlp_layer,
|
391 |
+
decoder_embed_dim=decoder_embed_dim,
|
392 |
+
decoder_depth=decoder_depth,
|
393 |
+
decoder_num_heads=decoder_num_heads,
|
394 |
+
decoder_use_neck=decoder_use_neck,
|
395 |
+
decoder_neck_dim=decoder_neck_dim,
|
396 |
+
mlp_ratio=mlp_ratio,
|
397 |
+
qkv_bias=qkv_bias,
|
398 |
+
qk_norm=qk_norm,
|
399 |
+
window_size=window_size,
|
400 |
+
decoder_global_attn_indexes=decoder_global_attn_indexes
|
401 |
+
)
|
402 |
+
self.init_weights(
|
403 |
+
grid_size=self.patch_embed.grid_size,
|
404 |
+
embed_dim=embed_dim,
|
405 |
+
decoder_embed_dim=decoder_embed_dim,
|
406 |
+
)
|
407 |
+
|
408 |
+
@classmethod
|
409 |
+
def from_config(cls, cfg: DictConfig) -> dict[str, Any]:
|
410 |
+
embed_dim, depth, num_heads = get_vit_scale(cfg.MODEL.VIT_SCALE)
|
411 |
+
global_attn_indexes = get_global_attn_indexes(depth)
|
412 |
+
return {
|
413 |
+
"img_size": cfg.MODEL.IMG_SIZE,
|
414 |
+
"patch_size": cfg.MODEL.PATCH_SIZE,
|
415 |
+
"in_chans": cfg.MODEL.IN_CHANS,
|
416 |
+
"dynamic_img_size": cfg.MODEL.DYNAMIC_IMG_SIZE,
|
417 |
+
"dynamic_img_pad": cfg.MODEL.DYNAMIC_IMG_PAD,
|
418 |
+
"use_abs_pos": cfg.MODEL.USE_ABS_POS,
|
419 |
+
"embed_dim": embed_dim,
|
420 |
+
"depth": depth,
|
421 |
+
"num_heads": num_heads,
|
422 |
+
"window_size": cfg.MODEL.WINDOW_SIZE,
|
423 |
+
"global_attn_indexes": global_attn_indexes,
|
424 |
+
"decoder_embed_dim": cfg.MODEL.DECODER_EMBED_DIM,
|
425 |
+
"decoder_depth": cfg.MODEL.DECODER_DEPTH,
|
426 |
+
"decoder_num_heads": cfg.MODEL.DECODER_NUM_HEADS,
|
427 |
+
"decoder_use_neck": cfg.MODEL.DECODER_USE_NECK,
|
428 |
+
"decoder_neck_dim": cfg.MODEL.DECODER_NECK_DIM,
|
429 |
+
"decoder_global_attn_indexes": cfg.MODEL.DECODER_GLOBAL_ATTN_INDEXES,
|
430 |
+
}
|
431 |
+
|
432 |
+
def init_encoder(self, *,
|
433 |
+
img_size: int,
|
434 |
+
patch_size: int,
|
435 |
+
in_chans: int,
|
436 |
+
embed_layer: Callable,
|
437 |
+
dynamic_img_size: bool,
|
438 |
+
dynamic_img_pad: bool,
|
439 |
+
use_abs_pos: bool,
|
440 |
+
block_fn: nn.Module,
|
441 |
+
norm_layer: LayerType | None,
|
442 |
+
act_layer: LayerType | None,
|
443 |
+
mlp_layer: nn.Module,
|
444 |
+
embed_dim: int,
|
445 |
+
depth: int,
|
446 |
+
num_heads: int,
|
447 |
+
mlp_ratio: float,
|
448 |
+
qkv_bias: bool,
|
449 |
+
qk_norm: bool,
|
450 |
+
window_size: int,
|
451 |
+
global_attn_indexes: list,
|
452 |
+
) -> None:
|
453 |
+
embed_args = {}
|
454 |
+
if dynamic_img_size:
|
455 |
+
# flatten deferred until after pos embed
|
456 |
+
embed_args.update(dict(strict_img_size=False))
|
457 |
+
self.patch_embed = embed_layer(
|
458 |
+
img_size=img_size,
|
459 |
+
patch_size=patch_size,
|
460 |
+
in_chans=in_chans,
|
461 |
+
embed_dim=embed_dim,
|
462 |
+
dynamic_img_pad=dynamic_img_pad,
|
463 |
+
output_fmt="NHWC",
|
464 |
+
**embed_args
|
465 |
+
)
|
466 |
+
|
467 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, embed_dim)) if use_abs_pos else None
|
468 |
+
self.blocks = nn.ModuleList(
|
469 |
+
block_fn(
|
470 |
+
dim=embed_dim,
|
471 |
+
num_heads=num_heads,
|
472 |
+
mlp_ratio=mlp_ratio,
|
473 |
+
qkv_bias=qkv_bias,
|
474 |
+
qk_norm=qk_norm,
|
475 |
+
norm_layer=norm_layer,
|
476 |
+
act_layer=act_layer,
|
477 |
+
mlp_layer=mlp_layer,
|
478 |
+
use_rel_pos=True,
|
479 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
480 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
481 |
+
) for i in range(depth)
|
482 |
+
)
|
483 |
+
|
484 |
+
self.norm = norm_layer(embed_dim)
|
485 |
+
|
486 |
+
def init_decoder(self, *,
|
487 |
+
img_size: int,
|
488 |
+
patch_size: int,
|
489 |
+
in_chans: int,
|
490 |
+
embed_dim: int,
|
491 |
+
use_abs_pos: bool,
|
492 |
+
decoder_block_fn: nn.Module,
|
493 |
+
decoder_norm_layer: LayerType | None,
|
494 |
+
decoder_act_layer: LayerType | None,
|
495 |
+
decoder_mlp_layer: nn.Module,
|
496 |
+
decoder_embed_dim: int,
|
497 |
+
decoder_depth: int,
|
498 |
+
decoder_num_heads: int,
|
499 |
+
decoder_use_neck: bool,
|
500 |
+
decoder_neck_dim: int,
|
501 |
+
mlp_ratio: float,
|
502 |
+
qkv_bias: bool,
|
503 |
+
qk_norm: bool,
|
504 |
+
window_size: int,
|
505 |
+
decoder_global_attn_indexes: list[int]
|
506 |
+
) -> None:
|
507 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
|
508 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, decoder_embed_dim)) if use_abs_pos else None
|
509 |
+
self.decoder_blocks = nn.ModuleList(
|
510 |
+
decoder_block_fn(
|
511 |
+
dim=decoder_embed_dim,
|
512 |
+
num_heads=decoder_num_heads,
|
513 |
+
mlp_ratio=mlp_ratio,
|
514 |
+
qkv_bias=qkv_bias,
|
515 |
+
qk_norm=qk_norm,
|
516 |
+
norm_layer=decoder_norm_layer,
|
517 |
+
act_layer=decoder_act_layer,
|
518 |
+
mlp_layer=decoder_mlp_layer,
|
519 |
+
use_rel_pos=True,
|
520 |
+
window_size=window_size if i not in decoder_global_attn_indexes else 0,
|
521 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
522 |
+
) for i in range(decoder_depth)
|
523 |
+
)
|
524 |
+
self.decoder_norm = decoder_norm_layer(decoder_embed_dim)
|
525 |
+
if decoder_use_neck:
|
526 |
+
self.decoder_neck = nn.Sequential(
|
527 |
+
nn.Conv2d(
|
528 |
+
in_channels=decoder_embed_dim,
|
529 |
+
out_channels=decoder_neck_dim,
|
530 |
+
kernel_size=1,
|
531 |
+
bias=False,
|
532 |
+
),
|
533 |
+
LayerNorm2d(decoder_neck_dim),
|
534 |
+
decoder_act_layer(),
|
535 |
+
nn.Conv2d(
|
536 |
+
in_channels=decoder_neck_dim,
|
537 |
+
out_channels=decoder_neck_dim,
|
538 |
+
kernel_size=3,
|
539 |
+
padding=1,
|
540 |
+
bias=False,
|
541 |
+
),
|
542 |
+
LayerNorm2d(decoder_neck_dim),
|
543 |
+
decoder_act_layer(),
|
544 |
+
nn.Conv2d(
|
545 |
+
in_channels=decoder_neck_dim,
|
546 |
+
out_channels=decoder_embed_dim,
|
547 |
+
kernel_size=1,
|
548 |
+
bias=False,
|
549 |
+
),
|
550 |
+
LayerNorm2d(decoder_embed_dim),
|
551 |
+
)
|
552 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans)
|
553 |
+
|
554 |
+
def init_weights(self, *,
|
555 |
+
grid_size: tuple[int, int],
|
556 |
+
embed_dim: int,
|
557 |
+
decoder_embed_dim: int
|
558 |
+
) -> None:
|
559 |
+
w = self.patch_embed.proj.weight.data
|
560 |
+
torch.nn.init.xavier_uniform_(w.view(w.size(0), -1))
|
561 |
+
|
562 |
+
if self.pos_embed is not None:
|
563 |
+
self.pos_embed.data.copy_(build_sincos2d_pos_embed(
|
564 |
+
feat_shape=grid_size,
|
565 |
+
dim=embed_dim,
|
566 |
+
interleave_sin_cos=True
|
567 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
568 |
+
|
569 |
+
if self.decoder_pos_embed is not None:
|
570 |
+
self.decoder_pos_embed.data.copy_(build_sincos2d_pos_embed(
|
571 |
+
feat_shape=grid_size,
|
572 |
+
dim=decoder_embed_dim,
|
573 |
+
interleave_sin_cos=True
|
574 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
575 |
+
|
576 |
+
# Zero-initialize the neck
|
577 |
+
if self.decoder_use_neck:
|
578 |
+
for m in self.decoder_neck.modules():
|
579 |
+
if isinstance(m, nn.Conv2d):
|
580 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
581 |
+
if m.bias is not None:
|
582 |
+
nn.init.zeros_(m.bias)
|
583 |
+
nn.init.zeros_(self.decoder_neck[-1].weight)
|
584 |
+
nn.init.zeros_(self.decoder_neck[-1].bias)
|
585 |
+
|
586 |
+
self.apply(self._init_weights)
|
587 |
+
|
588 |
+
def _init_weights(self, module: nn.Module) -> None:
|
589 |
+
if isinstance(module, nn.Linear):
|
590 |
+
nn.init.xavier_uniform_(module.weight)
|
591 |
+
if module.bias is not None:
|
592 |
+
nn.init.constant_(module.bias, 0.0)
|
593 |
+
|
594 |
+
def forward_encoder(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
|
595 |
+
"""
|
596 |
+
Forward pass of the encoder.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
`x` (torch.Tensor): Image of shape [B, C, H, W].
|
600 |
+
|
601 |
+
Returns:
|
602 |
+
(torch.Tensor): Encoded image of shape [B, num_kept, E].
|
603 |
+
(int): Height of the encoded tokens.
|
604 |
+
(int): Width of the encoded tokens.
|
605 |
+
"""
|
606 |
+
x = self.patch_embed(x)
|
607 |
+
B, H, W, E = x.shape
|
608 |
+
|
609 |
+
if self.pos_embed is not None:
|
610 |
+
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, (H, W))
|
611 |
+
|
612 |
+
for block in self.blocks:
|
613 |
+
x = block(x)
|
614 |
+
|
615 |
+
x = x.view(B, -1, E)
|
616 |
+
x = self.norm(x)
|
617 |
+
|
618 |
+
return x, H, W
|
619 |
+
|
620 |
+
def forward_decoder(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
621 |
+
"""
|
622 |
+
Forward pass of the decoder.
|
623 |
+
|
624 |
+
Args:
|
625 |
+
`x` (torch.Tensor): Encoded image of shape [B, num_kept, E].
|
626 |
+
`H` (int): Height of the encoded tokens.
|
627 |
+
`W` (int): Width of the encoded tokens.
|
628 |
+
|
629 |
+
Returns:
|
630 |
+
(torch.Tensor): Decoded image of shape [B, L, E].
|
631 |
+
"""
|
632 |
+
x = self.decoder_embed(x) # [B, num_kept, E]
|
633 |
+
B, L, E = x.shape
|
634 |
+
|
635 |
+
if self.decoder_pos_embed is not None:
|
636 |
+
x = x.view(B, H, W, E)
|
637 |
+
x = x + resample_abs_pos_embed_nhwc(self.decoder_pos_embed, (H, W))
|
638 |
+
|
639 |
+
for block in self.decoder_blocks:
|
640 |
+
x = block(x)
|
641 |
+
x = x.view(B, -1, E)
|
642 |
+
|
643 |
+
x = self.decoder_norm(x)
|
644 |
+
if self.decoder_use_neck:
|
645 |
+
x = x + self.decoder_neck(
|
646 |
+
x.permute(0, 2, 1).reshape(B, E, H, W).contiguous()
|
647 |
+
).permute(0, 2, 3, 1).reshape(B, L, -1).contiguous()
|
648 |
+
x = self.decoder_pred(x)
|
649 |
+
|
650 |
+
return x
|
651 |
+
|
652 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
653 |
+
"""
|
654 |
+
Args:
|
655 |
+
`x` (torch.Tensor): Image of shape [B, C, H, W].
|
656 |
+
|
657 |
+
Returns:
|
658 |
+
(torch.Tensor): The prediction of shape [B, L, E].
|
659 |
+
"""
|
660 |
+
x, H, W = self.forward_encoder(x)
|
661 |
+
x = self.forward_decoder(x, H, W)
|
662 |
+
return x
|
663 |
+
|
draco/model/draco_base.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class DenoisingReconstructionAutoencoderVisionTransformerBase(nn.Module, metaclass=ABCMeta):
|
8 |
+
def __init__(self) -> None:
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
@torch.jit.ignore
|
12 |
+
def no_weight_decay(self) -> set:
|
13 |
+
return {"cls_token"}
|
14 |
+
|
15 |
+
@torch.jit.ignore
|
16 |
+
def group_matcher(self, coarse: bool = False) -> dict:
|
17 |
+
return dict(
|
18 |
+
stem=r'^(?:_orig_mod\.)?cls_token|^(?:_orig_mod\.)?pos_embed|^(?:_orig_mod\.)?patch_embed',
|
19 |
+
blocks=[(r'^(?:_orig_mod\.)?blocks\.(\d+)', None), (r'^(?:_orig_mod\.)?norm', (99999,))]
|
20 |
+
)
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def random_masking(cls, x: torch.Tensor, mask_ratio: float) -> torch.BoolTensor:
|
24 |
+
B, L = x.shape[:2]
|
25 |
+
num_masked = int(L * mask_ratio)
|
26 |
+
|
27 |
+
noise = torch.rand(B, L, device=x.device)
|
28 |
+
rank = noise.argsort(dim=1)
|
29 |
+
mask = rank < num_masked
|
30 |
+
|
31 |
+
return mask
|
32 |
+
|
33 |
+
@abstractmethod
|
34 |
+
def forward(self) -> None:
|
35 |
+
raise NotImplementedError
|
draco/model/layer/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .normalization import LayerNorm2d
|
2 |
+
|
3 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
draco/model/layer/normalization.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"LayerNorm2d",
|
6 |
+
]
|
7 |
+
|
8 |
+
|
9 |
+
class LayerNorm2d(nn.Module):
|
10 |
+
def __init__(self, num_features: int, eps: float = 1e-6) -> None:
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.weight = nn.Parameter(torch.ones(num_features))
|
14 |
+
self.bias = nn.Parameter(torch.zeros(num_features))
|
15 |
+
self.eps = eps
|
16 |
+
|
17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
18 |
+
u = x.mean(1, keepdim=True)
|
19 |
+
s = (x - u).square().mean(1, keepdim=True)
|
20 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
21 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
22 |
+
return x
|
draco/model/utils/constant.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_vit_scale(scale: str) -> tuple[int, int, int]:
|
2 |
+
if scale == "tiny":
|
3 |
+
return 192, 12, 3
|
4 |
+
elif scale == "small":
|
5 |
+
return 384, 12, 6
|
6 |
+
elif scale == "base":
|
7 |
+
return 768, 12, 12
|
8 |
+
elif scale == "large":
|
9 |
+
return 1024, 24, 16
|
10 |
+
elif scale == "huge":
|
11 |
+
return 1280, 32, 16
|
12 |
+
else:
|
13 |
+
raise KeyError(f"Unknown Vision Transformer scale: {scale}")
|
14 |
+
|
15 |
+
def get_global_attn_indexes(num_layers: int) -> list[int]:
|
16 |
+
"""
|
17 |
+
Args:
|
18 |
+
num_layers (int): The number of layers.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
List[int]: The global attention indexes.
|
22 |
+
"""
|
23 |
+
|
24 |
+
return list(range(num_layers // 4 - 1, num_layers, num_layers // 4))
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchvision==0.20.1
|
3 |
+
h5py==3.12.1
|
4 |
+
numpy==1.26.4
|
5 |
+
pandas==2.2.2
|
6 |
+
mrcfile==1.5.3
|
7 |
+
scipy==1.13.1
|
8 |
+
pycocotools==2.0.8
|
9 |
+
omegaconf==2.3.0
|
10 |
+
pillow
|
11 |
+
fvcore
|