Spaces:
Running
Running
File size: 3,727 Bytes
7088d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
SequenceAutodecoder,
)
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
NeuralRadianceFieldImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import (
AngleWeightedIdentityFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
remove_unused_components,
)
from tests.common_testing import get_tests_dir
from .common_resources import provide_resnet34
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
# Tests the use of the config system in implicitron
class TestGenericModel(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def test_create_gm(self):
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsNone(gm.global_encoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor)
def test_create_gm_overrides(self):
provide_resnet34()
args = get_default_args(GenericModel)
args.view_pooler_enabled = True
args.view_pooler_args.feature_aggregator_class_type = (
"AngleWeightedIdentityFeatureAggregator"
)
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
args.implicit_function_class_type = "IdrFeatureField"
args.global_encoder_class_type = "SequenceAutodecoder"
idr_args = args.implicit_function_IdrFeatureField_args
idr_args.n_harmonic_functions_xyz = 1729
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.view_pooler.feature_aggregator,
AngleWeightedIdentityFeatureAggregator,
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function"))
instance_args = OmegaConf.structured(gm)
if DEBUG:
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml)
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
(DATA_DIR / "overrides_.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|