# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import tensorflow as tf, tf_keras from official.modeling.hyperparams import base_config from official.modeling.hyperparams import oneof @dataclasses.dataclass class ResNet(base_config.Config): model_depth: int = 50 @dataclasses.dataclass class Backbone(oneof.OneOfConfig): type: str = 'resnet' resnet: ResNet = dataclasses.field(default_factory=ResNet) not_resnet: int = 2 @dataclasses.dataclass class OutputLayer(oneof.OneOfConfig): type: str = 'single' single: int = 1 multi_head: int = 2 @dataclasses.dataclass class Network(base_config.Config): backbone: Backbone = dataclasses.field(default_factory=Backbone) output_layer: OutputLayer = dataclasses.field(default_factory=OutputLayer) class OneOfTest(tf.test.TestCase): def test_to_dict(self): network_params = { 'backbone': { 'type': 'resnet', 'resnet': { 'model_depth': 50 } }, 'output_layer': { 'type': 'single', 'single': 1000 } } network_config = Network(network_params) self.assertEqual(network_config.as_dict(), network_params) def test_get_oneof(self): backbone = Backbone() self.assertIsInstance(backbone.get(), ResNet) self.assertEqual(backbone.get().as_dict(), {'model_depth': 50}) if __name__ == '__main__': tf.test.main()