|
import os |
|
from pathlib import Path |
|
|
|
import yaml |
|
from timm.models.helpers import load_pretrained, load_custom_pretrained |
|
from timm.models.registry import register_model |
|
from timm.models.vision_transformer import _create_vision_transformer |
|
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn |
|
|
|
import segmenter_model.torch as ptu |
|
import torch |
|
from segmenter_model.decoder import MaskTransformer |
|
from segmenter_model.segmenter import Segmenter |
|
from segmenter_model.vit_dino import vit_small, VisionTransformer |
|
|
|
|
|
@register_model |
|
def vit_base_patch8_384(pretrained=False, **kwargs): |
|
"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. |
|
""" |
|
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) |
|
model = _create_vision_transformer( |
|
"vit_base_patch8_384", |
|
pretrained=pretrained, |
|
default_cfg=dict( |
|
url="", |
|
input_size=(3, 384, 384), |
|
mean=(0.5, 0.5, 0.5), |
|
std=(0.5, 0.5, 0.5), |
|
num_classes=1000, |
|
), |
|
**model_kwargs, |
|
) |
|
return model |
|
|
|
|
|
def create_vit(model_cfg): |
|
model_cfg = model_cfg.copy() |
|
backbone = model_cfg.pop("backbone") |
|
if 'pretrained_weights' in model_cfg: |
|
pretrained_weights = model_cfg.pop('pretrained_weights') |
|
|
|
if 'dino' in backbone: |
|
if backbone.lower() == 'dino_vits16': |
|
model_cfg['drop_rate'] = model_cfg['dropout'] |
|
model = vit_small(**model_cfg) |
|
|
|
pretrained_weights = 'dino_deitsmall16_pretrain.pth' |
|
if not os.path.exists(pretrained_weights): |
|
import urllib.request |
|
urllib.request.urlretrieve( |
|
"https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", |
|
pretrained_weights) |
|
model.load_state_dict(torch.load(pretrained_weights), strict=True) |
|
else: |
|
model = torch.hub.load('facebookresearch/dino:main', backbone) |
|
setattr(model, 'd_model', model.num_features) |
|
setattr(model, 'patch_size', model.patch_embed.patch_size) |
|
setattr(model, 'distilled', False) |
|
model.forward = lambda x, return_features: model.get_intermediate_layers(x, n=1)[0] |
|
else: |
|
normalization = model_cfg.pop("normalization") |
|
model_cfg["n_cls"] = 1000 |
|
mlp_expansion_ratio = 4 |
|
model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] |
|
|
|
if backbone in default_cfgs: |
|
default_cfg = default_cfgs[backbone] |
|
else: |
|
default_cfg = dict( |
|
pretrained=False, |
|
num_classes=1000, |
|
drop_rate=0.0, |
|
drop_path_rate=0.0, |
|
drop_block_rate=None, |
|
) |
|
|
|
default_cfg["input_size"] = ( |
|
3, |
|
model_cfg["image_size"][0], |
|
model_cfg["image_size"][1], |
|
) |
|
model = VisionTransformer(**model_cfg) |
|
if backbone == "vit_base_patch8_384": |
|
path = os.path.expandvars("/home/vobecant/PhD/weights/vit_base_patch8_384.pth") |
|
state_dict = torch.load(path, map_location="cpu") |
|
filtered_dict = checkpoint_filter_fn(state_dict, model) |
|
model.load_state_dict(filtered_dict, strict=True) |
|
elif "deit" in backbone: |
|
load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) |
|
else: |
|
load_custom_pretrained(model, default_cfg) |
|
|
|
return model |
|
|
|
|
|
def create_decoder(encoder, decoder_cfg): |
|
decoder_cfg = decoder_cfg.copy() |
|
name = decoder_cfg.pop("name") |
|
decoder_cfg["d_encoder"] = encoder.d_model |
|
decoder_cfg["patch_size"] = encoder.patch_size |
|
|
|
if "linear" in name: |
|
decoder = DecoderLinear(**decoder_cfg) |
|
elif name == "mask_transformer": |
|
dim = encoder.d_model |
|
n_heads = dim // 64 |
|
decoder_cfg["n_heads"] = n_heads |
|
decoder_cfg["d_model"] = dim |
|
decoder_cfg["d_ff"] = 4 * dim |
|
decoder = MaskTransformer(**decoder_cfg) |
|
elif 'deeplab' in name: |
|
decoder = DeepLabHead(in_channels=encoder.d_model, num_classes=decoder_cfg["n_cls"], |
|
patch_size=decoder_cfg["patch_size"]) |
|
else: |
|
raise ValueError(f"Unknown decoder: {name}") |
|
return decoder |
|
|
|
|
|
def create_segmenter(model_cfg): |
|
model_cfg = model_cfg.copy() |
|
decoder_cfg = model_cfg.pop("decoder") |
|
decoder_cfg["n_cls"] = model_cfg["n_cls"] |
|
|
|
if 'weights_path' in model_cfg.keys(): |
|
weights_path = model_cfg.pop('weights_path') |
|
else: |
|
weights_path = None |
|
|
|
encoder = create_vit(model_cfg) |
|
decoder = create_decoder(encoder, decoder_cfg) |
|
model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"]) |
|
|
|
if weights_path is not None: |
|
raise Exception('Wants to load weights to the complete segmenter insice create_segmenter method!') |
|
state_dict = torch.load(weights_path, map_location="cpu") |
|
if 'model' in state_dict: |
|
state_dict = state_dict['model'] |
|
msg = model.load_state_dict(state_dict, strict=False) |
|
print(msg) |
|
|
|
return model |
|
|
|
|
|
def load_model(model_path, decoder_only=False, variant_path=None): |
|
variant_path = Path(model_path).parent / "variant.yml" if variant_path is None else variant_path |
|
with open(variant_path, "r") as f: |
|
variant = yaml.load(f, Loader=yaml.FullLoader) |
|
net_kwargs = variant["net_kwargs"] |
|
|
|
model = create_segmenter(net_kwargs) |
|
data = torch.load(model_path, map_location=ptu.device) |
|
checkpoint = data["model"] |
|
|
|
if decoder_only: |
|
model.decoder.load_state_dict(checkpoint, strict=True) |
|
else: |
|
model.load_state_dict(checkpoint, strict=True) |
|
|
|
return model, variant |
|
|