Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
from omegaconf import OmegaConf | |
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( | |
ResNetFeatureExtractor, | |
) | |
from pytorch3d.implicitron.models.generic_model import GenericModel | |
from pytorch3d.implicitron.models.global_encoder.global_encoder import ( | |
SequenceAutodecoder, | |
) | |
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( | |
IdrFeatureField, | |
) | |
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( | |
NeuralRadianceFieldImplicitFunction, | |
) | |
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer | |
from pytorch3d.implicitron.models.renderer.multipass_ea import ( | |
MultiPassEmissionAbsorptionRenderer, | |
) | |
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import ( | |
AngleWeightedIdentityFeatureAggregator, | |
) | |
from pytorch3d.implicitron.tools.config import ( | |
get_default_args, | |
remove_unused_components, | |
) | |
from tests.common_testing import get_tests_dir | |
from .common_resources import provide_resnet34 | |
DATA_DIR = get_tests_dir() / "implicitron/data" | |
DEBUG: bool = False | |
# Tests the use of the config system in implicitron | |
class TestGenericModel(unittest.TestCase): | |
def setUp(self): | |
self.maxDiff = None | |
def test_create_gm(self): | |
args = get_default_args(GenericModel) | |
gm = GenericModel(**args) | |
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer) | |
self.assertIsInstance( | |
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction | |
) | |
self.assertIsNone(gm.global_encoder) | |
self.assertFalse(hasattr(gm, "implicit_function")) | |
self.assertIsNone(gm.view_pooler) | |
self.assertIsNone(gm.image_feature_extractor) | |
def test_create_gm_overrides(self): | |
provide_resnet34() | |
args = get_default_args(GenericModel) | |
args.view_pooler_enabled = True | |
args.view_pooler_args.feature_aggregator_class_type = ( | |
"AngleWeightedIdentityFeatureAggregator" | |
) | |
args.image_feature_extractor_class_type = "ResNetFeatureExtractor" | |
args.implicit_function_class_type = "IdrFeatureField" | |
args.global_encoder_class_type = "SequenceAutodecoder" | |
idr_args = args.implicit_function_IdrFeatureField_args | |
idr_args.n_harmonic_functions_xyz = 1729 | |
args.renderer_class_type = "LSTMRenderer" | |
gm = GenericModel(**args) | |
self.assertIsInstance(gm.renderer, LSTMRenderer) | |
self.assertIsInstance( | |
gm.view_pooler.feature_aggregator, | |
AngleWeightedIdentityFeatureAggregator, | |
) | |
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) | |
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729) | |
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder) | |
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) | |
self.assertFalse(hasattr(gm, "implicit_function")) | |
instance_args = OmegaConf.structured(gm) | |
if DEBUG: | |
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) | |
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml) | |
remove_unused_components(instance_args) | |
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) | |
if DEBUG: | |
(DATA_DIR / "overrides_.yaml").write_text(yaml) | |
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text()) | |