Spaces:
Running
on
L4
Running
on
L4
Commit
·
5e2bf3b
1
Parent(s):
9d0b2aa
load from HF
Browse files- hugging_face/app.py +8 -3
- matanyone/__init__.py +0 -0
- matanyone/inference/inference_core.py +1 -2
- matanyone/inference/memory_manager.py +1 -5
- matanyone/model/big_modules.py +13 -6
- matanyone/model/matanyone.py +18 -8
- matanyone/model/modules.py +3 -24
- matanyone/model/transformer/object_summarizer.py +1 -1
- matanyone/model/transformer/object_transformer.py +1 -1
- matanyone/model/utils/resnet.py +1 -1
- matanyone/utils/get_default_model.py +8 -4
hugging_face/app.py
CHANGED
@@ -416,9 +416,14 @@ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type]
|
|
416 |
model = MaskGenerator(sam_checkpoint, args)
|
417 |
|
418 |
# initialize matanyone
|
419 |
-
|
420 |
-
|
421 |
-
|
|
|
|
|
|
|
|
|
|
|
422 |
matanyone_model = matanyone_model.to(args.device).eval()
|
423 |
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
424 |
|
|
|
416 |
model = MaskGenerator(sam_checkpoint, args)
|
417 |
|
418 |
# initialize matanyone
|
419 |
+
# load from ckpt
|
420 |
+
# pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0"
|
421 |
+
# ckpt_path = load_file_from_url(os.path.join(pretrain_model_url, 'matanyone.pth'), checkpoint_folder)
|
422 |
+
# matanyone_model = get_matanyone_model(ckpt_path, args.device)
|
423 |
+
# load from Hugging Face
|
424 |
+
from matanyone.model.matanyone import MatAnyone
|
425 |
+
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
|
426 |
+
|
427 |
matanyone_model = matanyone_model.to(args.device).eval()
|
428 |
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
429 |
|
matanyone/__init__.py
ADDED
File without changes
|
matanyone/inference/inference_core.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import List, Optional, Iterable
|
2 |
import logging
|
3 |
from omegaconf import DictConfig
|
4 |
|
@@ -302,7 +302,6 @@ class InferenceCore:
|
|
302 |
|
303 |
mask, _ = pad_divide_by(mask, 16)
|
304 |
if need_segment:
|
305 |
-
print("HERE!!!!!!!!!!!")
|
306 |
# merge predicted mask with the incomplete input mask
|
307 |
pred_prob_no_bg = pred_prob_with_bg[1:]
|
308 |
# use the mutual exclusivity of segmentation
|
|
|
1 |
+
from typing import List, Optional, Iterable
|
2 |
import logging
|
3 |
from omegaconf import DictConfig
|
4 |
|
|
|
302 |
|
303 |
mask, _ = pad_divide_by(mask, 16)
|
304 |
if need_segment:
|
|
|
305 |
# merge predicted mask with the incomplete input mask
|
306 |
pred_prob_no_bg = pred_prob_with_bg[1:]
|
307 |
# use the mutual exclusivity of segmentation
|
matanyone/inference/memory_manager.py
CHANGED
@@ -2,12 +2,11 @@ import logging
|
|
2 |
from omegaconf import DictConfig
|
3 |
from typing import List, Dict
|
4 |
import torch
|
5 |
-
import cv2
|
6 |
|
7 |
from matanyone.inference.object_manager import ObjectManager
|
8 |
from matanyone.inference.kv_memory_store import KeyValueMemoryStore
|
9 |
from matanyone.model.matanyone import MatAnyone
|
10 |
-
from matanyone.model.utils.memory_utils import
|
11 |
|
12 |
log = logging.getLogger()
|
13 |
|
@@ -128,8 +127,6 @@ class MemoryManager:
|
|
128 |
bs = pix_feat.shape[0]
|
129 |
assert last_mask.shape[0] == bs
|
130 |
|
131 |
-
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
132 |
-
|
133 |
"""
|
134 |
Compute affinity and perform readout
|
135 |
"""
|
@@ -374,7 +371,6 @@ class MemoryManager:
|
|
374 |
self.engaged = False
|
375 |
|
376 |
def compress_features(self, bucket_id: int) -> None:
|
377 |
-
HW = self.HW
|
378 |
|
379 |
# perform memory consolidation
|
380 |
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
|
|
2 |
from omegaconf import DictConfig
|
3 |
from typing import List, Dict
|
4 |
import torch
|
|
|
5 |
|
6 |
from matanyone.inference.object_manager import ObjectManager
|
7 |
from matanyone.inference.kv_memory_store import KeyValueMemoryStore
|
8 |
from matanyone.model.matanyone import MatAnyone
|
9 |
+
from matanyone.model.utils.memory_utils import get_similarity, do_softmax
|
10 |
|
11 |
log = logging.getLogger()
|
12 |
|
|
|
127 |
bs = pix_feat.shape[0]
|
128 |
assert last_mask.shape[0] == bs
|
129 |
|
|
|
|
|
130 |
"""
|
131 |
Compute affinity and perform readout
|
132 |
"""
|
|
|
371 |
self.engaged = False
|
372 |
|
373 |
def compress_features(self, bucket_id: int) -> None:
|
|
|
374 |
|
375 |
# perform memory consolidation
|
376 |
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
matanyone/model/big_modules.py
CHANGED
@@ -8,14 +8,15 @@ g - usually denotes features that are not shared between objects
|
|
8 |
The trailing number of a variable usually denotes the stride
|
9 |
"""
|
10 |
|
|
|
11 |
from omegaconf import DictConfig
|
12 |
import torch
|
13 |
import torch.nn as nn
|
14 |
import torch.nn.functional as F
|
15 |
|
16 |
-
from matanyone.model.group_modules import
|
17 |
from matanyone.model.utils import resnet
|
18 |
-
from matanyone.model.modules import
|
19 |
|
20 |
class UncertPred(nn.Module):
|
21 |
def __init__(self, model_cfg: DictConfig):
|
@@ -51,11 +52,14 @@ class PixelEncoder(nn.Module):
|
|
51 |
super().__init__()
|
52 |
|
53 |
self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
|
|
|
|
|
|
|
54 |
if self.is_resnet:
|
55 |
if model_cfg.pixel_encoder.type == 'resnet18':
|
56 |
-
network = resnet.resnet18(pretrained=
|
57 |
elif model_cfg.pixel_encoder.type == 'resnet50':
|
58 |
-
network = resnet.resnet50(pretrained=
|
59 |
else:
|
60 |
raise NotImplementedError
|
61 |
self.conv1 = network.conv1
|
@@ -127,10 +131,13 @@ class MaskEncoder(nn.Module):
|
|
127 |
self.single_object = single_object
|
128 |
extra_dim = 1 if single_object else 2
|
129 |
|
|
|
|
|
|
|
130 |
if model_cfg.mask_encoder.type == 'resnet18':
|
131 |
-
network = resnet.resnet18(pretrained=
|
132 |
elif model_cfg.mask_encoder.type == 'resnet50':
|
133 |
-
network = resnet.resnet50(pretrained=
|
134 |
else:
|
135 |
raise NotImplementedError
|
136 |
self.conv1 = network.conv1
|
|
|
8 |
The trailing number of a variable usually denotes the stride
|
9 |
"""
|
10 |
|
11 |
+
from typing import Iterable
|
12 |
from omegaconf import DictConfig
|
13 |
import torch
|
14 |
import torch.nn as nn
|
15 |
import torch.nn.functional as F
|
16 |
|
17 |
+
from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
|
18 |
from matanyone.model.utils import resnet
|
19 |
+
from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
|
20 |
|
21 |
class UncertPred(nn.Module):
|
22 |
def __init__(self, model_cfg: DictConfig):
|
|
|
52 |
super().__init__()
|
53 |
|
54 |
self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
|
55 |
+
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
|
56 |
+
# else default to True
|
57 |
+
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
|
58 |
if self.is_resnet:
|
59 |
if model_cfg.pixel_encoder.type == 'resnet18':
|
60 |
+
network = resnet.resnet18(pretrained=is_pretrained_resnet)
|
61 |
elif model_cfg.pixel_encoder.type == 'resnet50':
|
62 |
+
network = resnet.resnet50(pretrained=is_pretrained_resnet)
|
63 |
else:
|
64 |
raise NotImplementedError
|
65 |
self.conv1 = network.conv1
|
|
|
131 |
self.single_object = single_object
|
132 |
extra_dim = 1 if single_object else 2
|
133 |
|
134 |
+
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
|
135 |
+
# else default to True
|
136 |
+
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
|
137 |
if model_cfg.mask_encoder.type == 'resnet18':
|
138 |
+
network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
|
139 |
elif model_cfg.mask_encoder.type == 'resnet50':
|
140 |
+
network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
|
141 |
else:
|
142 |
raise NotImplementedError
|
143 |
self.conv1 = network.conv1
|
matanyone/model/matanyone.py
CHANGED
@@ -1,21 +1,31 @@
|
|
1 |
-
from typing import List, Dict
|
2 |
import logging
|
3 |
from omegaconf import DictConfig
|
4 |
import torch
|
5 |
import torch.nn as nn
|
|
|
|
|
|
|
6 |
|
7 |
-
from matanyone.model.
|
8 |
-
from matanyone.model.big_modules import *
|
9 |
from matanyone.model.aux_modules import AuxComputer
|
10 |
-
from matanyone.model.utils.memory_utils import
|
11 |
from matanyone.model.transformer.object_transformer import QueryTransformer
|
12 |
from matanyone.model.transformer.object_summarizer import ObjectSummarizer
|
13 |
from matanyone.utils.tensor_utils import aggregate
|
14 |
|
15 |
log = logging.getLogger()
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def __init__(self, cfg: DictConfig, *, single_object=False):
|
21 |
super().__init__()
|
@@ -304,7 +314,7 @@ class MatAnyone(nn.Module):
|
|
304 |
finetune a trained model with single object datasets.
|
305 |
"""
|
306 |
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
|
307 |
-
log.warning(
|
308 |
'This is not supposed to happen in standard training.')
|
309 |
src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
|
310 |
src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
|
|
|
1 |
+
from typing import List, Dict, Iterable
|
2 |
import logging
|
3 |
from omegaconf import DictConfig
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
9 |
|
10 |
+
from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
|
|
|
11 |
from matanyone.model.aux_modules import AuxComputer
|
12 |
+
from matanyone.model.utils.memory_utils import get_affinity, readout
|
13 |
from matanyone.model.transformer.object_transformer import QueryTransformer
|
14 |
from matanyone.model.transformer.object_summarizer import ObjectSummarizer
|
15 |
from matanyone.utils.tensor_utils import aggregate
|
16 |
|
17 |
log = logging.getLogger()
|
18 |
+
class MatAnyone(nn.Module,
|
19 |
+
PyTorchModelHubMixin,
|
20 |
+
library_name="matanyone",
|
21 |
+
repo_url="https://github.com/pq-yang/MatAnyone",
|
22 |
+
coders={
|
23 |
+
DictConfig: (
|
24 |
+
lambda x: OmegaConf.to_container(x),
|
25 |
+
lambda data: OmegaConf.create(data),
|
26 |
+
)
|
27 |
+
},
|
28 |
+
):
|
29 |
|
30 |
def __init__(self, cfg: DictConfig, *, single_object=False):
|
31 |
super().__init__()
|
|
|
314 |
finetune a trained model with single object datasets.
|
315 |
"""
|
316 |
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
|
317 |
+
log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
|
318 |
'This is not supposed to happen in standard training.')
|
319 |
src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
|
320 |
src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
|
matanyone/model/modules.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
from typing import List, Iterable
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
5 |
-
from matanyone.model.group_modules import
|
6 |
|
7 |
|
8 |
class UpsampleBlock(nn.Module):
|
@@ -145,26 +146,4 @@ class ResBlock(nn.Module):
|
|
145 |
|
146 |
g = self.downsample(g)
|
147 |
|
148 |
-
return out_g + g
|
149 |
-
|
150 |
-
def __init__(self, in_dim, reduction_dim, bins):
|
151 |
-
super(PPM, self).__init__()
|
152 |
-
self.features = []
|
153 |
-
for bin in bins:
|
154 |
-
self.features.append(nn.Sequential(
|
155 |
-
nn.AdaptiveAvgPool2d(bin),
|
156 |
-
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
|
157 |
-
nn.PReLU()
|
158 |
-
))
|
159 |
-
self.features = nn.ModuleList(self.features)
|
160 |
-
self.fuse = nn.Sequential(
|
161 |
-
nn.Conv2d(in_dim+reduction_dim*4, in_dim, kernel_size=3, padding=1, bias=False),
|
162 |
-
nn.PReLU())
|
163 |
-
|
164 |
-
def forward(self, x):
|
165 |
-
x_size = x.size()
|
166 |
-
out = [x]
|
167 |
-
for f in self.features:
|
168 |
-
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
|
169 |
-
out_feat = self.fuse(torch.cat(out, 1))
|
170 |
-
return out_feat
|
|
|
1 |
from typing import List, Iterable
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
|
6 |
+
from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
|
7 |
|
8 |
|
9 |
class UpsampleBlock(nn.Module):
|
|
|
146 |
|
147 |
g = self.downsample(g)
|
148 |
|
149 |
+
return out_g + g
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
matanyone/model/transformer/object_summarizer.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import
|
2 |
from omegaconf import DictConfig
|
3 |
|
4 |
import torch
|
|
|
1 |
+
from typing import Optional
|
2 |
from omegaconf import DictConfig
|
3 |
|
4 |
import torch
|
matanyone/model/transformer/object_transformer.py
CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
|
|
6 |
from matanyone.model.group_modules import GConv2d
|
7 |
from matanyone.utils.tensor_utils import aggregate
|
8 |
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
9 |
-
from matanyone.model.transformer.transformer_layers import
|
10 |
|
11 |
|
12 |
class QueryTransformerBlock(nn.Module):
|
|
|
6 |
from matanyone.model.group_modules import GConv2d
|
7 |
from matanyone.utils.tensor_utils import aggregate
|
8 |
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
9 |
+
from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
|
10 |
|
11 |
|
12 |
class QueryTransformerBlock(nn.Module):
|
matanyone/model/utils/resnet.py
CHANGED
@@ -15,7 +15,7 @@ def load_weights_add_extra_dim(target, source_state, extra_dim=1):
|
|
15 |
new_dict = OrderedDict()
|
16 |
|
17 |
for k1, v1 in target.state_dict().items():
|
18 |
-
if
|
19 |
if k1 in source_state:
|
20 |
tar_v = source_state[k1]
|
21 |
|
|
|
15 |
new_dict = OrderedDict()
|
16 |
|
17 |
for k1, v1 in target.state_dict().items():
|
18 |
+
if 'num_batches_tracked' not in k1:
|
19 |
if k1 in source_state:
|
20 |
tar_v = source_state[k1]
|
21 |
|
matanyone/utils/get_default_model.py
CHANGED
@@ -6,9 +6,8 @@ from hydra import compose, initialize
|
|
6 |
|
7 |
import torch
|
8 |
from matanyone.model.matanyone import MatAnyone
|
9 |
-
from matanyone.inference.utils.args_utils import get_dataset_cfg
|
10 |
|
11 |
-
def get_matanyone_model(ckpt_path, device) -> MatAnyone:
|
12 |
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
|
13 |
cfg = compose(config_name="eval_matanyone_config")
|
14 |
|
@@ -16,8 +15,13 @@ def get_matanyone_model(ckpt_path, device) -> MatAnyone:
|
|
16 |
cfg['weights'] = ckpt_path
|
17 |
|
18 |
# Load the network weights
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
matanyone.load_weights(model_weights)
|
22 |
|
23 |
return matanyone
|
|
|
6 |
|
7 |
import torch
|
8 |
from matanyone.model.matanyone import MatAnyone
|
|
|
9 |
|
10 |
+
def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
|
11 |
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
|
12 |
cfg = compose(config_name="eval_matanyone_config")
|
13 |
|
|
|
15 |
cfg['weights'] = ckpt_path
|
16 |
|
17 |
# Load the network weights
|
18 |
+
if device is not None:
|
19 |
+
matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
|
20 |
+
model_weights = torch.load(cfg.weights, map_location=device)
|
21 |
+
else: # if device is not specified, `.cuda()` by default
|
22 |
+
matanyone = MatAnyone(cfg, single_object=True).cuda().eval()
|
23 |
+
model_weights = torch.load(cfg.weights)
|
24 |
+
|
25 |
matanyone.load_weights(model_weights)
|
26 |
|
27 |
return matanyone
|