Spaces:
Build error
Build error
import torch.nn as nn | |
from functools import partial | |
from detectron2.config import LazyCall as L | |
from modeling import ViTMatte, MattingCriterion, Detail_Capture, ViT | |
# Base | |
embed_dim, num_heads = 384, 6 | |
model = L(ViTMatte)( | |
backbone = L(ViT)( # Single-scale ViT backbone | |
in_chans=4, | |
img_size=512, | |
patch_size=16, | |
embed_dim=embed_dim, | |
depth=12, | |
num_heads=num_heads, | |
drop_path_rate=0, | |
window_size=14, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
window_block_indexes=[ | |
# 2, 5, 8 11 for global attention | |
0, | |
1, | |
3, | |
4, | |
6, | |
7, | |
9, | |
10, | |
], | |
residual_block_indexes=[2, 5, 8, 11], | |
use_rel_pos=True, | |
out_feature="last_feat", | |
), | |
criterion=L(MattingCriterion)( | |
losses = ['unknown_l1_loss', 'known_l1_loss', 'loss_pha_laplacian', 'loss_gradient_penalty'] | |
), | |
pixel_mean = [123.675 / 255., 116.280 / 255., 103.530 / 255.], | |
pixel_std = [58.395 / 255., 57.120 / 255., 57.375 / 255.], | |
input_format = "RGB", | |
size_divisibility=32, | |
decoder=L(Detail_Capture)(), | |
) |