Spaces:
Runtime error
Runtime error
from LightHQSAM.tiny_vit_sam import TinyViT | |
from segment_anything.modeling import MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer | |
def setup_model(): | |
prompt_embed_dim = 256 | |
image_size = 1024 | |
vit_patch_size = 16 | |
image_embedding_size = image_size // vit_patch_size | |
mobile_sam = Sam( | |
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, | |
embed_dims=[64, 128, 160, 320], | |
depths=[2, 2, 6, 2], | |
num_heads=[2, 4, 5, 10], | |
window_sizes=[7, 7, 14, 7], | |
mlp_ratio=4., | |
drop_rate=0., | |
drop_path_rate=0.0, | |
use_checkpoint=False, | |
mbconv_expand_ratio=4.0, | |
local_conv_size=3, | |
layer_lr_decay=0.8 | |
), | |
prompt_encoder=PromptEncoder( | |
embed_dim=prompt_embed_dim, | |
image_embedding_size=(image_embedding_size, image_embedding_size), | |
input_image_size=(image_size, image_size), | |
mask_in_chans=16, | |
), | |
mask_decoder=MaskDecoderHQ( | |
num_multimask_outputs=3, | |
transformer=TwoWayTransformer( | |
depth=2, | |
embedding_dim=prompt_embed_dim, | |
mlp_dim=2048, | |
num_heads=8, | |
), | |
transformer_dim=prompt_embed_dim, | |
iou_head_depth=3, | |
iou_head_hidden_dim=256, | |
vit_dim=160, | |
), | |
pixel_mean=[123.675, 116.28, 103.53], | |
pixel_std=[58.395, 57.12, 57.375], | |
) | |
return mobile_sam |