Spaces:
Sleeping
Sleeping
Update openshape/__init__.py
Browse files- openshape/__init__.py +15 -0
openshape/__init__.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
from .ppat_rgb import Projected, PointPatchTransformer
|
|
|
5 |
|
6 |
|
7 |
def module(state_dict: dict, name):
|
@@ -32,10 +33,16 @@ def B32(s):
|
|
32 |
return model
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
model_list = {
|
36 |
"openshape-pointbert-vitb32-rgb": B32,
|
37 |
"openshape-pointbert-vitl14-rgb": L14,
|
38 |
"openshape-pointbert-vitg14-rgb": G14,
|
|
|
39 |
}
|
40 |
|
41 |
|
@@ -45,3 +52,11 @@ def load_pc_encoder(name):
|
|
45 |
if torch.cuda.is_available():
|
46 |
model.cuda()
|
47 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
from .ppat_rgb import Projected, PointPatchTransformer
|
5 |
+
from .Minkowski import MinkResNet34
|
6 |
|
7 |
|
8 |
def module(state_dict: dict, name):
|
|
|
33 |
return model
|
34 |
|
35 |
|
36 |
+
def Mk34(s):
|
37 |
+
model = MinkResNet34()
|
38 |
+
model.load_state_dict(module(s, 'pc_encoder'))
|
39 |
+
return model
|
40 |
+
|
41 |
model_list = {
|
42 |
"openshape-pointbert-vitb32-rgb": B32,
|
43 |
"openshape-pointbert-vitl14-rgb": L14,
|
44 |
"openshape-pointbert-vitg14-rgb": G14,
|
45 |
+
"tripletmix-spconv-all": Mk34,
|
46 |
}
|
47 |
|
48 |
|
|
|
52 |
if torch.cuda.is_available():
|
53 |
model.cuda()
|
54 |
return model
|
55 |
+
|
56 |
+
|
57 |
+
def load_pc_encoder_mix(name):
|
58 |
+
s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
|
59 |
+
model = model_list[name](s).eval()
|
60 |
+
if torch.cuda.is_available():
|
61 |
+
model.cuda()
|
62 |
+
return model
|