PeiqingYang commited on
Commit
5e2bf3b
·
1 Parent(s): 9d0b2aa

load from HF

Browse files
hugging_face/app.py CHANGED
@@ -416,9 +416,14 @@ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type]
416
  model = MaskGenerator(sam_checkpoint, args)
417
 
418
  # initialize matanyone
419
- pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0"
420
- ckpt_path = load_file_from_url(os.path.join(pretrain_model_url, 'matanyone.pth'), checkpoint_folder)
421
- matanyone_model = get_matanyone_model(ckpt_path, args.device)
 
 
 
 
 
422
  matanyone_model = matanyone_model.to(args.device).eval()
423
  matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
424
 
 
416
  model = MaskGenerator(sam_checkpoint, args)
417
 
418
  # initialize matanyone
419
+ # load from ckpt
420
+ # pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0"
421
+ # ckpt_path = load_file_from_url(os.path.join(pretrain_model_url, 'matanyone.pth'), checkpoint_folder)
422
+ # matanyone_model = get_matanyone_model(ckpt_path, args.device)
423
+ # load from Hugging Face
424
+ from matanyone.model.matanyone import MatAnyone
425
+ matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
426
+
427
  matanyone_model = matanyone_model.to(args.device).eval()
428
  matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
429
 
matanyone/__init__.py ADDED
File without changes
matanyone/inference/inference_core.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Iterable, Dict
2
  import logging
3
  from omegaconf import DictConfig
4
 
@@ -302,7 +302,6 @@ class InferenceCore:
302
 
303
  mask, _ = pad_divide_by(mask, 16)
304
  if need_segment:
305
- print("HERE!!!!!!!!!!!")
306
  # merge predicted mask with the incomplete input mask
307
  pred_prob_no_bg = pred_prob_with_bg[1:]
308
  # use the mutual exclusivity of segmentation
 
1
+ from typing import List, Optional, Iterable
2
  import logging
3
  from omegaconf import DictConfig
4
 
 
302
 
303
  mask, _ = pad_divide_by(mask, 16)
304
  if need_segment:
 
305
  # merge predicted mask with the incomplete input mask
306
  pred_prob_no_bg = pred_prob_with_bg[1:]
307
  # use the mutual exclusivity of segmentation
matanyone/inference/memory_manager.py CHANGED
@@ -2,12 +2,11 @@ import logging
2
  from omegaconf import DictConfig
3
  from typing import List, Dict
4
  import torch
5
- import cv2
6
 
7
  from matanyone.inference.object_manager import ObjectManager
8
  from matanyone.inference.kv_memory_store import KeyValueMemoryStore
9
  from matanyone.model.matanyone import MatAnyone
10
- from matanyone.model.utils.memory_utils import *
11
 
12
  log = logging.getLogger()
13
 
@@ -128,8 +127,6 @@ class MemoryManager:
128
  bs = pix_feat.shape[0]
129
  assert last_mask.shape[0] == bs
130
 
131
- uncert_mask = uncert_output["mask"] if uncert_output is not None else None
132
-
133
  """
134
  Compute affinity and perform readout
135
  """
@@ -374,7 +371,6 @@ class MemoryManager:
374
  self.engaged = False
375
 
376
  def compress_features(self, bucket_id: int) -> None:
377
- HW = self.HW
378
 
379
  # perform memory consolidation
380
  prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
 
2
  from omegaconf import DictConfig
3
  from typing import List, Dict
4
  import torch
 
5
 
6
  from matanyone.inference.object_manager import ObjectManager
7
  from matanyone.inference.kv_memory_store import KeyValueMemoryStore
8
  from matanyone.model.matanyone import MatAnyone
9
+ from matanyone.model.utils.memory_utils import get_similarity, do_softmax
10
 
11
  log = logging.getLogger()
12
 
 
127
  bs = pix_feat.shape[0]
128
  assert last_mask.shape[0] == bs
129
 
 
 
130
  """
131
  Compute affinity and perform readout
132
  """
 
371
  self.engaged = False
372
 
373
  def compress_features(self, bucket_id: int) -> None:
 
374
 
375
  # perform memory consolidation
376
  prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
matanyone/model/big_modules.py CHANGED
@@ -8,14 +8,15 @@ g - usually denotes features that are not shared between objects
8
  The trailing number of a variable usually denotes the stride
