Last commit not found
import os | |
from pathlib import Path | |
import yaml | |
from timm.models.helpers import load_pretrained, load_custom_pretrained | |
from timm.models.registry import register_model | |
from timm.models.vision_transformer import _create_vision_transformer | |
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn | |
import segmenter_model.torch as ptu | |
import torch | |
from segmenter_model.decoder import MaskTransformer | |
from segmenter_model.segmenter import Segmenter | |
from segmenter_model.vit_dino import vit_small, VisionTransformer | |
def vit_base_patch8_384(pretrained=False, **kwargs): | |
"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). | |
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. | |
""" | |
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) | |
model = _create_vision_transformer( | |
"vit_base_patch8_384", | |
pretrained=pretrained, | |
default_cfg=dict( | |
url="", | |
input_size=(3, 384, 384), | |
mean=(0.5, 0.5, 0.5), | |
std=(0.5, 0.5, 0.5), | |
num_classes=1000, | |
), | |
**model_kwargs, | |
) | |
return model | |
def create_vit(model_cfg): | |
model_cfg = model_cfg.copy() | |
backbone = model_cfg.pop("backbone") | |
if 'pretrained_weights' in model_cfg: | |
pretrained_weights = model_cfg.pop('pretrained_weights') | |
if 'dino' in backbone: | |
if backbone.lower() == 'dino_vits16': | |
model_cfg['drop_rate'] = model_cfg['dropout'] | |
model = vit_small(**model_cfg) | |
# hard-coded for now, too lazy | |
pretrained_weights = 'dino_deitsmall16_pretrain.pth' | |
if not os.path.exists(pretrained_weights): | |
import urllib.request | |
urllib.request.urlretrieve( | |
"https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", | |
pretrained_weights) | |
model.load_state_dict(torch.load(pretrained_weights), strict=True) | |
else: | |
model = torch.hub.load('facebookresearch/dino:main', backbone) | |
setattr(model, 'd_model', model.num_features) | |
setattr(model, 'patch_size', model.patch_embed.patch_size) | |
setattr(model, 'distilled', False) | |
model.forward = lambda x, return_features: model.get_intermediate_layers(x, n=1)[0] | |
else: | |
normalization = model_cfg.pop("normalization") | |
model_cfg["n_cls"] = 1000 | |
mlp_expansion_ratio = 4 | |
model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] | |
if backbone in default_cfgs: | |
default_cfg = default_cfgs[backbone] | |
else: | |
default_cfg = dict( | |
pretrained=False, | |
num_classes=1000, | |
drop_rate=0.0, | |
drop_path_rate=0.0, | |
drop_block_rate=None, | |
) | |
default_cfg["input_size"] = ( | |
3, | |
model_cfg["image_size"][0], | |
model_cfg["image_size"][1], | |
) | |
model = VisionTransformer(**model_cfg) | |
if backbone == "vit_base_patch8_384": | |
path = os.path.expandvars("/home/vobecant/PhD/weights/vit_base_patch8_384.pth") | |
state_dict = torch.load(path, map_location="cpu") | |
filtered_dict = checkpoint_filter_fn(state_dict, model) | |
model.load_state_dict(filtered_dict, strict=True) | |
elif "deit" in backbone: | |
load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) | |
else: | |
load_custom_pretrained(model, default_cfg) | |
return model | |
def create_decoder(encoder, decoder_cfg): | |
decoder_cfg = decoder_cfg.copy() | |
name = decoder_cfg.pop("name") | |
decoder_cfg["d_encoder"] = encoder.d_model | |
decoder_cfg["patch_size"] = encoder.patch_size | |
if "linear" in name: | |
decoder = DecoderLinear(**decoder_cfg) | |
elif name == "mask_transformer": | |
dim = encoder.d_model | |
n_heads = dim // 64 | |
decoder_cfg["n_heads"] = n_heads | |
decoder_cfg["d_model"] = dim | |
decoder_cfg["d_ff"] = 4 * dim | |
decoder = MaskTransformer(**decoder_cfg) | |
elif 'deeplab' in name: | |
decoder = DeepLabHead(in_channels=encoder.d_model, num_classes=decoder_cfg["n_cls"], | |
patch_size=decoder_cfg["patch_size"]) | |
else: | |
raise ValueError(f"Unknown decoder: {name}") | |
return decoder | |
def create_segmenter(model_cfg): | |
model_cfg = model_cfg.copy() | |
decoder_cfg = model_cfg.pop("decoder") | |
decoder_cfg["n_cls"] = model_cfg["n_cls"] | |
if 'weights_path' in model_cfg.keys(): | |
weights_path = model_cfg.pop('weights_path') | |
else: | |
weights_path = None | |
encoder = create_vit(model_cfg) | |
decoder = create_decoder(encoder, decoder_cfg) | |
model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"]) | |
if weights_path is not None: | |
raise Exception('Wants to load weights to the complete segmenter insice create_segmenter method!') | |
state_dict = torch.load(weights_path, map_location="cpu") | |
if 'model' in state_dict: | |
state_dict = state_dict['model'] | |
msg = model.load_state_dict(state_dict, strict=False) | |
print(msg) | |
return model | |
def load_model(model_path, decoder_only=False, variant_path=None): | |
variant_path = Path(model_path).parent / "variant.yml" if variant_path is None else variant_path | |
with open(variant_path, "r") as f: | |
variant = yaml.load(f, Loader=yaml.FullLoader) | |
net_kwargs = variant["net_kwargs"] | |
model = create_segmenter(net_kwargs) | |
data = torch.load(model_path, map_location=ptu.device) | |
checkpoint = data["model"] | |
if decoder_only: | |
model.decoder.load_state_dict(checkpoint, strict=True) | |
else: | |
model.load_state_dict(checkpoint, strict=True) | |
return model, variant | |