Spaces:
Sleeping
Sleeping
Update openshape/__init__.py
Browse files- openshape/__init__.py +15 -3
openshape/__init__.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
from .ppat_rgb import Projected, PointPatchTransformer
|
|
|
5 |
|
6 |
|
7 |
def module(state_dict: dict, name):
|
@@ -32,7 +33,7 @@ def B32(s):
|
|
32 |
return model
|
33 |
|
34 |
|
35 |
-
def
|
36 |
model = Projected(
|
37 |
PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
|
38 |
nn.Linear(512, 1280)
|
@@ -40,11 +41,17 @@ def G14_shapenet_mix(s):
|
|
40 |
model.load_state_dict(module(s['state_dict'], 'module'))
|
41 |
return model
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
model_list = {
|
44 |
"openshape-pointbert-vitb32-rgb": B32,
|
45 |
"openshape-pointbert-vitl14-rgb": L14,
|
46 |
"openshape-pointbert-vitg14-rgb": G14,
|
47 |
-
"tripletmix-pointbert-shapenet":
|
|
|
48 |
}
|
49 |
|
50 |
|
@@ -57,8 +64,13 @@ def load_pc_encoder(name):
|
|
57 |
|
58 |
|
59 |
def load_pc_encoder_mix(name):
|
|
|
60 |
s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
|
61 |
model = model_list[name](s).eval()
|
|
|
|
|
|
|
|
|
62 |
if torch.cuda.is_available():
|
63 |
model.cuda()
|
64 |
-
return model
|
|
|
2 |
import torch.nn as nn
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
from .ppat_rgb import Projected, PointPatchTransformer
|
5 |
+
from .mlp import MLP
|
6 |
|
7 |
|
8 |
def module(state_dict: dict, name):
|
|
|
33 |
return model
|
34 |
|
35 |
|
36 |
+
def G14_M(s):
|
37 |
model = Projected(
|
38 |
PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
|
39 |
nn.Linear(512, 1280)
|
|
|
41 |
model.load_state_dict(module(s['state_dict'], 'module'))
|
42 |
return model
|
43 |
|
44 |
+
def PCA(s):
|
45 |
+
model = MLP(in_out_features=1280)
|
46 |
+
model.load_state_dict(module(s['pc_augment_adapter'], 'module'))
|
47 |
+
return model
|
48 |
+
|
49 |
model_list = {
|
50 |
"openshape-pointbert-vitb32-rgb": B32,
|
51 |
"openshape-pointbert-vitl14-rgb": L14,
|
52 |
"openshape-pointbert-vitg14-rgb": G14,
|
53 |
+
"tripletmix-pointbert-shapenet": G14_M,
|
54 |
+
"pc_adapter": PCA,
|
55 |
}
|
56 |
|
57 |
|
|
|
64 |
|
65 |
|
66 |
def load_pc_encoder_mix(name):
|
67 |
+
pc_adapter = None
|
68 |
s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu')
|
69 |
model = model_list[name](s).eval()
|
70 |
+
if name == "tripletmix-pointbert-shapenet":
|
71 |
+
pc_adapter = model_list["pc_adapter"](s).eval()
|
72 |
+
if torch.cuda.is_available():
|
73 |
+
pc_adapter.cuda()
|
74 |
if torch.cuda.is_available():
|
75 |
model.cuda()
|
76 |
+
return model, pc_adapter
|