Felix-Xu commited on
Commit
3bf7d18
·
1 Parent(s): fe94f9f

denoise model update

Browse files
app.py CHANGED
@@ -1,7 +1,219 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
 
1
  import gradio as gr
2
+ import h5py
3
+ import mrcfile
4
+ import numpy as np
5
+ from PIL import Image
6
+ from omegaconf import DictConfig
7
+ import torch
8
+ from pathlib import Path
9
+ from torchvision.transforms import functional as F
10
+ import torchvision.transforms.v2 as v2
11
 
 
 
12
 
13
+ from draco.configuration import CfgNode
14
+ from draco.model import (
15
+ build_model,
16
+ load_pretrained
17
+ )
18
+
19
+ class DRACODenoiser(object):
20
+ def __init__(self,
21
+ cfg: DictConfig,
22
+ ckpt_path: Path,
23
+ ) -> None:
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ self.transform = self.build_transform()
27
+ self.model = build_model(cfg).to(self.device).eval()
28
+ self.model = load_pretrained(self.model, ckpt_path, self.device)
29
+ self.patch_size = cfg.MODEL.PATCH_SIZE
30
+
31
+ def patchify(self, image: torch.Tensor) -> torch.Tensor:
32
+ B, C, H, W = image.shape
33
+ P = self.patch_size
34
+ if H % P != 0 or W % P != 0:
35
+ image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0)
36
+
37
+ patches = image.unfold(2, P, P).unfold(3, P, P)
38
+ patches = patches.permute(0, 2, 3, 4, 5, 1)
39
+ patches = patches.reshape(B, -1, P * P * C)
40
+ return patches
41
+
42
+ def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor:
43
+ B = patches.shape[0]
44
+ P = self.patch_size
45
+
46
+ images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1)
47
+ images = images.permute(0, 5, 1, 3, 2, 4)
48
+ images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P)
49
+ images = images[..., :H, :W]
50
+ return images
51
+
52
+ @classmethod
53
+ def build_transform(cls) -> v2.Compose:
54
+ return v2.Compose([
55
+ v2.ToImage(),
56
+ v2.ToDtype(torch.float32, scale=True)
57
+ ])
58
+
59
+ @torch.inference_mode()
60
+ def inference(self, image: Image.Image) -> None:
61
+ W, H = image.size
62
+
63
+ x = self.transform(image).unsqueeze(0).to(self.device)
64
+ y = self.model(x)
65
+
66
+ x = self.patchify(x).detach().cpu().numpy()
67
+ denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
68
+
69
+ return denoised
70
+
71
+ # Model Initialization
72
+ cfg = CfgNode.load_yaml_with_base(Path("draco.yaml"))
73
+ CfgNode.merge_with_dotlist(cfg, [])
74
+ ckpt_path = Path("denoise.ckpt")
75
+ denoiser = DRACODenoiser(cfg, ckpt_path)
76
+
77
+ def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray:
78
+
79
+ image = (image - image.min()) / (image.max() - image.min())
80
+ mean = image.mean()
81
+ std = image.std()
82
+
83
+ f = std / t_sd
84
+
85
+ black = mean - t_mean * f
86
+ white = mean + (1 - t_mean) * f
87
+
88
+ new_image = np.clip(image, black, white)
89
+ new_image = (new_image - black) / (white - black)
90
+ return new_image
91
+
92
+
93
+ def load_data(file_path) -> np.ndarray:
94
+ if file_path.endswith('.h5'):
95
+ with h5py.File(file_path, "r") as f:
96
+ full_micrograph = f["micrograph"] if "micrograph" in f else f["data"]
97
+ full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean()
98
+ full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std()
99
+ data = full_micrograph[:].astype(np.float32)
100
+ elif file_path.endswith('.mrc'):
101
+ with mrcfile.open(file_path, "r") as f:
102
+ data = f.data[:].astype(np.float32)
103
+ full_mean = data.mean()
104
+ full_std = data.std()
105
+ else:
106
+ raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.")
107
+ data = (data - full_mean) / full_std
108
+ return data
109
+
110
+ def display_crop(data, x_offset, y_offset, auto_contrast) -> Image:
111
+
112
+ crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
113
+ original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min())
114
+ input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8))
115
+
116
+ return input_image
117
+
118
+ def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image:
119
+
120
+ crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
121
+ denoised_data = denoiser.inference(Image.fromarray(crop))
122
+
123
+ denoised_data = denoised_data.squeeze()
124
+ denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min())
125
+ denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8))
126
+
127
+ return denoised_image
128
+
129
+ def clear_images() -> tuple:
130
+ return None, None, None, gr.update(maximum=512), gr.update(maximum=512)
131
+
132
+ with gr.Blocks(css="""
133
+ .gradio-container {
134
+ background-color: #f7f9fc;
135
+ font-family: Arial, sans-serif;
136
+ }
137
+ .title-text {
138
+ text-align: center;
139
+ font-size: 30px;
140
+ font-weight: bold;
141
+ margin-bottom: 10px;
142
+ }
143
+ .description-text {
144
+ text-align: center;
145
+ font-size: 18px;
146
+ margin-bottom: 20px;
147
+ }
148
+ """) as demo:
149
+ # Centered Title and Description
150
+ with gr.Column():
151
+ gr.Markdown(
152
+ """
153
+ <div style="text-align: center; font-size: 30px; font-weight: bold; margin-bottom: 10px;">
154
+ Denoising Demo
155
+ </div>
156
+ <div style="text-align: center; font-size: 18px;">
157
+ Upload a Raw file or select an example to view the original and denoised images
158
+ </div>
159
+ """
160
+ )
161
+
162
+ file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format")
163
+ auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False)
164
+
165
+ x_slider = gr.Slider(0, 512, step=10, label="X Offset")
166
+ y_slider = gr.Slider(0, 512, step=10, label="Y Offset")
167
+
168
+ with gr.Row():
169
+ denoise_button = gr.Button("Denoise")
170
+ clear_button = gr.Button("Clear")
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ original_image = gr.Image(type="pil", label="Original Image")
175
+ with gr.Column():
176
+ denoised_image = gr.Image(type="pil", label="Denoised Image")
177
+
178
+ active_data = gr.State()
179
+
180
+ def load_image_and_update_sliders(file_path) -> tuple:
181
+ data = load_data(file_path)
182
+ h, w = data.shape[:2]
183
+ return data, gr.update(maximum=w-1024), gr.update(maximum=h-1024)
184
+
185
+
186
+ file_input.clear(
187
+ clear_images,
188
+ inputs=None,
189
+ outputs=[original_image, denoised_image, active_data, x_slider, y_slider]
190
+ )
191
+
192
+ file_input.change(
193
+ lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=512), gr.update(maximum=512)),
194
+ inputs=file_input,
195
+ outputs=[active_data, x_slider, y_slider]
196
+ )
197
+
198
+ x_slider.change(
199
+ display_crop,
200
+ inputs=[active_data, x_slider, y_slider, auto_contrast],
201
+ outputs=original_image
202
+ )
203
+
204
+ y_slider.change(
205
+ display_crop,
206
+ inputs=[active_data, x_slider, y_slider, auto_contrast],
207
+ outputs=original_image
208
+ )
209
+
210
+ denoise_button.click(
211
+ process_and_denoise,
212
+ inputs=[active_data, x_slider, y_slider, auto_contrast],
213
+ outputs=denoised_image
214
+ )
215
+
216
+ clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider])
217
+
218
  demo.launch()
