Spaces:
Sleeping
Sleeping
Commit
·
f1a62b0
1
Parent(s):
bf95f5a
updates
Browse files- modeling/official/modeling/multitask/base_trainer.py +4 -4
- modeling/official/modeling/multitask/configs.py +3 -3
- modeling/official/modeling/multitask/evaluator.py +4 -4
- modeling/official/modeling/multitask/interleaving_trainer.py +5 -5
- modeling/official/modeling/multitask/multitask.py +7 -7
- modeling/official/modeling/multitask/task_sampler.py +1 -1
- modeling/official/modeling/multitask/test_utils.py +4 -4
- modeling/official/modeling/multitask/train_lib.py +11 -11
- modeling/official/modeling/privacy/configs.py +1 -1
- modeling/official/modeling/privacy/configs_test.py +1 -1
- modeling/official/modeling/privacy/ops_test.py +1 -1
modeling/official/modeling/multitask/base_trainer.py
CHANGED
@@ -19,12 +19,12 @@ The trainer derives from the Orbit `StandardTrainer` class.
|
|
19 |
from typing import Union
|
20 |
|
21 |
import gin
|
22 |
-
import orbit
|
23 |
import tensorflow as tf, tf_keras
|
24 |
|
25 |
-
from official.modeling import optimization
|
26 |
-
from official.modeling.multitask import base_model
|
27 |
-
from official.modeling.multitask import multitask
|
28 |
|
29 |
|
30 |
@gin.configurable
|
|
|
19 |
from typing import Union
|
20 |
|
21 |
import gin
|
22 |
+
import modeling.orbit
|
23 |
import tensorflow as tf, tf_keras
|
24 |
|
25 |
+
from modeling.official.modeling import optimization
|
26 |
+
from modeling.official.modeling.multitask import base_model
|
27 |
+
from modeling.official.modeling.multitask import multitask
|
28 |
|
29 |
|
30 |
@gin.configurable
|
modeling/official/modeling/multitask/configs.py
CHANGED
@@ -16,9 +16,9 @@
|
|
16 |
import dataclasses
|
17 |
from typing import Optional, Tuple
|
18 |
|
19 |
-
from official.core import config_definitions as cfg
|
20 |
-
from official.modeling import hyperparams
|
21 |
-
from official.modeling.privacy import configs as dp_configs
|
22 |
|
23 |
|
24 |
@dataclasses.dataclass
|
|
|
16 |
import dataclasses
|
17 |
from typing import Optional, Tuple
|
18 |
|
19 |
+
from modeling.official.core import config_definitions as cfg
|
20 |
+
from modeling.official.modeling import hyperparams
|
21 |
+
from modeling.official.modeling.privacy import configs as dp_configs
|
22 |
|
23 |
|
24 |
@dataclasses.dataclass
|
modeling/official/modeling/multitask/evaluator.py
CHANGED
@@ -18,12 +18,12 @@ The evaluator implements the Orbit `AbstractEvaluator` interface.
|
|
18 |
"""
|
19 |
from typing import Dict, List, Optional, Union
|
20 |
import gin
|
21 |
-
import orbit
|
22 |
import tensorflow as tf, tf_keras
|
23 |
|
24 |
-
from official.core import base_task
|
25 |
-
from official.core import train_utils
|
26 |
-
from official.modeling.multitask import base_model
|
27 |
|
28 |
|
29 |
@gin.configurable
|
|
|
18 |
"""
|
19 |
from typing import Dict, List, Optional, Union
|
20 |
import gin
|
21 |
+
import modeling.orbit
|
22 |
import tensorflow as tf, tf_keras
|
23 |
|
24 |
+
from modeling.official.core import base_task
|
25 |
+
from modeling.official.core import train_utils
|
26 |
+
from modeling.official.modeling.multitask import base_model
|
27 |
|
28 |
|
29 |
@gin.configurable
|
modeling/official/modeling/multitask/interleaving_trainer.py
CHANGED
@@ -15,12 +15,12 @@
|
|
15 |
"""Multitask trainer that interleaves each task's train step."""
|
16 |
from typing import Union
|
17 |
import gin
|
18 |
-
import orbit
|
19 |
import tensorflow as tf, tf_keras
|
20 |
-
from official.modeling.multitask import base_model
|
21 |
-
from official.modeling.multitask import base_trainer
|
22 |
-
from official.modeling.multitask import multitask
|
23 |
-
from official.modeling.multitask import task_sampler as sampler
|
24 |
|
25 |
|
26 |
@gin.configurable
|
|
|
15 |
"""Multitask trainer that interleaves each task's train step."""
|
16 |
from typing import Union
|
17 |
import gin
|
18 |
+
import modeling.orbit
|
19 |
import tensorflow as tf, tf_keras
|
20 |
+
from modeling.official.modeling.multitask import base_model
|
21 |
+
from modeling.official.modeling.multitask import base_trainer
|
22 |
+
from modeling.official.modeling.multitask import multitask
|
23 |
+
from modeling.official.modeling.multitask import task_sampler as sampler
|
24 |
|
25 |
|
26 |
@gin.configurable
|
modeling/official/modeling/multitask/multitask.py
CHANGED
@@ -17,13 +17,13 @@ import abc
|
|
17 |
from typing import Dict, List, Optional, Text, Union
|
18 |
|
19 |
import tensorflow as tf, tf_keras
|
20 |
-
from official.core import base_task
|
21 |
-
from official.core import config_definitions
|
22 |
-
from official.core import task_factory
|
23 |
-
from official.modeling import optimization
|
24 |
-
from official.modeling.multitask import base_model
|
25 |
-
from official.modeling.multitask import configs
|
26 |
-
from official.modeling.privacy import configs as dp_configs
|
27 |
|
28 |
OptimizationConfig = optimization.OptimizationConfig
|
29 |
RuntimeConfig = config_definitions.RuntimeConfig
|
|
|
17 |
from typing import Dict, List, Optional, Text, Union
|
18 |
|
19 |
import tensorflow as tf, tf_keras
|
20 |
+
from modeling.official.core import base_task
|
21 |
+
from modeling.official.core import config_definitions
|
22 |
+
from modeling.official.core import task_factory
|
23 |
+
from modeling.official.modeling import optimization
|
24 |
+
from modeling.official.modeling.multitask import base_model
|
25 |
+
from modeling.official.modeling.multitask import configs
|
26 |
+
from modeling.official.modeling.privacy import configs as dp_configs
|
27 |
|
28 |
OptimizationConfig = optimization.OptimizationConfig
|
29 |
RuntimeConfig = config_definitions.RuntimeConfig
|
modeling/official/modeling/multitask/task_sampler.py
CHANGED
@@ -17,7 +17,7 @@ import abc
|
|
17 |
from typing import Union, Dict, Text
|
18 |
import tensorflow as tf, tf_keras
|
19 |
|
20 |
-
from official.modeling.multitask import configs
|
21 |
|
22 |
|
23 |
class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
|
|
|
17 |
from typing import Union, Dict, Text
|
18 |
import tensorflow as tf, tf_keras
|
19 |
|
20 |
+
from modeling.official.modeling.multitask import configs
|
21 |
|
22 |
|
23 |
class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
|
modeling/official/modeling/multitask/test_utils.py
CHANGED
@@ -15,10 +15,10 @@
|
|
15 |
"""Testing utils for mock models and tasks."""
|
16 |
from typing import Dict, Text
|
17 |
import tensorflow as tf, tf_keras
|
18 |
-
from official.core import base_task
|
19 |
-
from official.core import config_definitions as cfg
|
20 |
-
from official.core import task_factory
|
21 |
-
from official.modeling.multitask import base_model
|
22 |
|
23 |
|
24 |
class MockFooModel(tf_keras.Model):
|
|
|
15 |
"""Testing utils for mock models and tasks."""
|
16 |
from typing import Dict, Text
|
17 |
import tensorflow as tf, tf_keras
|
18 |
+
from modeling.official.core import base_task
|
19 |
+
from modeling.official.core import config_definitions as cfg
|
20 |
+
from modeling.official.core import task_factory
|
21 |
+
from modeling.official.modeling.multitask import base_model
|
22 |
|
23 |
|
24 |
class MockFooModel(tf_keras.Model):
|
modeling/official/modeling/multitask/train_lib.py
CHANGED
@@ -17,18 +17,18 @@
|
|
17 |
import os
|
18 |
from typing import Any, List, Mapping, Optional, Tuple, Union
|
19 |
from absl import logging
|
20 |
-
import orbit
|
21 |
import tensorflow as tf, tf_keras
|
22 |
-
from official.core import base_task
|
23 |
-
from official.core import base_trainer as core_lib
|
24 |
-
from official.core import train_utils
|
25 |
-
from official.modeling.multitask import base_model
|
26 |
-
from official.modeling.multitask import base_trainer
|
27 |
-
from official.modeling.multitask import configs
|
28 |
-
from official.modeling.multitask import evaluator as evaluator_lib
|
29 |
-
from official.modeling.multitask import interleaving_trainer
|
30 |
-
from official.modeling.multitask import multitask
|
31 |
-
from official.modeling.multitask import task_sampler
|
32 |
|
33 |
TRAINERS = {
|
34 |
'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer,
|
|
|
17 |
import os
|
18 |
from typing import Any, List, Mapping, Optional, Tuple, Union
|
19 |
from absl import logging
|
20 |
+
import modeling.orbit
|
21 |
import tensorflow as tf, tf_keras
|
22 |
+
from modeling.official.core import base_task
|
23 |
+
from modeling.official.core import base_trainer as core_lib
|
24 |
+
from modeling.official.core import train_utils
|
25 |
+
from modeling.official.modeling.multitask import base_model
|
26 |
+
from modeling.official.modeling.multitask import base_trainer
|
27 |
+
from modeling.official.modeling.multitask import configs
|
28 |
+
from modeling.official.modeling.multitask import evaluator as evaluator_lib
|
29 |
+
from modeling.official.modeling.multitask import interleaving_trainer
|
30 |
+
from modeling.official.modeling.multitask import multitask
|
31 |
+
from modeling.official.modeling.multitask import task_sampler
|
32 |
|
33 |
TRAINERS = {
|
34 |
'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer,
|
modeling/official/modeling/privacy/configs.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
"""Configs for differential privacy."""
|
16 |
import dataclasses
|
17 |
|
18 |
-
from official.modeling.hyperparams import base_config
|
19 |
|
20 |
|
21 |
@dataclasses.dataclass
|
|
|
15 |
"""Configs for differential privacy."""
|
16 |
import dataclasses
|
17 |
|
18 |
+
from modeling.official.modeling.hyperparams import base_config
|
19 |
|
20 |
|
21 |
@dataclasses.dataclass
|
modeling/official/modeling/privacy/configs_test.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
"""Tests for configs."""
|
16 |
|
17 |
import tensorflow as tf, tf_keras
|
18 |
-
from official.modeling.privacy import configs
|
19 |
|
20 |
|
21 |
class ConfigsTest(tf.test.TestCase):
|
|
|
15 |
"""Tests for configs."""
|
16 |
|
17 |
import tensorflow as tf, tf_keras
|
18 |
+
from modeling.official.modeling.privacy import configs
|
19 |
|
20 |
|
21 |
class ConfigsTest(tf.test.TestCase):
|
modeling/official/modeling/privacy/ops_test.py
CHANGED
@@ -18,7 +18,7 @@ from unittest import mock
|
|
18 |
|
19 |
import tensorflow as tf, tf_keras
|
20 |
|
21 |
-
from official.modeling.privacy import ops
|
22 |
|
23 |
|
24 |
class OpsTest(tf.test.TestCase):
|
|
|
18 |
|
19 |
import tensorflow as tf, tf_keras
|
20 |
|
21 |
+
from modeling.official.modeling.privacy import ops
|
22 |
|
23 |
|
24 |
class OpsTest(tf.test.TestCase):
|