Spaces:
Sleeping
Sleeping
delele unnecessary dependency
Browse files- app.py +2 -0
- dataset/__init__.py +4 -2
- dataset/utils.py +47 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -128,6 +128,8 @@ def main(cfg: ProjectConfig):
|
|
128 |
# Setup model
|
129 |
runner = DemoRunner(cfg)
|
130 |
|
|
|
|
|
131 |
# Setup interface
|
132 |
demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
|
133 |
with demo:
|
|
|
128 |
# Setup model
|
129 |
runner = DemoRunner(cfg)
|
130 |
|
131 |
+
# runner = None # without model initialization, it shows one line of thumbnail
|
132 |
+
|
133 |
# Setup interface
|
134 |
demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
|
135 |
with demo:
|
dataset/__init__.py
CHANGED
@@ -15,16 +15,17 @@ from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
|
15 |
from pytorch3d.implicitron.tools.config import expand_args_fields
|
16 |
from pytorch3d.renderer.cameras import CamerasBase
|
17 |
from torch.utils.data import DataLoader
|
|
|
18 |
|
19 |
from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional
|
20 |
-
from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE
|
21 |
from .utils import DatasetMap
|
22 |
-
|
23 |
|
24 |
|
25 |
def get_dataset(cfg: ProjectConfig):
|
26 |
|
27 |
if cfg.dataset.type == 'co3dv2':
|
|
|
28 |
dataset_cfg: CO3DConfig = cfg.dataset
|
29 |
dataloader_cfg: DataloaderConfig = cfg.dataloader
|
30 |
|
@@ -100,6 +101,7 @@ def get_dataset(cfg: ProjectConfig):
|
|
100 |
dataloader_val.batch_sampler.drop_last = False
|
101 |
elif cfg.dataset.type == 'shapenet_r2n2':
|
102 |
# from ..configs.structured import ShapeNetR2N2Config
|
|
|
103 |
dataset_cfg: ShapeNetR2N2Config = cfg.dataset
|
104 |
# for k in dataset_cfg:
|
105 |
# print(k)
|
|
|
15 |
from pytorch3d.implicitron.tools.config import expand_args_fields
|
16 |
from pytorch3d.renderer.cameras import CamerasBase
|
17 |
from torch.utils.data import DataLoader
|
18 |
+
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
19 |
|
20 |
from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional
|
|
|
21 |
from .utils import DatasetMap
|
22 |
+
|
23 |
|
24 |
|
25 |
def get_dataset(cfg: ProjectConfig):
|
26 |
|
27 |
if cfg.dataset.type == 'co3dv2':
|
28 |
+
from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE
|
29 |
dataset_cfg: CO3DConfig = cfg.dataset
|
30 |
dataloader_cfg: DataloaderConfig = cfg.dataloader
|
31 |
|
|
|
101 |
dataloader_val.batch_sampler.drop_last = False
|
102 |
elif cfg.dataset.type == 'shapenet_r2n2':
|
103 |
# from ..configs.structured import ShapeNetR2N2Config
|
104 |
+
from .r2n2_my import R2N2Sample
|
105 |
dataset_cfg: ShapeNetR2N2Config = cfg.dataset
|
106 |
# for k in dataset_cfg:
|
107 |
# print(k)
|
dataset/utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Iterable, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def show_item(item: Dict):
|
8 |
+
for key in item.keys():
|
9 |
+
value = item[key]
|
10 |
+
if torch.is_tensor(value) and value.numel() < 5:
|
11 |
+
value_str = value
|
12 |
+
elif torch.is_tensor(value):
|
13 |
+
value_str = value.shape
|
14 |
+
elif isinstance(value, str):
|
15 |
+
value_str = ('...' + value[-52:]) if len(value) > 50 else value
|
16 |
+
elif isinstance(value, dict):
|
17 |
+
value_str = str({k: type(v) for k, v in value.items()})
|
18 |
+
else:
|
19 |
+
value_str = type(value)
|
20 |
+
print(f"{key:<30} {value_str}")
|
21 |
+
|
22 |
+
|
23 |
+
def normalize_to_zero_one(x: torch.Tensor):
|
24 |
+
return (x - x.min()) / (x.max() - x.min())
|
25 |
+
|
26 |
+
|
27 |
+
def default(x, d):
|
28 |
+
return d if x is None else x
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class DatasetMap:
|
33 |
+
train: Optional[Iterable] = None
|
34 |
+
val: Optional[Iterable] = None
|
35 |
+
test: Optional[Iterable] = None
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def create_grid_points(bound=1.0, res=128):
|
40 |
+
x_ = np.linspace(-bound, bound, res)
|
41 |
+
y_ = np.linspace(-bound, bound, res)
|
42 |
+
z_ = np.linspace(-bound, bound, res)
|
43 |
+
|
44 |
+
x, y, z = np.meshgrid(x_, y_, z_)
|
45 |
+
# print(x.shape, y.shape) # (res, res, res)
|
46 |
+
pts = np.concatenate([y.reshape(-1, 1), x.reshape(-1, 1), z.reshape(-1, 1)], axis=-1)
|
47 |
+
return pts
|
requirements.txt
CHANGED
@@ -13,4 +13,5 @@ tqdm
|
|
13 |
transformers
|
14 |
wandb
|
15 |
trimesh
|
16 |
-
|
|
|
|
13 |
transformers
|
14 |
wandb
|
15 |
trimesh
|
16 |
+
gradio
|
17 |
+
"git+https://github.com/facebookresearch/pytorch3d.git"
|