winfred2027 commited on
Commit
0d10d86
·
verified ·
1 Parent(s): 711eaf6

Update openshape/__init__.py

Browse files
Files changed (1) hide show
  1. openshape/__init__.py +15 -0
openshape/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from huggingface_hub import hf_hub_download
4
  from .ppat_rgb import Projected, PointPatchTransformer
 
5
 
6
 
7
  def module(state_dict: dict, name):
@@ -32,10 +33,16 @@ def B32(s):
32
  return model
33
 
34
 
 
 
 
 
 
35
  model_list = {
36
  "openshape-pointbert-vitb32-rgb": B32,
37
  "openshape-pointbert-vitl14-rgb": L14,
38
  "openshape-pointbert-vitg14-rgb": G14,
 
39
  }
40
 
41
 
@@ -45,3 +52,11 @@ def load_pc_encoder(name):
45
  if torch.cuda.is_available():
46
  model.cuda()
47
  return model
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  from huggingface_hub import hf_hub_download
4
  from .ppat_rgb import Projected, PointPatchTransformer
5
+ from .Minkowski import MinkResNet34
6
 
7
 
8
  def module(state_dict: dict, name):
 
33
  return model
34
 
35
 
36
+ def Mk34(s):
37
+ model = MinkResNet34()
38
+ model.load_state_dict(module(s, 'pc_encoder'))
39
+ return model
40
+
41
  model_list = {
42
  "openshape-pointbert-vitb32-rgb": B32,
43
  "openshape-pointbert-vitl14-rgb": L14,
44
  "openshape-pointbert-vitg14-rgb": G14,
45
+ "tripletmix-spconv-all": Mk34,
46
  }
47
 
48
 
 
52
  if torch.cuda.is_available():
53
  model.cuda()
54
  return model
55
+
56
+
57
+ def load_pc_encoder_mix(name):
58
+ s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
59
+ model = model_list[name](s).eval()
60
+ if torch.cuda.is_available():
61
+ model.cuda()
62
+ return model