Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import unittest | |
import torch | |
from omegaconf import DictConfig, OmegaConf | |
from pytorch3d.implicitron.models.generic_model import GenericModel | |
from pytorch3d.implicitron.models.renderer.base import EvaluationMode | |
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args | |
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras | |
from tests.common_testing import get_pytorch3d_dir | |
from .common_resources import provide_resnet34 | |
IMPLICITRON_CONFIGS_DIR = ( | |
get_pytorch3d_dir() / "projects" / "implicitron_trainer" / "configs" | |
) | |
class TestGenericModel(unittest.TestCase): | |
def setUpClass(cls) -> None: | |
provide_resnet34() | |
def setUp(self): | |
torch.manual_seed(42) | |
def test_gm(self): | |
# Simple test of a forward and backward pass of the default GenericModel. | |
device = torch.device("cuda:0") | |
expand_args_fields(GenericModel) | |
model = GenericModel(render_image_height=80, render_image_width=80) | |
model.to(device) | |
self._one_model_test(model, device) | |
def test_all_gm_configs(self): | |
# Tests all model settings in the implicitron_trainer config folder. | |
device = torch.device("cuda:0") | |
config_files = [] | |
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): | |
config_files.extend( | |
[ | |
f | |
for f in IMPLICITRON_CONFIGS_DIR.glob(pattern) | |
if not f.name.endswith("_base.yaml") | |
] | |
) | |
for config_file in config_files: | |
with self.subTest(name=config_file.stem): | |
cfg = _load_model_config_from_yaml(str(config_file)) | |
cfg.render_image_height = 80 | |
cfg.render_image_width = 80 | |
model = GenericModel(**cfg) | |
model.to(device) | |
self._one_model_test( | |
model, | |
device, | |
eval_test=True, | |
bw_test=True, | |
) | |
def _one_model_test( | |
self, | |
model, | |
device, | |
n_train_cameras: int = 5, | |
eval_test: bool = True, | |
bw_test: bool = True, | |
): | |
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) | |
cameras = PerspectiveCameras(R=R, T=T, device=device) | |
N, H, W = n_train_cameras, model.render_image_height, model.render_image_width | |
random_args = { | |
"camera": cameras, | |
"fg_probability": _random_input_tensor(N, 1, H, W, True, device), | |
"depth_map": _random_input_tensor(N, 1, H, W, False, device) + 0.1, | |
"mask_crop": _random_input_tensor(N, 1, H, W, True, device), | |
"sequence_name": ["sequence"] * N, | |
"image_rgb": _random_input_tensor(N, 3, H, W, False, device), | |
} | |
# training foward pass | |
model.train() | |
train_preds = model( | |
**random_args, | |
evaluation_mode=EvaluationMode.TRAINING, | |
) | |
self.assertTrue( | |
train_preds["objective"].isfinite().item() | |
) # check finiteness of the objective | |
if bw_test: | |
train_preds["objective"].backward() | |
if eval_test: | |
model.eval() | |
with torch.no_grad(): | |
eval_preds = model( | |
**random_args, | |
evaluation_mode=EvaluationMode.EVALUATION, | |
) | |
self.assertEqual( | |
eval_preds["images_render"].shape, | |
(1, 3, model.render_image_height, model.render_image_width), | |
) | |
def test_idr(self): | |
# Forward pass of GenericModel with IDR. | |
device = torch.device("cuda:0") | |
args = get_default_args(GenericModel) | |
args.renderer_class_type = "SignedDistanceFunctionRenderer" | |
args.implicit_function_class_type = "IdrFeatureField" | |
args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6 | |
model = GenericModel(**args) | |
model.to(device) | |
n_train_cameras = 2 | |
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) | |
cameras = PerspectiveCameras(R=R, T=T, device=device) | |
defaulted_args = { | |
"depth_map": None, | |
"mask_crop": None, | |
"sequence_name": None, | |
} | |
target_image_rgb = torch.rand( | |
(n_train_cameras, 3, model.render_image_height, model.render_image_width), | |
device=device, | |
) | |
fg_probability = torch.rand( | |
(n_train_cameras, 1, model.render_image_height, model.render_image_width), | |
device=device, | |
) | |
train_preds = model( | |
camera=cameras, | |
evaluation_mode=EvaluationMode.TRAINING, | |
image_rgb=target_image_rgb, | |
fg_probability=fg_probability, | |
**defaulted_args, | |
) | |
self.assertGreater(train_preds["objective"].item(), 0) | |
def test_viewpool(self): | |
device = torch.device("cuda:0") | |
args = get_default_args(GenericModel) | |
args.view_pooler_enabled = True | |
args.image_feature_extractor_class_type = "ResNetFeatureExtractor" | |
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False | |
model = GenericModel(**args) | |
model.to(device) | |
n_train_cameras = 2 | |
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) | |
cameras = PerspectiveCameras(R=R, T=T, device=device) | |
defaulted_args = { | |
"fg_probability": None, | |
"depth_map": None, | |
"mask_crop": None, | |
} | |
target_image_rgb = torch.rand( | |
(n_train_cameras, 3, model.render_image_height, model.render_image_width), | |
device=device, | |
) | |
train_preds = model( | |
camera=cameras, | |
evaluation_mode=EvaluationMode.TRAINING, | |
image_rgb=target_image_rgb, | |
sequence_name=["a"] * n_train_cameras, | |
**defaulted_args, | |
) | |
self.assertGreater(train_preds["objective"].item(), 0) | |
def _random_input_tensor( | |
N: int, | |
C: int, | |
H: int, | |
W: int, | |
is_binary: bool, | |
device: torch.device, | |
) -> torch.Tensor: | |
T = torch.rand(N, C, H, W, device=device) | |
if is_binary: | |
T = (T > 0.5).float() | |
return T | |
def _load_model_config_from_yaml(config_path, strict=True) -> DictConfig: | |
default_cfg = get_default_args(GenericModel) | |
cfg = _load_model_config_from_yaml_rec(default_cfg, config_path) | |
return cfg | |
def _load_model_config_from_yaml_rec(cfg: DictConfig, config_path: str) -> DictConfig: | |
cfg_loaded = OmegaConf.load(config_path) | |
cfg_model_loaded = None | |
if "model_factory_ImplicitronModelFactory_args" in cfg_loaded: | |
factory_args = cfg_loaded.model_factory_ImplicitronModelFactory_args | |
if "model_GenericModel_args" in factory_args: | |
cfg_model_loaded = factory_args.model_GenericModel_args | |
defaults = cfg_loaded.pop("defaults", None) | |
if defaults is not None: | |
for default_name in defaults: | |
if default_name in ("_self_", "default_config"): | |
continue | |
default_name = os.path.splitext(default_name)[0] | |
defpath = os.path.join(os.path.dirname(config_path), default_name + ".yaml") | |
cfg = _load_model_config_from_yaml_rec(cfg, defpath) | |
if cfg_model_loaded is not None: | |
cfg = OmegaConf.merge(cfg, cfg_model_loaded) | |
elif cfg_model_loaded is not None: | |
cfg = OmegaConf.merge(cfg, cfg_model_loaded) | |
return cfg | |