Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Helper code to run complete models from within python.""" | |
import os | |
import shutil | |
import sys | |
import tempfile | |
from absl import flags | |
from absl.testing import flagsaver | |
from official.utils.flags import core as flags_core | |
def run_synthetic(main, | |
tmp_root, | |
extra_flags=None, | |
synth=True, | |
train_epochs=1, | |
epochs_between_evals=1): | |
"""Performs a minimal run of a model. | |
This function is intended to test for syntax errors throughout a model. A | |
very limited run is performed using synthetic data. | |
Args: | |
main: The primary function used to exercise a code path. Generally this | |
function is "<MODULE>.main(argv)". | |
tmp_root: Root path for the temp directory created by the test class. | |
extra_flags: Additional flags passed by the caller of this function. | |
synth: Use synthetic data. | |
train_epochs: Value of the --train_epochs flag. | |
epochs_between_evals: Value of the --epochs_between_evals flag. | |
""" | |
extra_flags = [] if extra_flags is None else extra_flags | |
model_dir = tempfile.mkdtemp(dir=tmp_root) | |
args = [sys.argv[0], "--model_dir", model_dir] + extra_flags | |
if synth: | |
args.append("--use_synthetic_data") | |
if train_epochs is not None: | |
args.extend(["--train_epochs", str(train_epochs)]) | |
if epochs_between_evals is not None: | |
args.extend(["--epochs_between_evals", str(epochs_between_evals)]) | |
try: | |
flags_core.parse_flags(argv=args) | |
main(flags.FLAGS) | |
finally: | |
if os.path.exists(model_dir): | |
shutil.rmtree(model_dir) | |