Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Common flags for SuperGLUE finetuning binary.""" | |
from typing import Callable | |
from absl import flags | |
from absl import logging | |
def define_flags(): | |
"""Defines flags.""" | |
# =========================================================================== | |
# SuperGlue binary flags. | |
# =========================================================================== | |
flags.DEFINE_enum( | |
'mode', 'train_eval_and_predict', | |
['train_eval_and_predict', 'train_eval', 'predict'], | |
'The mode to run the binary. If `train_eval_and_predict` ' | |
'it will (1) train on the training data and (2) evaluate on ' | |
'the validation data and (3) finally generate predictions ' | |
'on the prediction data; if `train_eval`, it will only ' | |
'run training and evaluation; if `predict`, it will only ' | |
'run prediction using the model in `model_dir`.') | |
flags.DEFINE_enum('task_name', None, [ | |
'AX-b', | |
'CB', | |
'COPA', | |
'MULTIRC', | |
'RTE', | |
'WiC', | |
'WSC', | |
'BoolQ', | |
'ReCoRD', | |
'AX-g', | |
], 'The type of SuperGLUE task.') | |
flags.DEFINE_string('train_input_path', None, | |
'The file path to the training data.') | |
flags.DEFINE_string('validation_input_path', None, | |
'The file path to the evaluation data.') | |
flags.DEFINE_string('test_input_path', None, | |
'The file path to the test input data.') | |
flags.DEFINE_string('test_output_path', None, | |
'The file path to the test output data.') | |
flags.DEFINE_string( | |
'model_dir', '', 'The model directory containing ' | |
'subdirectories for each task. Only needed for "predict" ' | |
'mode. For all other modes, if not provided, a unique ' | |
'directory will be created automatically for each run.') | |
flags.DEFINE_string( | |
'input_meta_data_path', None, 'Path to file that contains ' | |
'metadata about input file. It is output by the `create_finetuning_data` ' | |
'binary. Required for all modes except "predict".') | |
flags.DEFINE_string('init_checkpoint', '', | |
'Initial checkpoint from a pre-trained BERT model.') | |
flags.DEFINE_string( | |
'model_config_file', '', 'The config file specifying the architecture ' | |
'of the pre-trained model. The file can be either a bert_config.json ' | |
'file or `encoders.EncoderConfig` in yaml file.') | |
flags.DEFINE_string( | |
'hub_module_url', '', 'TF-Hub path/url to a pretrained model. If ' | |
'specified, `init_checkpoint` and `model_config_file` flag should not be ' | |
'used.') | |
flags.DEFINE_multi_string('gin_file', None, | |
'List of paths to the gin config files.') | |
flags.DEFINE_multi_string( | |
'gin_params', None, 'Newline separated list of gin parameter bindings.') | |
flags.DEFINE_multi_string( | |
'config_file', None, 'This is the advanced usage to specify the ' | |
'`ExperimentConfig` directly. When specified, ' | |
'we will ignore FLAGS related to `ExperimentConfig` such as ' | |
'`train_input_path`, `validation_input_path` and following hparams.') | |
# =========================================================================== | |
# Tuning hparams. | |
# =========================================================================== | |
flags.DEFINE_integer('global_batch_size', 32, | |
'Global batch size for train/eval/predict.') | |
flags.DEFINE_float('learning_rate', 3e-5, 'Initial learning rate.') | |
flags.DEFINE_integer('num_epoch', 3, 'Number of training epochs.') | |
flags.DEFINE_float('warmup_ratio', 0.1, | |
'Proportion of learning rate warmup steps.') | |
flags.DEFINE_integer('num_eval_per_epoch', 2, | |
'Number of evaluations to run per epoch.') | |
def validate_flags(flags_obj: flags.FlagValues, file_exists_fn: Callable[[str], | |
bool]): | |
"""Raises ValueError if any flags are misconfigured. | |
Args: | |
flags_obj: A `flags.FlagValues` object, usually from `flags.FLAG`. | |
file_exists_fn: A callable to decide if a file path exists or not. | |
""" | |
def _check_path_exists(flag_path, flag_name): | |
if not file_exists_fn(flag_path): | |
raise ValueError('Flag `%s` at %s does not exist.' % | |
(flag_name, flag_path)) | |
def _validate_path(flag_path, flag_name): | |
if not flag_path: | |
raise ValueError('Flag `%s` must be provided in mode %s.' % | |
(flag_name, flags_obj.mode)) | |
_check_path_exists(flag_path, flag_name) | |
if 'train' in flags_obj.mode: | |
_validate_path(flags_obj.train_input_path, 'train_input_path') | |
_validate_path(flags_obj.input_meta_data_path, 'input_meta_data_path') | |
if flags_obj.gin_file: | |
for gin_file in flags_obj.gin_file: | |
_check_path_exists(gin_file, 'gin_file') | |
if flags_obj.config_file: | |
for config_file in flags_obj.config_file: | |
_check_path_exists(config_file, 'config_file') | |
if 'eval' in flags_obj.mode: | |
_validate_path(flags_obj.validation_input_path, 'validation_input_path') | |
if flags_obj.mode == 'predict': | |
# model_dir is only needed strictly in 'predict' mode. | |
_validate_path(flags_obj.model_dir, 'model_dir') | |
if 'predict' in flags_obj.mode: | |
_validate_path(flags_obj.test_input_path, 'test_input_path') | |
if not flags_obj.config_file and flags_obj.mode != 'predict': | |
if flags_obj.hub_module_url: | |
if flags_obj.init_checkpoint or flags_obj.model_config_file: | |
raise ValueError( | |
'When `hub_module_url` is specified, `init_checkpoint` and ' | |
'`model_config_file` should be empty.') | |
logging.info('Using the pretrained tf.hub from %s', | |
flags_obj.hub_module_url) | |
else: | |
if not (flags_obj.init_checkpoint and flags_obj.model_config_file): | |
raise ValueError('Both `init_checkpoint` and `model_config_file` ' | |
'should be specified if `config_file` is not ' | |
'specified.') | |
_validate_path(flags_obj.model_config_file, 'model_config_file') | |
logging.info( | |
'Using the pretrained checkpoint from %s and model_config_file from ' | |
'%s.', flags_obj.init_checkpoint, flags_obj.model_config_file) | |