|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
from omegaconf import OmegaConf |
|
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( |
|
ResNetFeatureExtractor, |
|
) |
|
from pytorch3d.implicitron.models.generic_model import GenericModel |
|
from pytorch3d.implicitron.models.global_encoder.global_encoder import ( |
|
SequenceAutodecoder, |
|
) |
|
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( |
|
IdrFeatureField, |
|
) |
|
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( |
|
NeuralRadianceFieldImplicitFunction, |
|
) |
|
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer |
|
from pytorch3d.implicitron.models.renderer.multipass_ea import ( |
|
MultiPassEmissionAbsorptionRenderer, |
|
) |
|
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import ( |
|
AngleWeightedIdentityFeatureAggregator, |
|
) |
|
from pytorch3d.implicitron.tools.config import ( |
|
get_default_args, |
|
remove_unused_components, |
|
) |
|
from tests.common_testing import get_tests_dir |
|
|
|
from .common_resources import provide_resnet34 |
|
|
|
DATA_DIR = get_tests_dir() / "implicitron/data" |
|
DEBUG: bool = False |
|
|
|
|
|
|
|
|
|
class TestGenericModel(unittest.TestCase): |
|
def setUp(self): |
|
self.maxDiff = None |
|
|
|
def test_create_gm(self): |
|
args = get_default_args(GenericModel) |
|
gm = GenericModel(**args) |
|
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer) |
|
self.assertIsInstance( |
|
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction |
|
) |
|
self.assertIsNone(gm.global_encoder) |
|
self.assertFalse(hasattr(gm, "implicit_function")) |
|
self.assertIsNone(gm.view_pooler) |
|
self.assertIsNone(gm.image_feature_extractor) |
|
|
|
def test_create_gm_overrides(self): |
|
provide_resnet34() |
|
args = get_default_args(GenericModel) |
|
args.view_pooler_enabled = True |
|
args.view_pooler_args.feature_aggregator_class_type = ( |
|
"AngleWeightedIdentityFeatureAggregator" |
|
) |
|
args.image_feature_extractor_class_type = "ResNetFeatureExtractor" |
|
args.implicit_function_class_type = "IdrFeatureField" |
|
args.global_encoder_class_type = "SequenceAutodecoder" |
|
idr_args = args.implicit_function_IdrFeatureField_args |
|
idr_args.n_harmonic_functions_xyz = 1729 |
|
|
|
args.renderer_class_type = "LSTMRenderer" |
|
gm = GenericModel(**args) |
|
self.assertIsInstance(gm.renderer, LSTMRenderer) |
|
self.assertIsInstance( |
|
gm.view_pooler.feature_aggregator, |
|
AngleWeightedIdentityFeatureAggregator, |
|
) |
|
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) |
|
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729) |
|
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder) |
|
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) |
|
self.assertFalse(hasattr(gm, "implicit_function")) |
|
|
|
instance_args = OmegaConf.structured(gm) |
|
if DEBUG: |
|
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) |
|
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml) |
|
remove_unused_components(instance_args) |
|
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) |
|
if DEBUG: |
|
(DATA_DIR / "overrides_.yaml").write_text(yaml) |
|
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text()) |
|
|