winfred2027 commited on
Commit
89a587b
·
verified ·
1 Parent(s): cffb9c2

Update openshape/__init__.py

Browse files
Files changed (1) hide show
  1. openshape/__init__.py +15 -3
openshape/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from huggingface_hub import hf_hub_download
4
  from .ppat_rgb import Projected, PointPatchTransformer
 
5
 
6
 
7
  def module(state_dict: dict, name):
@@ -32,7 +33,7 @@ def B32(s):
32
  return model
33
 
34
 
35
- def G14_shapenet_mix(s):
36
  model = Projected(
37
  PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
38
  nn.Linear(512, 1280)
@@ -40,11 +41,17 @@ def G14_shapenet_mix(s):
40
  model.load_state_dict(module(s['state_dict'], 'module'))
41
  return model
42
 
 
 
 
 
 
43
  model_list = {
44
  "openshape-pointbert-vitb32-rgb": B32,
45
  "openshape-pointbert-vitl14-rgb": L14,
46
  "openshape-pointbert-vitg14-rgb": G14,
47
- "tripletmix-pointbert-shapenet": G14_shapenet_mix,
 
48
  }
49
 
50
 
@@ -57,8 +64,13 @@ def load_pc_encoder(name):
57
 
58
 
59
  def load_pc_encoder_mix(name):
 
60
  s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
61
  model = model_list[name](s).eval()
 
 
 
 
62
  if torch.cuda.is_available():
63
  model.cuda()
64
- return model
 
2
  import torch.nn as nn
3
  from huggingface_hub import hf_hub_download
4
  from .ppat_rgb import Projected, PointPatchTransformer
5
+ from .mlp import MLP
6
 
7
 
8
  def module(state_dict: dict, name):
 
33
  return model
34
 
35
 
36
+ def G14_M(s):
37
  model = Projected(
38
  PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
39
  nn.Linear(512, 1280)
 
41
  model.load_state_dict(module(s['state_dict'], 'module'))
42
  return model
43
 
44
+ def PCA(s):
45
+ model = MLP(in_out_features=1280)
46
+ model.load_state_dict(module(s['pc_augment_adapter'], 'module'))
47
+ return model
48
+
49
  model_list = {
50
  "openshape-pointbert-vitb32-rgb": B32,
51
  "openshape-pointbert-vitl14-rgb": L14,
52
  "openshape-pointbert-vitg14-rgb": G14,
53
+ "tripletmix-pointbert-shapenet": G14_M,
54
+ "pc_adapter": PCA,
55
  }
56
 
57
 
 
64
 
65
 
66
  def load_pc_encoder_mix(name):
67
+ pc_adapter = None
68
  s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
69
  model = model_list[name](s).eval()
70
+ if name == "tripletmix-pointbert-shapenet":
71
+ pc_adapter = model_list["pc_adapter"](s).eval()
72
+ if torch.cuda.is_available():
73
+ pc_adapter.cuda()
74
  if torch.cuda.is_available():
75
  model.cuda()
76
+ return model, pc_adapter