|
from __gin__ import dynamic_registration |
|
import t5x |
|
import seqio |
|
import t5.data.mixtures |
|
from t5x import utils |
|
from t5x import models |
|
import t5x.infer |
|
from t5x import partitioning |
|
import rephrase_inference |
|
|
|
include 't5x/configs/runs/infer.gin' |
|
include 't5x/examples/t5/t5_1_1/base.gin' |
|
|
|
MIXTURE_OR_TASK_NAME = "hf_inference_task" |
|
TASK_FEATURE_LENGTHS = {'inputs': 1024, 'targets': 200} |
|
CHECKPOINT_PATH = "./checkpoint_1020000" |
|
USE_CACHED_TASKS = False |
|
DROPOUT_RATE = 0.0 |
|
|
|
|
|
partitioning.PjitPartitioner: |
|
num_partitions = 4 |
|
|
|
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1 |
|
|
|
infer_eval/utils.DatasetConfig: |
|
mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
|
task_feature_lengths = %TASK_FEATURE_LENGTHS |
|
split = %FILE_PATH |
|
batch_size = 32 |
|
shuffle = False |
|
seed = 42 |
|
pack = False |
|
use_cached = False |
|
module = None |
|
|
|
utils.DatasetConfig: |
|
split = %FILE_PATH |
|
|
|
seqio.Evaluator: |
|
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] |
|
num_examples = None |
|
use_memory_cache = True |
|
|
|
FILE_PATH = %gin.REQUIRED |