Spaces:
Runtime error
Runtime error
# Using flags in official models | |
1. **All common flags must be incorporated in the models.** | |
Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions, | |
and channeled through `official.utils.flags.core`. For instance to define common supervised | |
learning parameters one could use the following code: | |
```$xslt | |
from absl import app as absl_app | |
from absl import flags | |
from official.utils.flags import core as flags_core | |
def define_flags(): | |
flags_core.define_base() | |
flags.adopt_key_flags(flags_core) | |
def main(_): | |
flags_obj = flags.FLAGS | |
print(flags_obj) | |
if __name__ == "__main__" | |
absl_app.run(main) | |
``` | |
2. **Validate flag values.** | |
See the [Validators](#validators) section for implementation details. | |
Validators in the official model repo should not access the file system, such as verifying | |
that files exist, due to the strict ordering requirements. | |
3. **Flag values should not be mutated.** | |
Instead of mutating flag values, use getter functions to return the desired values. An example | |
getter function is `get_tf_dtype` function below: | |
``` | |
# Map string to TensorFlow dtype | |
DTYPE_MAP = { | |
"fp16": tf.float16, | |
"fp32": tf.float32, | |
} | |
def get_tf_dtype(flags_obj): | |
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite": | |
# If the graph_rewrite is used, we build the graph with fp32, and let the | |
# graph rewrite change ops to fp16. | |
return tf.float32 | |
return DTYPE_MAP[flags_obj.dtype] | |
def main(_): | |
flags_obj = flags.FLAGS() | |
# Do not mutate flags_obj | |
# if flags_obj.fp16_implementation == "graph_rewrite": | |
# flags_obj.dtype = "float32" # Don't do this | |
print(get_tf_dtype(flags_obj)) | |
... | |
``` |