9
  """
10
 
 
11
  from omegaconf import DictConfig
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
 
16
- from matanyone.model.group_modules import *
17
  from matanyone.model.utils import resnet
18
- from matanyone.model.modules import *
19
 
20
  class UncertPred(nn.Module):
21
  def __init__(self, model_cfg: DictConfig):
@@ -51,11 +52,14 @@ class PixelEncoder(nn.Module):
51
  super().__init__()
52
 
53
  self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
 
 
 
54
  if self.is_resnet:
55
  if model_cfg.pixel_encoder.type == 'resnet18':
56
- network = resnet.resnet18(pretrained=True)
57
  elif model_cfg.pixel_encoder.type == 'resnet50':
58
- network = resnet.resnet50(pretrained=True)
59
  else:
60
  raise NotImplementedError
61
  self.conv1 = network.conv1
@@ -127,10 +131,13 @@ class MaskEncoder(nn.Module):
127
  self.single_object = single_object
128
  extra_dim = 1 if single_object else 2
129
 
 
 
 
130
  if model_cfg.mask_encoder.type == 'resnet18':
131
- network = resnet.resnet18(pretrained=True, extra_dim=extra_dim)
132
  elif model_cfg.mask_encoder.type == 'resnet50':
133
- network = resnet.resnet50(pretrained=True, extra_dim=extra_dim)
134
  else:
135
  raise NotImplementedError
136
  self.conv1 = network.conv1
 
8
  The trailing number of a variable usually denotes the stride
9
  """
10
 
11
+ from typing import Iterable
12
  from omegaconf import DictConfig
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
 
17
+ from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
18
  from matanyone.model.utils import resnet
19
+ from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
20
 
21
  class UncertPred(nn.Module):
22
  def __init__(self, model_cfg: DictConfig):
 
52
  super().__init__()
53
 
54
  self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
55
+ # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
56
+ # else default to True
57
+ is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
58
  if self.is_resnet:
59
  if model_cfg.pixel_encoder.type == 'resnet18':
60
+ network = resnet.resnet18(pretrained=is_pretrained_resnet)
61
  elif model_cfg.pixel_encoder.type == 'resnet50':
62
+ network = resnet.resnet50(pretrained=is_pretrained_resnet)
63
  else:
64
  raise NotImplementedError
65
  self.conv1 = network.conv1
 
131
  self.single_object = single_object
132
  extra_dim = 1 if single_object else 2
133
 
134
+ # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
135
+ # else default to True
136
+ is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
137
  if model_cfg.mask_encoder.type == 'resnet18':
138
+ network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
139
  elif model_cfg.mask_encoder.type == 'resnet50':
140
+ network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
141
  else:
142
  raise NotImplementedError
143
  self.conv1 = network.conv1
matanyone/model/matanyone.py CHANGED
@@ -1,21 +1,31 @@
1
- from typing import List, Dict
2
  import logging
3
  from omegaconf import DictConfig
4
  import torch
5
  import torch.nn as nn
 
 
 
6
 
7
- from matanyone.model.modules import *
8
- from matanyone.model.big_modules import *
9
  from matanyone.model.aux_modules import AuxComputer
10
- from matanyone.model.utils.memory_utils import *
11
  from matanyone.model.transformer.object_transformer import QueryTransformer
12
  from matanyone.model.transformer.object_summarizer import ObjectSummarizer
13
  from matanyone.utils.tensor_utils import aggregate
14
 
15
  log = logging.getLogger()
16
-
17
-
18
- class MatAnyone(nn.Module):
 
 
 
 
 
 
 
 
19
 
20
  def __init__(self, cfg: DictConfig, *, single_object=False):
21
  super().__init__()
@@ -304,7 +314,7 @@ class MatAnyone(nn.Module):
304
  finetune a trained model with single object datasets.
305
  """
306
  if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
307
- log.warning(f'Converting mask_encoder.conv1.weight from multiple objects to single object.'
308
  'This is not supposed to happen in standard training.')
309
  src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
310
  src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
 
1
+ from typing import List, Dict, Iterable
2
  import logging
3
  from omegaconf import DictConfig
4
  import torch
5
  import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from omegaconf import OmegaConf
8
+ from huggingface_hub import PyTorchModelHubMixin
9
 
10
+ from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
 
11
  from matanyone.model.aux_modules import AuxComputer
12
+ from matanyone.model.utils.memory_utils import get_affinity, readout
13
  from matanyone.model.transformer.object_transformer import QueryTransformer
14
  from matanyone.model.transformer.object_summarizer import ObjectSummarizer
15
  from matanyone.utils.tensor_utils import aggregate
16
 
17
  log = logging.getLogger()
18
+ class MatAnyone(nn.Module,
19
+ PyTorchModelHubMixin,
20
+ library_name="matanyone",
21
+ repo_url="https://github.com/pq-yang/MatAnyone",
22
+ coders={
23
+ DictConfig: (
24
+ lambda x: OmegaConf.to_container(x),
25
+ lambda data: OmegaConf.create(data),
26
+ )
27
+ },
28
+ ):
29
 
30
  def __init__(self, cfg: DictConfig, *, single_object=False):
31
  super().__init__()
 
314
  finetune a trained model with single object datasets.
315
  """
