winfred2027 commited on
Commit
e6dd343
·
verified ·
1 Parent(s): fbeb0dc

Update openshape/__init__.py

Browse files
Files changed (1) hide show
  1. openshape/__init__.py +2 -6
openshape/__init__.py CHANGED
@@ -58,6 +58,7 @@ model_list = {
58
  "openshape-pointbert-vitl14-rgb": L14,
59
  "openshape-pointbert-vitg14-rgb": G14,
60
  "openshape-pointbert-shapenet": S14,
 
61
  "tripletmix-pointbert-shapenet": G14_M,
62
  "pc_adapter": PCA,
63
  }
@@ -72,13 +73,8 @@ def load_pc_encoder(name):
72
 
73
 
74
  def load_pc_encoder_mix(name):
75
- pc_adapter = None
76
  s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
77
  model = model_list[name](s).eval()
78
- if name == "tripletmix-pointbert-shapenet":
79
- pc_adapter = model_list["pc_adapter"](s).eval()
80
- if torch.cuda.is_available():
81
- pc_adapter.cuda()
82
  if torch.cuda.is_available():
83
  model.cuda()
84
- return model, pc_adapter
 
58
  "openshape-pointbert-vitl14-rgb": L14,
59
  "openshape-pointbert-vitg14-rgb": G14,
60
  "openshape-pointbert-shapenet": S14,
61
+ "tripletmix-pointbert-all-modelnet40": G14_M,
62
  "tripletmix-pointbert-shapenet": G14_M,
63
  "pc_adapter": PCA,
64
  }
 
73
 
74
 
75
  def load_pc_encoder_mix(name):
 
76
  s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
77
  model = model_list[name](s).eval()
 
 
 
 
78
  if torch.cuda.is_available():
79
  model.cuda()
80
+ return model