Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Flags for managing compute devices. Currently only contains TPU flags.""" | |
from absl import flags | |
from absl import logging | |
from official.utils.flags._conventions import help_wrap | |
def require_cloud_storage(flag_names): | |
"""Register a validator to check directory flags. | |
Args: | |
flag_names: An iterable of strings containing the names of flags to be | |
checked. | |
""" | |
msg = "TPU requires GCS path for {}".format(", ".join(flag_names)) | |
def _path_check(flag_values): # pylint: disable=missing-docstring | |
if flag_values["tpu"] is None: | |
return True | |
valid_flags = True | |
for key in flag_names: | |
if not flag_values[key].startswith("gs://"): | |
logging.error("%s must be a GCS path.", key) | |
valid_flags = False | |
return valid_flags | |
def define_device(tpu=True): | |
"""Register device specific flags. | |
Args: | |
tpu: Create flags to specify TPU operation. | |
Returns: | |
A list of flags for core.py to marks as key flags. | |
""" | |
key_flags = [] | |
if tpu: | |
flags.DEFINE_string( | |
name="tpu", | |
default=None, | |
help=help_wrap( | |
"The Cloud TPU to use for training. This should be either the name " | |
"used when creating the Cloud TPU, or a " | |
"grpc://ip.address.of.tpu:8470 url. Passing `local` will use the" | |
"CPU of the local instance instead. (Good for debugging.)")) | |
key_flags.append("tpu") | |
flags.DEFINE_string( | |
name="tpu_zone", | |
default=None, | |
help=help_wrap( | |
"[Optional] GCE zone where the Cloud TPU is located in. If not " | |
"specified, we will attempt to automatically detect the GCE " | |
"project from metadata.")) | |
flags.DEFINE_string( | |
name="tpu_gcp_project", | |
default=None, | |
help=help_wrap( | |
"[Optional] Project name for the Cloud TPU-enabled project. If not " | |
"specified, we will attempt to automatically detect the GCE " | |
"project from metadata.")) | |
flags.DEFINE_integer( | |
name="num_tpu_shards", | |
default=8, | |
help=help_wrap("Number of shards (TPU chips).")) | |
return key_flags | |