316
  if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
317
+ log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
318
  'This is not supposed to happen in standard training.')
319
  src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
320
  src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
matanyone/model/modules.py CHANGED
@@ -1,8 +1,9 @@
1
  from typing import List, Iterable
2
  import torch
3
  import torch.nn as nn
 
4
 
5
- from matanyone.model.group_modules import *
6
 
7
 
8
  class UpsampleBlock(nn.Module):
@@ -145,26 +146,4 @@ class ResBlock(nn.Module):
145
 
146
  g = self.downsample(g)
147
 
148
- return out_g + g
149
-
150
- def __init__(self, in_dim, reduction_dim, bins):
151
- super(PPM, self).__init__()
152
- self.features = []
153
- for bin in bins:
154
- self.features.append(nn.Sequential(
155
- nn.AdaptiveAvgPool2d(bin),
156
- nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
157
- nn.PReLU()
158
- ))
159
- self.features = nn.ModuleList(self.features)
160
- self.fuse = nn.Sequential(
161
- nn.Conv2d(in_dim+reduction_dim*4, in_dim, kernel_size=3, padding=1, bias=False),
162
- nn.PReLU())
163
-
164
- def forward(self, x):
165
- x_size = x.size()
166
- out = [x]
167
- for f in self.features:
168
- out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
169
- out_feat = self.fuse(torch.cat(out, 1))
170
- return out_feat
 
1
  from typing import List, Iterable
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
 
6
+ from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
7
 
8
 
9
  class UpsampleBlock(nn.Module):
 
146
 
147
  g = self.downsample(g)
148
 
149
+ return out_g + g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
matanyone/model/transformer/object_summarizer.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Dict, Optional
2
  from omegaconf import DictConfig
3
 
4
  import torch
 
1
+ from typing import Optional
2
  from omegaconf import DictConfig
3
 
4
  import torch
matanyone/model/transformer/object_transformer.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  from matanyone.model.group_modules import GConv2d
7
  from matanyone.utils.tensor_utils import aggregate
8
  from matanyone.model.transformer.positional_encoding import PositionalEncoding
9
- from matanyone.model.transformer.transformer_layers import *
10
 
11
 
12
  class QueryTransformerBlock(nn.Module):
 
6
  from matanyone.model.group_modules import GConv2d
7
  from matanyone.utils.tensor_utils import aggregate
8
  from matanyone.model.transformer.positional_encoding import PositionalEncoding
9
+ from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
10
 
11
 
12
  class QueryTransformerBlock(nn.Module):
matanyone/model/utils/resnet.py CHANGED
@@ -15,7 +15,7 @@ def load_weights_add_extra_dim(target, source_state, extra_dim=1):
15
  new_dict = OrderedDict()
16
 
17
  for k1, v1 in target.state_dict().items():
18
- if not 'num_batches_tracked' in k1:
19
  if k1 in source_state:
20
  tar_v = source_state[k1]
21
 
 
15
  new_dict = OrderedDict()
16
 
17
  for k1, v1 in target.state_dict().items():
18
+ if 'num_batches_tracked' not in k1:
19
  if k1 in source_state:
20
  tar_v = source_state[k1]
21
 
matanyone/utils/get_default_model.py CHANGED
@@ -6,9 +6,8 @@ from hydra import compose, initialize
6
 
7
  import torch
8
  from matanyone.model.matanyone import MatAnyone
9
- from matanyone.inference.utils.args_utils import get_dataset_cfg
10
 
11
- def get_matanyone_model(ckpt_path, device) -> MatAnyone:
12
  initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
13
  cfg = compose(config_name="eval_matanyone_config")
14
 
@@ -16,8 +15,13 @@ def get_matanyone_model(ckpt_path, device) -> MatAnyone:
16
  cfg['weights'] = ckpt_path
17
 
18
  # Load the network weights
19
- matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
20
- model_weights = torch.load(cfg.weights, map_location=device)
 
 
 
 
 
21
  matanyone.load_weights(model_weights)
22
 
23
  return matanyone
 
6
 
7
  import torch
8
  from matanyone.model.matanyone import MatAnyone
 
9
 
10
+ def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
11
  initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
12
  cfg = compose(config_name="eval_matanyone_config")
13
 
 
15
  cfg['weights'] = ckpt_path
16
 
17
  # Load the network weights
18
+ if device is not None:
19
+ matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
20
+ model_weights = torch.load(cfg.weights, map_location=device)
21
+ else: # if device is not specified, `.cuda()` by default
22
+ matanyone = MatAnyone(cfg, single_object=True).cuda().eval()
23
+ model_weights = torch.load(cfg.weights)
24
+
25
  matanyone.load_weights(model_weights)
26
 
27
  return matanyone