deanna-emery's picture
updates
93528c6
|
raw
history blame
1.82 kB
# Using flags in official models
1. **All common flags must be incorporated in the models.**
Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
and channeled through `official.utils.flags.core`. For instance to define common supervised
learning parameters one could use the following code:
```$xslt
from absl import app as absl_app
from absl import flags
from official.utils.flags import core as flags_core
def define_flags():
flags_core.define_base()
flags.adopt_key_flags(flags_core)
def main(_):
flags_obj = flags.FLAGS
print(flags_obj)
if __name__ == "__main__"
absl_app.run(main)
```
2. **Validate flag values.**
See the [Validators](#validators) section for implementation details.
Validators in the official model repo should not access the file system, such as verifying
that files exist, due to the strict ordering requirements.
3. **Flag values should not be mutated.**
Instead of mutating flag values, use getter functions to return the desired values. An example
getter function is `get_tf_dtype` function below:
```
# Map string to TensorFlow dtype
DTYPE_MAP = {
"fp16": tf.float16,
"fp32": tf.float32,
}
def get_tf_dtype(flags_obj):
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
return DTYPE_MAP[flags_obj.dtype]
def main(_):
flags_obj = flags.FLAGS()
# Do not mutate flags_obj
# if flags_obj.fp16_implementation == "graph_rewrite":
# flags_obj.dtype = "float32" # Don't do this
print(get_tf_dtype(flags_obj))
...
```