219
+
draco.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NAME: DracoDenoiseAutoencoder
3
+ DEVICE: cuda
4
+
5
+ IMG_SIZE: 1024
6
+ PATCH_SIZE: 32
7
+ IN_CHANS: 1
8
+ VIT_SCALE: base
9
+ DYNAMIC_IMG_SIZE: True
10
+ DYNAMIC_IMG_PAD: True
11
+ DECODER_EMBED_DIM: 512
12
+ DECODER_DEPTH: 8
13
+ DECODER_NUM_HEADS: 16
14
+ DECODER_USE_NECK: True
15
+ DECODER_NECK_DIM: 256
16
+ USE_ABS_POS: true
17
+ USE_DECODER_NECK: True
18
+ WINDOW_SIZE: 28
19
+ DECODER_GLOBAL_ATTN_INDEXES: [3, 7]
draco/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ import draco.model
draco/configuration/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .config import CfgNode
2
+ from .configurable import configurable
3
+
4
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
draco/configuration/base.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATALOADER:
2
+ BATCH_SIZE: 0
3
+ NUM_WORKERS: 0
4
+ PIN_MEMORY: False
5
+ DROP_LAST: False
6
+ PERSISTENT_WORKERS: False
7
+
8
+ DATASET:
9
+ NAME: null
10
+
11
+ TRANSFORM:
12
+ NAME: null
13
+
14
+ MODEL:
15
+ NAME: null
16
+ DEVICE: cuda
17
+
18
+ METRIC:
19
+ NAME: null
20
+ TYPE: null
21
+
22
+ MODULE:
23
+ NAME: null
24
+ COMPILE: False
25
+
26
+ OPTIMIZER:
27
+ NAME: null
28
+
29
+ SCHEDULER:
30
+ NAME: null
31
+
32
+ TRAINER:
33
+ STRATEGY: auto # Set to `auto`, `ddp`, `deepspeed_stage_2`, `deepspeed_stage_3` ...
34
+ MIXED_PRECISION: False
35
+ CHECKPOINT:
36
+ EVERY_N_EPOCHS: 10
37
+
38
+ SAVE_BEST: False # If True, monitor will be required
39
+ MONITOR: null
40
+ MONITOR_MODE: min # Set to `min` or `max`
41
+
42
+ MAX_EPOCHS: -1 # If profiler is enabled, this will be *automatically* set to 1
43
+ LOG_EVERY_N_STEPS: 1
44
+ ACCUMULATE_GRAD_BATCHES: 1
45
+
46
+ CLIP_GRAD:
47
+ ALGORITHM: null
48
+ VALUE: null
49
+
50
+ DETERMINISTIC: False # Set to True to enable cudnn.deterministic
51
+ BENCHMARK: False # Set to True to enable cudnn.benchmark
52
+ PROFILER: null # Set to `advanced` or `pytorch` to enable profiling
53
+ DETECT_ANOMALY: False # Set to True to enable anomaly detection
54
+ SYNC_BATCHNORM: False # Set to True to enable sync batchnorm
55
+
56
+ SEED: null
57
+ OUTPUT_DIR: null
draco/configuration/config.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Any
3
+
4
+ from omegaconf import DictConfig, OmegaConf
5
+
6
+ BASE_KEY = "_BASE_"
7
+ ROOT_KEY = "cfg"
8
+
9
+ class CfgNode(OmegaConf):
10
+ """
11
+ A wrapper around OmegaConf that provides some additional functionality.
12
+ """
13
+
14
+ @staticmethod
15
+ def load_yaml_with_base(filename: str) -> DictConfig:
16
+ cfg = OmegaConf.load(filename)
17
+
18
+ def _load_with_base(base_cfg_file: str) -> dict[str, Any]:
19
+ if base_cfg_file.startswith("~"):
20
+ base_cfg_file = os.path.expanduser(base_cfg_file)
21
+ if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])):
22
+ # the path to base cfg is relative to the config file itself.
23
+ base_cfg_file = os.path.join(os.path.dirname(filename), base_cfg_file)
24
+ return CfgNode.load_yaml_with_base(base_cfg_file)
25
+
26
+ if BASE_KEY in cfg:
27
+ if isinstance(cfg[BASE_KEY], list):
28
+ base_cfg: dict[str, Any] = {}
29
+ base_cfg_files = cfg[BASE_KEY]
30
+ for base_cfg_file in base_cfg_files:
31
+ base_cfg = CfgNode.merge(base_cfg, _load_with_base(base_cfg_file))
32
+ else:
33
+ base_cfg_file = cfg[BASE_KEY]
34
+ base_cfg = _load_with_base(base_cfg_file)
35
+ del cfg[BASE_KEY]
36
+
37
+ base_cfg = CfgNode.merge(base_cfg, cfg)
38
+ return base_cfg
39
+
40
+ if ROOT_KEY in cfg:
41
+ return cfg[ROOT_KEY]
42
+ return cfg
43
+
44
+ @staticmethod
45
+ def merge_with_dotlist(cfg: DictConfig, dotlist: list[str]) -> None:
46
+ if len(dotlist) == 0:
47
+ return
48
+
49
+ new_dotlist = []
50
+ for key, value in zip(dotlist[::2], dotlist[1::2]):
51
+ new_dotlist.append(f"{key}={value}")
52
+ cfg.merge_with_dotlist(new_dotlist)
draco/configuration/configurable.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ from typing import Any, Callable
4
+
5
+ from omegaconf import DictConfig
6
+
7
+ __all__ = ["configurable"]
8
+
9
+
10
+ def _called_with_cfg(*args, **kwargs) -> bool:
11
+ """
12
+ Check if the function is called with a `DictConfig` as the first argument.
13
+
14
+ Returns:
15
+ (bool): whether the function is called with a `DictConfig` as the first argument.
16
+ Or the `cfg` keyword argument is a `DictConfig`.
17
+ """
18
+
19
+ if len(args) > 0 and isinstance(args[0], DictConfig):
20
+ return True
21
+ if isinstance(kwargs.get("cfg", None), DictConfig):
22
+ return True
23
+ return False
24
+
25
+
26
+ def _get_args_from_cfg(from_config_func: Callable[[Any], dict[str, Any]], *args, **kwargs) -> dict[str, Any]:
27
+ """
28
+ Get the input arguments of the decorated function from a `DictConfig` object.
29
+
30
+ Returns:
31
+ (dict): The input arguments of the class `__init__` method.
32
+ """
33
+
34
+ signature = inspect.signature(from_config_func)
35
+ if list(signature.parameters.keys())[0] != "cfg":
36
+ raise ValueError("The first argument of `{}` must be named as `cfg`.".format(from_config_func.__name__))
37
+
38
+ # Forwarding all arguments to `from_config`, if the arguments of `from_config` are only `*args` or `*kwargs`.
39
+ if any(param.kind in [param.VAR_POSITIONAL or param.VAR_KEYWORD] for param in signature.parameters.values()):
40
+ result = from_config_func(*args, **kwargs)
41
+
42
+ # If there is any positional arguments.
43
+ else:
44
+ positional_args_name = set(signature.parameters.keys())
45
+ extra_kwargs = {}
46
+ for name in kwargs.keys():
47
+ if name not in positional_args_name:
48
+ extra_kwargs[name] = kwargs.pop(name)
49
+ result = from_config_func(*args, **kwargs)
50
+ # These args are forwarded directly to `__init__` method.
51
+ result.update(extra_kwargs)
52
+
53
+ return result
54
+
55
+
56
+ def configurable(init_func: Callable = None, *, from_config: Callable[[Any], dict[str, Any]] | None = None) -> Callable:
57
+ """
58
+ A decorator of a function or a class `__init__` method,
59
+ to make it configurable by a `DictConfig` object.
60
+
61
+ Example:
62
+ ```python
63
+ # 1. Decorate a function.
64
+ @configurable(from_config=lambda cfg: { "x": cfg.x })
65
+ def func(x, y=2, z=3):
66
+ pass
67
+
68
+ a1 = func(x=1, y=2) # Call with regular args.
69
+ a2 = func(cfg) # Call with a `DictConfig` object.
70
+ a3 = func(cfg, y=2, z=3) # Call with a `DictConfig` object and regular arguments.
71
+
72
+ # 2. Decorate a class `__init__` method.
73
+ class A:
74
+ @configurable
75
+ def __init__(self, *args, **kwargs) -> None:
76
+ pass
77
+
78
+ @classmethod
79
+ def from_config(cls, cfg) -> dict:
80
+ pass
81
+
82
+ a1 = A(x, y) # Call with regular constructor.
83
+ a2 = A(cfg) # Call with a `DictConfig` object.
84
+ a3 = A(cfg, x, y) # Call with a `DictConfig` object and regular arguments.
85
+ ```
86
+
87
+ Args:
88
+ `init_func` (callable): a function or a class method.
89
+ `from_config` (callable): a function that converts a `DictConfig` to the
90
+ input arguments of the decorated function.
91
+ It is always required.
92
+ """
93
+
94
+ # Decorating a function
95
+ if init_func is None:
96
+ # Prevent common misuse: `@configurable()`.
97
+ if from_config is None:
98
+ return configurable
99
+
100
+ assert inspect.isfunction(from_config), "`from_config` must be a function."
101
+
102
+ def wrapper(func):
103
+ @functools.wraps(func)
104
+ def wrapped(*args, **kwargs):
105
+ if _called_with_cfg(*args, **kwargs):
106
+ explicit_args = _get_args_from_cfg(from_config, *args, **kwargs)
107
+ return func(**explicit_args)
108
+ else:
109
+ return func(*args, **kwargs)
110
+
111
+ wrapped.from_config = from_config
112
+ return wrapped
113
+
114
+ return wrapper
115
+
116
+ # Decorating a class `__init__` method
117
+ else:
118
+ assert(
119
+ inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__"
120
+ ), "Invalid usage of @configurable."
121
+
122
+ @functools.wraps(init_func)
123
+ def wrapped(self, *args, **kwargs):
124
+ try:
125
+ from_config_func = getattr(self, "from_config")
126
+ except AttributeError as e:
127
+ raise AttributeError("Class with `@configurable` should have a `from_config` classmethod.") from e
128
+
129
+ if not inspect.ismethod(from_config_func):
130
+ raise AttributeError("Class with `@configurable` should have a `from_config` classmethod.")
131
+
132
+ if _called_with_cfg(*args, **kwargs):
133
+ explicit_args = _get_args_from_cfg(from_config_func, *args, **kwargs)
134
+ init_func(self, **explicit_args)
135
+ else:
136
+ init_func(self, *args, **kwargs)
137
+
138
+ return wrapped
draco/configuration/draco2d-b_triplet_pretrain.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: base.yaml
2
+
3
+ MODEL:
4
+ NAME: DenoisingReconstructionAutoencoderVisionTransformer2d
5
+
6
+ IMG_SIZE: 256
7
+ PATCH_SIZE: 16
8
+ IN_CHANS: 1
9
+ VIT_SCALE: base
10
+ DYNAMIC_IMG_SIZE: False
11
+ DYNAMIC_IMG_PAD: False
12
+ USE_ABS_POS: True
13
+ DECODER_EMBED_DIM: 512
14
+ DECODER_DEPTH: 8
15
+ DECODER_NUM_HEADS: 16
16
+ DECODER_USE_NECK: True
17
+ DECODER_NECK_DIM: 256
18
+
19
+ SEED: 0
20
+ OUTPUT_DIR: null
draco/configuration/draco2d-h_triplet_pretrain.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _BASE_: draco-b_imagenet_pretrain.yaml
2
+
3
+ MODEL:
4
+ PATCH_SIZE: 14
5
+ VIT_SCALE: huge
draco/configuration/draco2d-l_triplet_pretrain.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _BASE_: draco-b_imagenet_pretrain.yaml
2
+
3
+ MODEL:
4
+ VIT_SCALE: large
draco/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .build import MODEL_REGISTRY, build_model
2
+ from .checkpoint import load_pretrained
3
+
4
+ from .draco2d import DenoisingReconstructionAutoencoderVisionTransformer2d, DracoDenoiseAutoencoder
5
+
6
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
draco/model/build.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fvcore.common.registry import Registry
2
+ from omegaconf import DictConfig
3
+ import torch
4
+
5
+ __all__ = ["MODEL_REGISTRY", "build_model"]
6
+
7
+
8
+ MODEL_REGISTRY = Registry("MODEL")
9
+ MODEL_REGISTRY.__doc__ = "Registry for the model."
10
+
11
+ def build_model(cfg: DictConfig) -> torch.nn.Module:
12
+ """
13
+ Build the model defined by `cfg.MODEL.NAME`.
14
+ It moves the model to the device defined by `cfg.MODEL.DEVICE`.
15
+ It does not load checkpoints from `cfg`.
16
+ """
17
+ model_name = cfg.MODEL.NAME
18
+ try:
19
+ model = MODEL_REGISTRY.get(model_name)(cfg)
20
+ except KeyError as e:
21
+ raise KeyError(MODEL_REGISTRY) from e
22
+ return model
draco/model/checkpoint.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+
7
+ def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None:
8
+ """
9
+ Strip the prefix in metadata, if any.
10
+
11
+ Args:
12
+ state_dict (OrderedDict): a state-dict to be loaded to the model.
13
+ prefix (str): prefix.
14
+ """
15
+ keys = sorted(state_dict.keys())
16
+ if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
17
+ return
18
+
19
+ for key in keys:
20
+ newkey = key[len(prefix) :]
21
+ state_dict[newkey] = state_dict.pop(key)
22
+
23
+ # also strip the prefix in metadata, if any..
24
+ try:
25
+ metadata = state_dict._metadata # pyre-ignore
26
+ except AttributeError:
27
+ pass
28
+ else:
29
+ for key in list(metadata.keys()):
30
+ # for the metadata dict, the key can be:
31
+ # '': for the DDP module, which we want to remove.
32
+ # 'module': for the actual model.
33
+ # 'module.xx.xx': for the rest.
34
+
35
+ if len(key) == 0:
36
+ continue
37
+ newkey = key[len(prefix) :]
38
+ metadata[newkey] = metadata.pop(key)
39
+
40
+
41
+ def load_pretrained(model: torch.nn.Module, ckpt_path: Path, device: torch.device = "cuda") -> torch.nn.Module:
42
+ """
43
+ Load the pre-trained model from the checkpoint file.
44
+ """
45
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
46
+
47
+ if "state_dict" in ckpt:
48
+ checkpoint_state_dict = ckpt["state_dict"]
49
+ elif "model" in ckpt:
50
+ checkpoint_state_dict = ckpt["model"]
51
+ else:
52
+ checkpoint_state_dict = ckpt
53
+
54
+ _strip_prefix_if_present(checkpoint_state_dict, "module.") # for DistributedDataParallel
55
+ _strip_prefix_if_present(checkpoint_state_dict, "model.") # for PyTorch Lightning Module
56
+ _strip_prefix_if_present(checkpoint_state_dict, "_orig_mod.") # for torch.compile
57
+
58
+ msg = model.load_state_dict(checkpoint_state_dict, strict=False)
59
+ print(f"Loaded pre-trained model from {ckpt_path} with message: {msg}")
60
+
61
+ return model
draco/model/draco2d.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable
3
+
4
+ from omegaconf import DictConfig
5
+ from timm.layers import build_sincos2d_pos_embed, resample_abs_pos_embed_nhwc, PatchEmbed, Mlp, LayerType
6
+ from timm.models.vision_transformer import Block
7
+ from timm.models.vision_transformer_sam import Block as SAMBlock
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from draco.configuration import configurable
12
+ from .build import MODEL_REGISTRY
13
+ from .layer import LayerNorm2d
14
+ from .draco_base import DenoisingReconstructionAutoencoderVisionTransformerBase
15
+ from .utils.constant import get_vit_scale, get_global_attn_indexes
16
+
17
+ __all__ = ["DenoisingReconstructionAutoencoderVisionTransformer2d", "DracoDenoiseAutoencoder"]
18
+
19
+
20
+ @MODEL_REGISTRY.register()
21
+ class DenoisingReconstructionAutoencoderVisionTransformer2d(DenoisingReconstructionAutoencoderVisionTransformerBase):
22
+ @configurable
23
+ def __init__(self, *,
24
+ img_size: int = 224,
25
+ patch_size: int = 16,
26
+ in_chans: int = 3,
27
+ embed_layer: Callable = PatchEmbed,
28
+ dynamic_img_size: bool = False,
29
+ dynamic_img_pad: bool = False,
30
+ use_abs_pos: bool = True,
31
+ block_fn: nn.Module = Block,
32
+ norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
33
+ act_layer: LayerType = nn.GELU,
34
+ mlp_layer: nn.Module = Mlp,
35
+ embed_dim: int = 768,
36
+ depth: int = 12,
37
+ num_heads: int = 12,
38
+ mlp_ratio: float = 4.0,
39
+ qkv_bias: bool = True,
40
+ qk_norm: bool = False,
41
+ decoder_block_fn: nn.Module = Block,
42
+ decoder_norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
43
+ decoder_act_layer: LayerType = nn.GELU,
44
+ decoder_mlp_layer: nn.Module = Mlp,
45
+ decoder_embed_dim: int = 512,
46
+ decoder_depth: int = 8,
47
+ decoder_num_heads: int = 16,
48
+ decoder_use_neck: bool = True,
49
+ decoder_neck_dim: int = 256,
50
+ ) -> None:
51
+ super().__init__()
52
+
53
+ self.dynamic_img_size = dynamic_img_size
54
+ self.decoder_use_neck = decoder_use_neck
55
+
56
+ self.init_encoder(
57
+ img_size=img_size,
58
+ patch_size=patch_size,
59
+ in_chans=in_chans,
60
+ embed_layer=embed_layer,
61
+ dynamic_img_size=dynamic_img_size,
62
+ dynamic_img_pad=dynamic_img_pad,
63
+ use_abs_pos=use_abs_pos,
64
+ block_fn=block_fn,
65
+ norm_layer=norm_layer,
66
+ act_layer=act_layer,
67
+ mlp_layer=mlp_layer,
68
+ embed_dim=embed_dim,
69
+ depth=depth,
70
+ num_heads=num_heads,
71
+ mlp_ratio=mlp_ratio,
72
+ qkv_bias=qkv_bias,
73
+ qk_norm=qk_norm,
74
+ )
75
+ self.init_decoder(
76
+ patch_size=patch_size,
77
+ in_chans=in_chans,
78
+ embed_dim=embed_dim,
79
+ use_abs_pos=use_abs_pos,
80
+ decoder_block_fn=decoder_block_fn,
81
+ decoder_norm_layer=decoder_norm_layer,
82
+ decoder_act_layer=decoder_act_layer,
83
+ decoder_mlp_layer=decoder_mlp_layer,
84
+ decoder_embed_dim=decoder_embed_dim,
85
+ decoder_depth=decoder_depth,
86
+ decoder_num_heads=decoder_num_heads,
87
+ decoder_use_neck=decoder_use_neck,
88
+ decoder_neck_dim=decoder_neck_dim,
89
+ mlp_ratio=mlp_ratio,
90
+ qkv_bias=qkv_bias,
91
+ qk_norm=qk_norm,
92
+ )
93
+ self.init_weights(
94
+ grid_size=self.patch_embed.grid_size,
95
+ embed_dim=embed_dim,
96
+ decoder_embed_dim=decoder_embed_dim,
97
+ )
98
+
99
+ @classmethod
100
+ def from_config(cls, cfg: DictConfig) -> dict[str, Any]:
101
+ embed_dim, depth, num_heads = get_vit_scale(cfg.MODEL.VIT_SCALE)
102
+ return {
103
+ "img_size": cfg.MODEL.IMG_SIZE,
104
+ "patch_size": cfg.MODEL.PATCH_SIZE,
105
+ "in_chans": cfg.MODEL.IN_CHANS,
106
+ "dynamic_img_size": cfg.MODEL.DYNAMIC_IMG_SIZE,
107
+ "dynamic_img_pad": cfg.MODEL.DYNAMIC_IMG_PAD,
108
+ "use_abs_pos": cfg.MODEL.USE_ABS_POS,
109
+ "embed_dim": embed_dim,
110
+ "depth": depth,
111
+ "num_heads": num_heads,
112
+ "decoder_embed_dim": cfg.MODEL.DECODER_EMBED_DIM,
113
+ "decoder_depth": cfg.MODEL.DECODER_DEPTH,
114
+ "decoder_num_heads": cfg.MODEL.DECODER_NUM_HEADS,
115
+ "decoder_use_neck": cfg.MODEL.DECODER_USE_NECK,
116
+ "decoder_neck_dim": cfg.MODEL.DECODER_NECK_DIM,
117
+ }
118
+
119
+ def init_encoder(self, *,
120
+ img_size: int,
121
+ patch_size: int,
122
+ in_chans: int,
123
+ embed_layer: Callable,
124
+ dynamic_img_size: bool,
125
+ dynamic_img_pad: bool,
126
+ use_abs_pos: bool,
127
+ block_fn: nn.Module,
128
+ norm_layer: LayerType | None,
129
+ act_layer: LayerType | None,
130
+ mlp_layer: nn.Module,
131
+ embed_dim: int,
132
+ depth: int,
133
+ num_heads: int,
134
+ mlp_ratio: float,
135
+ qkv_bias: bool,
136
+ qk_norm: bool,
137
+ ) -> None:
138
+ embed_args = {}
139
+ if dynamic_img_size:
140
+ embed_args.update(dict(strict_img_size=False))
141
+ self.patch_embed = embed_layer(
142
+ img_size=img_size,
143
+ patch_size=patch_size,
144
+ in_chans=in_chans,
145
+ embed_dim=embed_dim,
146
+ dynamic_img_pad=dynamic_img_pad,
147
+ output_fmt="NHWC",
148
+ **embed_args
149
+ )
150
+
151
+ self.pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, embed_dim)) if use_abs_pos else None
152
+ self.blocks = nn.ModuleList([
153
+ block_fn(
154
+ dim=embed_dim,
155
+ num_heads=num_heads,
156
+ mlp_ratio=mlp_ratio,
157
+ qkv_bias=qkv_bias,
158
+ qk_norm=qk_norm,
159
+ norm_layer=norm_layer,
160
+ act_layer=act_layer,
161
+ mlp_layer=mlp_layer,
162
+ ) for _ in range(depth)
163
+ ])
164
+ self.norm = norm_layer(embed_dim)
165
+
166
+ def init_decoder(self, *,
167
+ patch_size: int,
168
+ in_chans: int,
169
+ embed_dim: int,
170
+ use_abs_pos: bool,
171
+ decoder_block_fn: nn.Module,
172
+ decoder_norm_layer: LayerType | None,
173
+ decoder_act_layer: LayerType | None,
174
+ decoder_mlp_layer: nn.Module,
175
+ decoder_embed_dim: int,
176
+ decoder_depth: int,
177
+ decoder_num_heads: int,
178
+ decoder_use_neck: bool,
179
+ decoder_neck_dim: int,
180
+ mlp_ratio: float,
181
+ qkv_bias: bool,
182
+ qk_norm: bool,
183
+ ) -> None:
184
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
185
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
186
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, decoder_embed_dim)) if use_abs_pos else None
187
+ self.decoder_blocks = nn.ModuleList([
188
+ decoder_block_fn(
189
+ dim=decoder_embed_dim,
190
+ num_heads=decoder_num_heads,
191
+ mlp_ratio=mlp_ratio,
192
+ qkv_bias=qkv_bias,
193
+ qk_norm=qk_norm,
194
+ norm_layer=decoder_norm_layer,
195
+ act_layer=decoder_act_layer,
196
+ mlp_layer=decoder_mlp_layer,
197
+ ) for _ in range(decoder_depth)
198
+ ])
199
+ self.decoder_norm = decoder_norm_layer(decoder_embed_dim)
200
+ if decoder_use_neck:
201
+ self.decoder_neck = nn.Sequential(
202
+ nn.Conv2d(
203
+ in_channels=decoder_embed_dim,
204
+ out_channels=decoder_neck_dim,
205
+ kernel_size=1,
206
+ bias=False,
207
+ ),
208
+ LayerNorm2d(decoder_neck_dim),
209
+ decoder_act_layer(),
210
+ nn.Conv2d(
211
+ in_channels=decoder_neck_dim,
212
+ out_channels=decoder_neck_dim,
213
+ kernel_size=3,
214
+ padding=1,
215
+ bias=False,
216
+ ),
217
+ LayerNorm2d(decoder_neck_dim),
218
+ decoder_act_layer(),
219
+ nn.Conv2d(
220
+ in_channels=decoder_neck_dim,
221
+ out_channels=decoder_embed_dim,
222
+ kernel_size=1,
223
+ bias=False,
224
+ ),
225
+ LayerNorm2d(decoder_embed_dim),
226
+ )
227
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans)
228
+
229
+ def init_weights(self, *,
230
+ grid_size: tuple[int, int],
231
+ embed_dim: int,
232
+ decoder_embed_dim: int
233
+ ) -> None:
234
+ w = self.patch_embed.proj.weight.data
235
+ torch.nn.init.xavier_uniform_(w.view(w.size(0), -1))
236
+
237
+ torch.nn.init.normal_(self.mask_token, std=0.02)
238
+
239
+ if self.pos_embed is not None:
240
+ self.pos_embed.data.copy_(build_sincos2d_pos_embed(
241
+ feat_shape=grid_size,
242
+ dim=embed_dim,
243
+ interleave_sin_cos=True
244
+ ).reshape(1, *grid_size, -1).transpose(1, 2))
245
+
246
+ if self.decoder_pos_embed is not None:
247
+ self.decoder_pos_embed.data.copy_(build_sincos2d_pos_embed(
248
+ feat_shape=grid_size,
249
+ dim=decoder_embed_dim,
250
+ interleave_sin_cos=True
251
+ ).reshape(1, *grid_size, -1).transpose(1, 2))
252
+
253
+ if self.decoder_use_neck:
254
+ for m in self.decoder_neck.modules():
255
+ if isinstance(m, nn.Conv2d):
256
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
257
+ if m.bias is not None:
258
+ nn.init.zeros_(m.bias)
259
+ nn.init.zeros_(self.decoder_neck[-1].weight)
260
+ nn.init.zeros_(self.decoder_neck[-1].bias)
261
+
262
+ self.apply(self._init_weights)
263
+
264
+ def _init_weights(self, module: nn.Module) -> None:
265
+ if isinstance(module, nn.Linear):
266
+ nn.init.xavier_uniform_(module.weight)
267
+ if module.bias is not None:
268
+ nn.init.constant_(module.bias, 0.0)
269
+
270
+ def forward_encoder(self, x: torch.Tensor, mask_ratio: float) -> tuple[torch.Tensor, torch.BoolTensor, int, int]:
271
+ x = self.patch_embed(x)
272
+ B, H, W, E = x.shape
273
+ if self.pos_embed is not None:
274
+ x = x + resample_abs_pos_embed_nhwc(self.pos_embed, (H, W))
275
+ x = x.view(B, -1, E)
276
+
277
+ mask = super().random_masking(x, mask_ratio)
278
+ x = x[~mask].reshape(B, -1, E)
279
+
280
+ for block in self.blocks:
281
+ x = block(x)
282
+ x = self.norm(x)
283
+
284
+ return x, mask, H, W
285
+
286
+ def forward_decoder(self, x: torch.Tensor, mask: torch.BoolTensor, H: int, W: int) -> torch.Tensor:
287
+ x = self.decoder_embed(x)
288
+
289
+ B, L = mask.shape
290
+ E = x.shape[-1]
291
+ mask_tokens = self.mask_token.repeat(B, L, 1).to(x.dtype)
292
+ mask_tokens[~mask] = x.reshape(-1, E)
293
+ x = mask_tokens
294
+
295
+ if self.decoder_pos_embed is not None:
296
+ x = x.view(B, H, W, E)
297
+ x = x + resample_abs_pos_embed_nhwc(self.decoder_pos_embed, (H, W))
298
+ x = x.view(B, -1, E)
299
+
300
+ for block in self.decoder_blocks:
301
+ x = block(x)
302
+ x = self.decoder_norm(x)
303
+ if self.decoder_use_neck:
304
+ x = x + self.decoder_neck(
305
+ x.permute(0, 2, 1).reshape(B, E, H, W).contiguous()
306
+ ).permute(0, 2, 3, 1).reshape(B, L, -1).contiguous()
307
+ x = self.decoder_pred(x)
308
+
309
+ return x
310
+
311
+ def forward(self, x: torch.Tensor, mask_ratio: float) -> tuple[torch.Tensor, torch.BoolTensor]:
312
+ x, mask, H, W = self.forward_encoder(x, mask_ratio)
313
+ x = self.forward_decoder(x, mask, H, W)
314
+ return x, mask
315
+
316
+ @MODEL_REGISTRY.register()
317
+ class DracoDenoiseAutoencoder(DenoisingReconstructionAutoencoderVisionTransformerBase):
318
+ """
319
+ Masked Autoencoder (MAE) with Vision Transformer backbone.
320
+ Note that `cls_token` is discarded.
321
+ """
322
+
323
+ @configurable
324
+ def __init__(self, *,
325
+ img_size: int = 224,
326
+ patch_size: int = 16,
327
+ in_chans: int = 3,
328
+ embed_layer: Callable = PatchEmbed,
329
+ dynamic_img_size: bool = False,
330
+ dynamic_img_pad: bool = False,
331
+ use_abs_pos: bool = True,
332
+ block_fn: nn.Module = SAMBlock,
333
+ norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
334
+ act_layer: LayerType = nn.GELU,
335
+ mlp_layer: nn.Module = Mlp,
336
+ embed_dim: int = 768,
337
+ depth: int = 12,
338
+ num_heads: int = 12,
339
+ mlp_ratio: float = 4.0,
340
+ qkv_bias: bool = True,
341
+ qk_norm: bool = False,
342
+ window_size: int = 16,
343
+ global_attn_indexes: list[int] = [2, 5, 8, 11],
344
+ decoder_block_fn: nn.Module = SAMBlock,
345
+ decoder_norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
346
+ decoder_act_layer: LayerType = nn.GELU,
347
+ decoder_mlp_layer: nn.Module = Mlp,
348
+ decoder_embed_dim: int = 512,
349
+ decoder_depth: int = 8,
350
+ decoder_num_heads: int = 16,
351
+ decoder_use_neck: bool = True,
352
+ decoder_neck_dim: int = 256,
353
+ decoder_global_attn_indexes: list[int] = [3, 7],
354
+ ) -> None:
355
+ super().__init__()
356
+
357
+ self.dynamic_img_size = dynamic_img_size
358
+ self.decoder_use_neck = decoder_use_neck
359
+
360
+ self.init_encoder(
361
+ img_size=img_size,
362
+ patch_size=patch_size,
363
+ in_chans=in_chans,
364
+ embed_layer=embed_layer,
365
+ dynamic_img_size=dynamic_img_size,
366
+ dynamic_img_pad=dynamic_img_pad,
367
+ use_abs_pos=use_abs_pos,
368
+ block_fn=block_fn,
369
+ norm_layer=norm_layer,
370
+ act_layer=act_layer,
371
+ mlp_layer=mlp_layer,
372
+ embed_dim=embed_dim,
373
+ depth=depth,
374
+ num_heads=num_heads,
375
+ mlp_ratio=mlp_ratio,
376
+ qkv_bias=qkv_bias,
377
+ qk_norm=qk_norm,
378
+ window_size=window_size,
379
+ global_attn_indexes=global_attn_indexes
380
+ )
381
+ self.init_decoder(
382
+ img_size=img_size,
383
+ patch_size=patch_size,
384
+ in_chans=in_chans,
385
+ embed_dim=embed_dim,
386
+ use_abs_pos=use_abs_pos,
387
+ decoder_block_fn=decoder_block_fn,
388
+ decoder_norm_layer=decoder_norm_layer,
389
+ decoder_act_layer=decoder_act_layer,
390
+ decoder_mlp_layer=decoder_mlp_layer,
391
+ decoder_embed_dim=decoder_embed_dim,
392
+ decoder_depth=decoder_depth,
393
+ decoder_num_heads=decoder_num_heads,
394
+ decoder_use_neck=decoder_use_neck,
395
+ decoder_neck_dim=decoder_neck_dim,
396
+ mlp_ratio=mlp_ratio,
397
+ qkv_bias=qkv_bias,
398
+ qk_norm=qk_norm,
399
+ window_size=window_size,
400
+ decoder_global_attn_indexes=decoder_global_attn_indexes
401
+ )
402
+ self.init_weights(
403
+ grid_size=self.patch_embed.grid_size,
404
+ embed_dim=embed_dim,
405
+ decoder_embed_dim=decoder_embed_dim,
406
+ )
407
+
408
+ @classmethod
409
+ def from_config(cls, cfg: DictConfig) -> dict[str, Any]:
410
+ embed_dim, depth, num_heads = get_vit_scale(cfg.MODEL.VIT_SCALE)
411
+ global_attn_indexes = get_global_attn_indexes(depth)
412
+ return {
413
+ "img_size": cfg.MODEL.IMG_SIZE,
414
+ "patch_size": cfg.MODEL.PATCH_SIZE,
415
+ "in_chans": cfg.MODEL.IN_CHANS,
416
+ "dynamic_img_size": cfg.MODEL.DYNAMIC_IMG_SIZE,
417
+ "dynamic_img_pad": cfg.MODEL.DYNAMIC_IMG_PAD,
418
+ "use_abs_pos": cfg.MODEL.USE_ABS_POS,
419
+ "embed_dim": embed_dim,
420
+ "depth": depth,
421
+ "num_heads": num_heads,
422
+ "window_size": cfg.MODEL.WINDOW_SIZE,
423
+ "global_attn_indexes": global_attn_indexes,
424
+ "decoder_embed_dim": cfg.MODEL.DECODER_EMBED_DIM,
425
+ "decoder_depth": cfg.MODEL.DECODER_DEPTH,
426
+ "decoder_num_heads": cfg.MODEL.DECODER_NUM_HEADS,
427
+ "decoder_use_neck": cfg.MODEL.DECODER_USE_NECK,
428
+ "decoder_neck_dim": cfg.MODEL.DECODER_NECK_DIM,
429
+ "decoder_global_attn_indexes": cfg.MODEL.DECODER_GLOBAL_ATTN_INDEXES,
430
+ }
431
+
432
+ def init_encoder(self, *,
433
+ img_size: int,
434
+ patch_size: int,
435
+ in_chans: int,
436
+ embed_layer: Callable,
437
+ dynamic_img_size: bool,
438
+ dynamic_img_pad: bool,
439
+ use_abs_pos: bool,
440
+ block_fn: nn.Module,
441
+ norm_layer: LayerType | None,
442
+ act_layer: LayerType | None,
443
+ mlp_layer: nn.Module,
444
+ embed_dim: int,
445
+ depth: int,
446
+ num_heads: int,
447
+ mlp_ratio: float,
448
+ qkv_bias: bool,
449
+ qk_norm: bool,
450
+ window_size: int,
451
+ global_attn_indexes: list,
452
+ ) -> None:
453
+ embed_args = {}
454
+ if dynamic_img_size:
455
+ # flatten deferred until after pos embed
456
+ embed_args.update(dict(strict_img_size=False))
457
+ self.patch_embed = embed_layer(
458
+ img_size=img_size,
459
+ patch_size=patch_size,
460
+ in_chans=in_chans,
461
+ embed_dim=embed_dim,
462
+ dynamic_img_pad=dynamic_img_pad,
463
+ output_fmt="NHWC",
464
+ **embed_args
465
+ )
466
+
467
+ self.pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, embed_dim)) if use_abs_pos else None
468
+ self.blocks = nn.ModuleList(
469
+ block_fn(
470
+ dim=embed_dim,
471
+ num_heads=num_heads,
472
+ mlp_ratio=mlp_ratio,
473
+ qkv_bias=qkv_bias,
474
+ qk_norm=qk_norm,
475
+ norm_layer=norm_layer,
476
+ act_layer=act_layer,
477
+ mlp_layer=mlp_layer,
478
+ use_rel_pos=True,
479
+ window_size=window_size if i not in global_attn_indexes else 0,
480
+ input_size=(img_size // patch_size, img_size // patch_size),
481
+ ) for i in range(depth)
482
+ )
483
+
484
+ self.norm = norm_layer(embed_dim)
485
+
486
+ def init_decoder(self, *,
487
+ img_size: int,
488
+ patch_size: int,
489
+ in_chans: int,
490
+ embed_dim: int,
491
+ use_abs_pos: bool,
492
+ decoder_block_fn: nn.Module,
493
+ decoder_norm_layer: LayerType | None,
494
+ decoder_act_layer: LayerType | None,
495
+ decoder_mlp_layer: nn.Module,
496
+ decoder_embed_dim: int,
497
+ decoder_depth: int,
498
+ decoder_num_heads: int,
499
+ decoder_use_neck: bool,
500
+ decoder_neck_dim: int,
501
+ mlp_ratio: float,
502
+ qkv_bias: bool,
503
+ qk_norm: bool,
504
+ window_size: int,
505
+ decoder_global_attn_indexes: list[int]
506
+ ) -> None:
507
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
508
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, decoder_embed_dim)) if use_abs_pos else None
509
+ self.decoder_blocks = nn.ModuleList(
510
+ decoder_block_fn(
511
+ dim=decoder_embed_dim,
512
+ num_heads=decoder_num_heads,
513
+ mlp_ratio=mlp_ratio,
514
+ qkv_bias=qkv_bias,
515
+ qk_norm=qk_norm,
516
+ norm_layer=decoder_norm_layer,
517
+ act_layer=decoder_act_layer,
518
+ mlp_layer=decoder_mlp_layer,
519
+ use_rel_pos=True,
520
+ window_size=window_size if i not in decoder_global_attn_indexes else 0,
521
+ input_size=(img_size // patch_size, img_size // patch_size),
522
+ ) for i in range(decoder_depth)
523
+ )
524
+ self.decoder_norm = decoder_norm_layer(decoder_embed_dim)
525
+ if decoder_use_neck:
526
+ self.decoder_neck = nn.Sequential(
527
+ nn.Conv2d(
528
+ in_channels=decoder_embed_dim,
529
+ out_channels=decoder_neck_dim,
530
+ kernel_size=1,
531
+ bias=False,
532
+ ),
533
+ LayerNorm2d(decoder_neck_dim),
534
+ decoder_act_layer(),
535
+ nn.Conv2d(
536
+ in_channels=decoder_neck_dim,
537
+ out_channels=decoder_neck_dim,
538
+ kernel_size=3,
539
+ padding=1,
540
+ bias=False,
541
+ ),
542
+ LayerNorm2d(decoder_neck_dim),
543
+ decoder_act_layer(),
544
+ nn.Conv2d(
545
+ in_channels=decoder_neck_dim,
546
+ out_channels=decoder_embed_dim,
547
+ kernel_size=1,
548
+ bias=False,
549
+ ),
550
+ LayerNorm2d(decoder_embed_dim),
551
+ )
552
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans)
553
+
554
+ def init_weights(self, *,
555
+ grid_size: tuple[int, int],
556
+ embed_dim: int,
557
+ decoder_embed_dim: int
558
+ ) -> None:
559
+ w = self.patch_embed.proj.weight.data
560
+ torch.nn.init.xavier_uniform_(w.view(w.size(0), -1))
561
+
562
+ if self.pos_embed is not None:
563
+ self.pos_embed.data.copy_(build_sincos2d_pos_embed(
564
+ feat_shape=grid_size,
565
+ dim=embed_dim,
566
+ interleave_sin_cos=True
567
+ ).reshape(1, *grid_size, -1).transpose(1, 2))
568
+
569
+ if self.decoder_pos_embed is not None:
570
+ self.decoder_pos_embed.data.copy_(build_sincos2d_pos_embed(
571
+ feat_shape=grid_size,
572
+ dim=decoder_embed_dim,
573
+ interleave_sin_cos=True
574
+ ).reshape(1, *grid_size, -1).transpose(1, 2))
575
+
576
+ # Zero-initialize the neck
577
+ if self.decoder_use_neck:
578
+ for m in self.decoder_neck.modules():
579
+ if isinstance(m, nn.Conv2d):
580
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
581
+ if m.bias is not None:
582
+ nn.init.zeros_(m.bias)
583
+ nn.init.zeros_(self.decoder_neck[-1].weight)
584
+ nn.init.zeros_(self.decoder_neck[-1].bias)
585
+
586
+ self.apply(self._init_weights)
587
+
588
+ def _init_weights(self, module: nn.Module) -> None:
589
+ if isinstance(module, nn.Linear):
590
+ nn.init.xavier_uniform_(module.weight)
591
+ if module.bias is not None:
592
+ nn.init.constant_(module.bias, 0.0)
593
+
594
+ def forward_encoder(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
595
+ """
596
+ Forward pass of the encoder.
597
+
598
+ Args:
599
+ `x` (torch.Tensor): Image of shape [B, C, H, W].
600
+
601
+ Returns:
602
+ (torch.Tensor): Encoded image of shape [B, num_kept, E].
603
+ (int): Height of the encoded tokens.
604
+ (int): Width of the encoded tokens.
605
+ """
606
+ x = self.patch_embed(x)
607
+ B, H, W, E = x.shape
608
+
609
+ if self.pos_embed is not None:
610
+ x = x + resample_abs_pos_embed_nhwc(self.pos_embed, (H, W))
611
+
612
+ for block in self.blocks:
613
+ x = block(x)
614
+
615
+ x = x.view(B, -1, E)
616
+ x = self.norm(x)
617
+
618
+ return x, H, W
619
+
620
+ def forward_decoder(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
621
+ """
622
+ Forward pass of the decoder.
623
+
624
+ Args:
625
+ `x` (torch.Tensor): Encoded image of shape [B, num_kept, E].
626
+ `H` (int): Height of the encoded tokens.
627
+ `W` (int): Width of the encoded tokens.
628
+
629
+ Returns:
630
+ (torch.Tensor): Decoded image of shape [B, L, E].
631
+ """
632
+ x = self.decoder_embed(x) # [B, num_kept, E]
633
+ B, L, E = x.shape
634
+
635
+ if self.decoder_pos_embed is not None:
636
+ x = x.view(B, H, W, E)
637
+ x = x + resample_abs_pos_embed_nhwc(self.decoder_pos_embed, (H, W))
638
+
639
+ for block in self.decoder_blocks:
640
+ x = block(x)
641
+ x = x.view(B, -1, E)
642
+
643
+ x = self.decoder_norm(x)
644
+ if self.decoder_use_neck:
645
+ x = x + self.decoder_neck(
646
+ x.permute(0, 2, 1).reshape(B, E, H, W).contiguous()
647
+ ).permute(0, 2, 3, 1).reshape(B, L, -1).contiguous()
648
+ x = self.decoder_pred(x)
649
+
650
+ return x
651
+
652
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
653
+ """
654
+ Args:
655
+ `x` (torch.Tensor): Image of shape [B, C, H, W].
656
+
657
+ Returns:
658
+ (torch.Tensor): The prediction of shape [B, L, E].
659
+ """
660
+ x, H, W = self.forward_encoder(x)
661
+ x = self.forward_decoder(x, H, W)
662
+ return x
663
+
draco/model/draco_base.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class DenoisingReconstructionAutoencoderVisionTransformerBase(nn.Module, metaclass=ABCMeta):
8
+ def __init__(self) -> None:
9
+ super().__init__()
10
+
11
+ @torch.jit.ignore
12
+ def no_weight_decay(self) -> set:
13
+ return {"cls_token"}
14
+
15
+ @torch.jit.ignore
16
+ def group_matcher(self, coarse: bool = False) -> dict:
17
+ return dict(
18
+ stem=r'^(?:_orig_mod\.)?cls_token|^(?:_orig_mod\.)?pos_embed|^(?:_orig_mod\.)?patch_embed',
19
+ blocks=[(r'^(?:_orig_mod\.)?blocks\.(\d+)', None), (r'^(?:_orig_mod\.)?norm', (99999,))]
20
+ )
21
+
22
+ @classmethod
23
+ def random_masking(cls, x: torch.Tensor, mask_ratio: float) -> torch.BoolTensor:
24
+ B, L = x.shape[:2]
25
+ num_masked = int(L * mask_ratio)
26
+
27
+ noise = torch.rand(B, L, device=x.device)
28
+ rank = noise.argsort(dim=1)
29
+ mask = rank < num_masked
30
+
31
+ return mask
32
+
33
+ @abstractmethod
34
+ def forward(self) -> None:
35
+ raise NotImplementedError
draco/model/layer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .normalization import LayerNorm2d
2
+
3
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
draco/model/layer/normalization.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ __all__ = [
5
+ "LayerNorm2d",
6
+ ]
7
+
8
+
9
+ class LayerNorm2d(nn.Module):
10
+ def __init__(self, num_features: int, eps: float = 1e-6) -> None:
11
+ super().__init__()
12
+
13
+ self.weight = nn.Parameter(torch.ones(num_features))
14
+ self.bias = nn.Parameter(torch.zeros(num_features))
15
+ self.eps = eps
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ u = x.mean(1, keepdim=True)
19
+ s = (x - u).square().mean(1, keepdim=True)
20
+ x = (x - u) / torch.sqrt(s + self.eps)
21
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
22
+ return x
draco/model/utils/constant.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_vit_scale(scale: str) -> tuple[int, int, int]:
2
+ if scale == "tiny":
3
+ return 192, 12, 3
4
+ elif scale == "small":
5
+ return 384, 12, 6
6
+ elif scale == "base":
7
+ return 768, 12, 12
8
+ elif scale == "large":
9
+ return 1024, 24, 16
10
+ elif scale == "huge":
11
+ return 1280, 32, 16
12
+ else:
13
+ raise KeyError(f"Unknown Vision Transformer scale: {scale}")
14
+
15
+ def get_global_attn_indexes(num_layers: int) -> list[int]:
16
+ """
17
+ Args:
18
+ num_layers (int): The number of layers.
19
+
20
+ Returns:
21
+ List[int]: The global attention indexes.
22
+ """
23
+
24
+ return list(range(num_layers // 4 - 1, num_layers, num_layers // 4))
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ h5py==3.12.1
4
+ numpy==1.26.4
5
+ pandas==2.2.2
6
+ mrcfile==1.5.3
7
+ scipy==1.13.1
8
+ pycocotools==2.0.8
9
+ omegaconf==2.3.0
10
+ pillow
11
+ fvcore