|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Flags for managing compute devices. Currently only contains TPU flags.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from absl import flags |
|
from absl import logging |
|
|
|
from official.utils.flags._conventions import help_wrap |
|
|
|
|
|
def require_cloud_storage(flag_names): |
|
"""Register a validator to check directory flags. |
|
Args: |
|
flag_names: An iterable of strings containing the names of flags to be |
|
checked. |
|
""" |
|
msg = "TPU requires GCS path for {}".format(", ".join(flag_names)) |
|
@flags.multi_flags_validator(["tpu"] + flag_names, message=msg) |
|
def _path_check(flag_values): |
|
if flag_values["tpu"] is None: |
|
return True |
|
|
|
valid_flags = True |
|
for key in flag_names: |
|
if not flag_values[key].startswith("gs://"): |
|
logging.error("%s must be a GCS path.", key) |
|
valid_flags = False |
|
|
|
return valid_flags |
|
|
|
|
|
def define_device(tpu=True): |
|
"""Register device specific flags. |
|
Args: |
|
tpu: Create flags to specify TPU operation. |
|
Returns: |
|
A list of flags for core.py to marks as key flags. |
|
""" |
|
|
|
key_flags = [] |
|
|
|
if tpu: |
|
flags.DEFINE_string( |
|
name="tpu", default=None, |
|
help=help_wrap( |
|
"The Cloud TPU to use for training. This should be either the name " |
|
"used when creating the Cloud TPU, or a " |
|
"grpc://ip.address.of.tpu:8470 url. Passing `local` will use the" |
|
"CPU of the local instance instead. (Good for debugging.)")) |
|
key_flags.append("tpu") |
|
|
|
flags.DEFINE_string( |
|
name="tpu_zone", default=None, |
|
help=help_wrap( |
|
"[Optional] GCE zone where the Cloud TPU is located in. If not " |
|
"specified, we will attempt to automatically detect the GCE " |
|
"project from metadata.")) |
|
|
|
flags.DEFINE_string( |
|
name="tpu_gcp_project", default=None, |
|
help=help_wrap( |
|
"[Optional] Project name for the Cloud TPU-enabled project. If not " |
|
"specified, we will attempt to automatically detect the GCE " |
|
"project from metadata.")) |
|
|
|
flags.DEFINE_integer(name="num_tpu_shards", default=8, |
|
help=help_wrap("Number of shards (TPU chips).")) |
|
|
|
return key_flags |
|
|