Spaces:
Running
Running
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Create LM TF examples for XLNet.""" | |
import dataclasses | |
import json | |
import math | |
import os | |
import random | |
from typing import Iterable, Mapping, List, Optional, Tuple | |
import unicodedata | |
# Import libraries | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
from official.nlp.tools import tokenization | |
special_symbols = { | |
"<unk>": 0, | |
"<s>": 1, | |
"</s>": 2, | |
"<cls>": 3, | |
"<sep>": 4, | |
"<pad>": 5, | |
"<mask>": 6, | |
"<eod>": 7, | |
"<eop>": 8, | |
} | |
FLAGS = flags.FLAGS | |
flags.DEFINE_integer("seq_length", 512, | |
help="Sequence length.") | |
flags.DEFINE_integer("reuse_length", 256, | |
help="Number of token that can be reused as memory. " | |
"Could be half of `seq_len`.") | |
flags.DEFINE_string("input_file", None, | |
"Input raw text file (or comma-separated list of files).") | |
flags.DEFINE_string( | |
"save_dir", None, | |
"Directory for saving processed data.") | |
flags.DEFINE_string("sp_model_file", "", | |
"The path to the model used by sentence piece tokenizer.") | |
flags.DEFINE_bool("use_eod_token", True, | |
"Whether or not to include EOD tokens.") | |
flags.DEFINE_bool("bi_data", True, "Whether or not to use bi-directional data.") | |
flags.DEFINE_bool( | |
"do_lower_case", True, | |
"Whether to lower case the input text. Should be True for uncased " | |
"models and False for cased models.") | |
flags.DEFINE_integer("per_host_batch_size", 32, "Batch size per host.") | |
flags.DEFINE_integer("num_cores_per_host", 16, | |
"The number of (TPU) cores per host.") | |
flags.DEFINE_string("prefix", "", "Filename prefix.") | |
flags.DEFINE_string("suffix", "", "Filename suffix.") | |
flags.DEFINE_integer("task_id", None, | |
"The id of the current task.") | |
flags.DEFINE_integer("num_tasks", None, | |
"The total number of tasks.") | |
flags.DEFINE_integer("num_passes", 1, "The number of times to run the script.") | |
class TrainingInstance: | |
"""Representation of a single XLNet Pretraining instance.""" | |
data: Iterable[int] | |
segment_ids: Iterable[int] | |
boundary_indices: Iterable[int] | |
label: int | |
def to_feature(self) -> Mapping[str, tf.train.Feature]: | |
feat = lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=x)) | |
return dict( | |
input_word_ids=feat(self.data), | |
input_type_ids=feat(self.segment_ids), | |
boundary_indices=feat(self.boundary_indices), | |
label=feat([self.label])) | |
def to_example(self) -> tf.train.Example: | |
return tf.train.Example( | |
features=tf.train.Features(feature=self.to_feature())) | |
def __str__(self): | |
def seq_to_str(seq): | |
return " ".join([str(x) for x in seq]) | |
s = "" | |
s += "tokens: %s\n" % seq_to_str(self.data) | |
s += "segment_ids: %s\n" % seq_to_str(self.segment_ids) | |
s += "boundary_indices: %s\n" % seq_to_str(self.boundary_indices) | |
s += "label: %s\n" % self.label | |
s += "\n" | |
return s | |
def __repr__(self): | |
return self.__str__() | |
def _preprocess_line(line: str, do_lower_case: bool = False) -> str: | |
"""Preprocesses an individual raw text line. | |
This function will: | |
- Remove extraneous spaces. | |
- Replace `` with ", and '' with ". | |
- Replaces accents. | |
- Applies lower casing. | |
Args: | |
line: The input line to preprocess. | |
do_lower_case: Whether or not to lower case the text. | |
Returns: | |
The preprocessed line. | |
""" | |
line = " ".join(line.split()) | |
line = line.replace("``", "\"").replace("''", "\"") | |
# Replace accents. | |
line = unicodedata.normalize("NFKD", line) | |
line = "".join([c for c in line if not unicodedata.combining(c)]) | |
if do_lower_case: | |
line = line.lower() | |
return line | |
def preprocess_and_tokenize_input_files( | |
input_files: Iterable[str], | |
tokenizer: tokenization.FullSentencePieceTokenizer, | |
use_eod: bool = True, | |
do_lower_case: bool = False, | |
log_example_freq: int = 100000) -> List[Tuple[np.array, np.array]]: | |
"""Preprocesses and encodes raw text from input files. | |
This function preprocesses raw text and encodes them into tokens using a | |
`SentencePieceModel` tokenization method. This also provides the sentence | |
indicator for each token. | |
Args: | |
input_files: The list of input file names. | |
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`. | |
use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is | |
not included. | |
do_lower_case: Whether or not to apply lower casing during raw text | |
preprocessing. | |
log_example_freq: The optional field for how many lines to process before | |
emitting an info log. | |
Returns: | |
The preprocessed list. Each entry in the list is a tuple consisting of | |
the token IDs and the sentence IDs. | |
""" | |
all_data = [] | |
eod_symbol = special_symbols["<eod>"] | |
total_number_of_lines = 0 | |
# Input file format: | |
# (1) One sentence per line. These should ideally be actual sentences, not | |
# entire paragraphs or arbitrary spans of text. (Because we use the | |
# sentence boundaries for the "next sentence prediction" task). | |
# (2) Blank lines between documents. Document boundaries are needed so | |
# that the "next sentence prediction" task doesn't span between documents. | |
for input_file in input_files: | |
line_count = 0 | |
logging.info("Preprocessing %s", input_file) | |
all_tokens = [] | |
all_sentence_ids = [] | |
sentence_id = True | |
with tf.io.gfile.GFile(input_file, "rb") as reader: | |
while True: | |
line = tokenization.convert_to_unicode(reader.readline()) | |
if not line: | |
break | |
line_count += 1 | |
if line_count % log_example_freq == 0: | |
logging.info("Loading line %d", line_count) | |
line = line.strip() | |
if not line: | |
if use_eod: | |
token_ids = [eod_symbol] | |
sentence_id = not sentence_id | |
else: | |
continue | |
else: | |
preprocessed_line = _preprocess_line( | |
line=line, do_lower_case=do_lower_case) | |
token_ids = tokenization.encode_ids( | |
sp_model=tokenizer.sp_model, text=preprocessed_line) | |
all_tokens.extend(token_ids) | |
all_sentence_ids.extend([sentence_id] * len(token_ids)) | |
sentence_id = not sentence_id | |
logging.info("Finished processing %s. Number of lines: %d", | |
input_file, line_count) | |
if line_count == 0: | |
continue | |
total_number_of_lines += line_count | |
all_tokens = np.array(all_tokens, dtype=np.int64) | |
all_sentence_ids = np.array(all_sentence_ids, dtype=bool) | |
all_data.append((all_tokens, all_sentence_ids)) | |
logging.info("Completed text preprocessing. Total number of lines: %d", | |
total_number_of_lines) | |
return all_data | |
def _reshape_to_batch_dimensions( | |
tokens: np.array, | |
sentence_ids: np.array, | |
per_host_batch_size: int) -> Tuple[np.array, np.array]: | |
"""Truncates and reshapes input data with a batch major dimension. | |
Args: | |
tokens: The input token ids. This should have the same shape as | |
`sentence_ids`. | |
sentence_ids: The input sentence ids. This should have the same shape as | |
`token_ids`. | |
per_host_batch_size: The target per-host batch size. | |
Returns: | |
The tuple of reshaped tokens and sentence_ids. | |
""" | |
num_steps = len(tokens) // per_host_batch_size | |
truncated_data_length = num_steps * per_host_batch_size | |
logging.info("per_host_batch_size: %d", per_host_batch_size) | |
logging.info("num_steps: %d", num_steps) | |
def truncate_and_reshape(a): | |
return a[:truncated_data_length].reshape((per_host_batch_size, num_steps)) | |
return (truncate_and_reshape(tokens), truncate_and_reshape(sentence_ids)) | |
def _create_a_and_b_segments( | |
tokens: np.array, | |
sentence_ids: np.array, | |
begin_index: int, | |
total_length: int, | |
no_cut_probability: float = 0.5): | |
"""Splits segments A and B from a single instance of tokens and sentence ids. | |
Args: | |
tokens: The 1D input token ids. This represents an individual entry within a | |
batch. | |
sentence_ids: The 1D input sentence ids. This represents an individual entry | |
within a batch. This should be the same length as `tokens`. | |
begin_index: The reference beginning index to split data. | |
total_length: The target combined length of segments A and B. | |
no_cut_probability: The probability of not cutting a segment despite | |
a cut possibly existing. | |
Returns: | |
A tuple consisting of A data, B data, and label. | |
""" | |
data_length = tokens.shape[0] | |
if begin_index + total_length >= data_length: | |
logging.info("[_create_segments]: begin_index %d + total_length %d >= " | |
"data_length %d", begin_index, total_length, data_length) | |
return None | |
end_index = begin_index + 1 | |
cut_indices = [] | |
# Identify all indices where sentence IDs change from one to the next. | |
while end_index < data_length: | |
if sentence_ids[end_index] != sentence_ids[end_index - 1]: | |
if end_index - begin_index >= total_length: | |
break | |
cut_indices.append(end_index) | |
end_index += 1 | |
a_begin = begin_index | |
if not cut_indices or random.random() < no_cut_probability: | |
# Segments A and B are contained within the same sentence. | |
label = 0 | |
if not cut_indices: | |
a_end = end_index | |
else: | |
a_end = random.choice(cut_indices) | |
b_length = max(1, total_length - (a_end - a_begin)) | |
b_begin = random.randint(0, data_length - 1 - b_length) | |
b_end = b_begin + b_length | |
while b_begin > 0 and sentence_ids[b_begin - 1] == sentence_ids[b_begin]: | |
b_begin -= 1 | |
while (b_end < data_length - 1 and | |
sentence_ids[b_end - 1] == sentence_ids[b_end]): | |
b_end += 1 | |
else: | |
# Segments A and B are different sentences. | |
label = 1 | |
a_end = random.choice(cut_indices) | |
b_begin = a_end | |
b_end = end_index | |
while a_end - a_begin + b_end - b_begin > total_length: | |
if a_end - a_begin > b_end - b_begin: | |
# Delete only the right side for the LM objective. | |
a_end -= 1 | |
else: | |
b_end -= 1 | |
if a_end >= data_length or b_end >= data_length: | |
logging.info("[_create_segments]: a_end %d or b_end %d >= data_length %d", | |
a_end, b_end, data_length) | |
return None | |
a_data = tokens[a_begin: a_end] | |
b_data = tokens[b_begin: b_end] | |
return a_data, b_data, label | |
def _is_functional_piece(piece: str) -> bool: | |
return piece != "<unk>" and piece.startswith("<") and piece.endswith(">") | |
def _is_start_piece(piece: str) -> bool: | |
special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')) | |
if (piece.startswith("▁") or piece in special_pieces): | |
return True | |
else: | |
return False | |
def _get_boundary_indices( | |
data: np.array, | |
tokenizer: tokenization.FullSentencePieceTokenizer) -> np.array: | |
"""Gets the boundary indices of whole words.""" | |
seq_length = len(data) | |
boundary_indices = [] | |
for index, piece in enumerate(tokenizer.convert_ids_to_tokens(data.tolist())): | |
if _is_start_piece(piece) and not _is_functional_piece(piece): | |
boundary_indices.append(index) | |
boundary_indices.append(seq_length) | |
return boundary_indices | |
def _convert_tokens_to_instances( | |
tokens: np.array, | |
sentence_ids: np.array, | |
per_host_batch_size: int, | |
seq_length: int, | |
reuse_length: int, | |
bi_data: bool, | |
tokenizer: tokenization.FullSentencePieceTokenizer, | |
num_cores_per_host: int = 0, | |
logging_frequency: int = 500) -> List[TrainingInstance]: | |
"""Converts tokens and sentence IDs into individual training instances. | |
The format of data in the XLNet pretraining task is very similar to the | |
BERT pretraining task. Two segments A and B are randomly sampled, and the | |
contatenation of A and B into a single sequence is used to perform | |
language modeling. | |
To create an XLNet Pretraining instance from a single long sequence, S: | |
- Create a segment of length `reuse_length`. This first segment represents | |
past tokens. During modeling, this segment is used to cache obtained | |
content representations for the segment recurrence mechanism. | |
- Similar to BERT, create a segment of length `seq_length` - `reuse_length` | |
composed of A and B segments. | |
For XLNet, the order is "A", "SEP", "B", "SEP", "CLS". | |
Args: | |
tokens: All tokens concatenated into a single list. | |
sentence_ids: All sentence IDs concatenated into a single list. | |
per_host_batch_size: The target batch size per host. | |
seq_length: The max sequence length. | |
reuse_length: The number of tokens to use from the previous segment. | |
bi_data: Whether or not to use bidirectional data. | |
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`. | |
num_cores_per_host: The number of cores per host. This is required if | |
`bi_data` = `True`. | |
logging_frequency: The frequency at which to log status updates. | |
Returns: | |
A list of `TrainingInstance` objects. | |
""" | |
instances = [] | |
per_core_batch_size = (per_host_batch_size // num_cores_per_host | |
if bi_data else None) | |
if bi_data: | |
logging.info("Bi-directional data enabled.") | |
assert per_host_batch_size % (2 * num_cores_per_host) == 0 | |
forward_tokens, forward_sentence_ids = _reshape_to_batch_dimensions( | |
tokens=tokens, | |
sentence_ids=sentence_ids, | |
per_host_batch_size=per_host_batch_size // 2) | |
forward_data_shape = (num_cores_per_host, 1, per_core_batch_size // 2, -1) | |
forward_tokens = forward_tokens.reshape(forward_data_shape) | |
forward_sentence_ids = forward_sentence_ids.reshape(forward_data_shape) | |
backwards_tokens = forward_tokens[:, :, :, ::-1] | |
backwards_sentence_ids = forward_sentence_ids[:, :, :, ::-1] | |
tokens = np.concatenate([forward_tokens, backwards_tokens], 1).reshape( | |
per_host_batch_size, -1) | |
sentence_ids = np.concatenate( | |
[forward_sentence_ids, backwards_sentence_ids]).reshape( | |
per_host_batch_size, -1) | |
else: | |
logging.info("Bi-directional data disabled.") | |
tokens, sentence_ids = _reshape_to_batch_dimensions( | |
tokens=tokens, | |
sentence_ids=sentence_ids, | |
per_host_batch_size=per_host_batch_size) | |
logging.info("Tokens shape: %s", tokens.shape) | |
data_length = tokens.shape[1] | |
sep = np.array([special_symbols["<sep>"]], dtype=np.int64) | |
cls = np.array([special_symbols["<cls>"]], dtype=np.int64) | |
# 2 sep, 1 cls | |
num_special_tokens = 3 | |
data_index = 0 | |
batch_number = 0 | |
step_size = reuse_length if reuse_length else seq_length | |
num_batches = math.ceil(data_length / step_size) | |
while data_index + seq_length <= data_length: | |
if batch_number % logging_frequency == 0: | |
logging.info("Processing batch %d of %d", batch_number, num_batches) | |
for batch_index in range(per_host_batch_size): | |
previous_segment_tokens = tokens[ | |
batch_index, data_index: data_index + reuse_length] | |
results = _create_a_and_b_segments( | |
tokens=tokens[batch_index], | |
sentence_ids=sentence_ids[batch_index], | |
begin_index=data_index + reuse_length, | |
total_length=seq_length - reuse_length - num_special_tokens) | |
if results is None: | |
logging.info("Stopping at data index: %d", data_index) | |
break | |
a_data, b_data, label = results | |
data = np.concatenate( | |
[previous_segment_tokens, a_data, sep, b_data, sep, cls]) | |
a_length = a_data.shape[0] | |
b_length = b_data.shape[0] | |
segment_ids = ([0] * (reuse_length + a_length) + [0] | |
+ [1] * b_length + [1] + [2]) | |
boundary_indices = _get_boundary_indices(tokenizer=tokenizer, | |
data=data) | |
assert len(data) == seq_length | |
assert len(segment_ids) == seq_length | |
assert len(boundary_indices) > 0 # pylint: disable=g-explicit-length-test | |
instances.append(TrainingInstance( | |
data=data, | |
segment_ids=segment_ids, | |
boundary_indices=boundary_indices, | |
label=label)) | |
batch_number += 1 | |
data_index += step_size | |
return instances | |
def write_instances_to_tfrecord( | |
instances: Iterable[TrainingInstance], | |
save_path: str): | |
"""Writes instances to TFRecord.""" | |
record_writer = tf.io.TFRecordWriter(save_path) | |
logging.info("Start writing to %s.", save_path) | |
for i, instance in enumerate(instances): | |
if i < 5: | |
logging.info("Instance %d: %s", i, str(instance)) | |
record_writer.write(instance.to_example().SerializeToString()) | |
record_writer.close() | |
logging.info("Done writing %s.", save_path) | |
def shuffle_and_combine_preprocessed_data( | |
all_data: List[Tuple[np.array, np.array]]) -> Tuple[np.array, np.array]: | |
"""Shuffles and combines preprocessed token/sentence IDs from documents.""" | |
document_permutation = np.random.permutation(len(all_data)) | |
previous_sentence_id = None | |
all_tokens, all_sentence_ids = [], [] | |
for document_index in document_permutation: | |
tokens, sentence_ids = all_data[document_index] | |
# pylint: disable=g-explicit-length-test | |
if len(tokens) == 0: | |
continue | |
if (previous_sentence_id is not None and | |
sentence_ids[0] == previous_sentence_id): | |
sentence_ids = np.logical_not(sentence_ids) | |
all_tokens.append(tokens) | |
all_sentence_ids.append(sentence_ids) | |
previous_sentence_id = sentence_ids[-1] | |
return np.concatenate(all_tokens), np.concatenate(all_sentence_ids) | |
def get_tfrecord_name( | |
per_host_batch_size: int, | |
num_cores_per_host: int, | |
seq_length: int, | |
bi_data: bool, | |
reuse_length: int, | |
do_lower_case: bool, | |
use_eod_token: bool, | |
prefix: str = "", | |
suffix: str = "", | |
pass_id: int = 0, | |
num_passes: int = 1, | |
task_id: int = None, | |
num_tasks: int = None) -> str: | |
"""Formats the resulting TFRecord name based on provided inputs.""" | |
components = [] | |
if prefix: | |
components.append(prefix) | |
components.append("seqlen-{}".format(seq_length)) | |
if reuse_length == 0: | |
components.append("memless") | |
else: | |
components.append("reuse-{}".format(reuse_length)) | |
components.append("bs-{}".format(per_host_batch_size)) | |
components.append("cores-{}".format(num_cores_per_host)) | |
if do_lower_case: | |
components.append("uncased") | |
else: | |
components.append("cased") | |
if use_eod_token: | |
components.append("eod") | |
if bi_data: | |
components.append("bi") | |
else: | |
components.append("uni") | |
if suffix: | |
components.append(suffix) | |
s = "_".join(components) + ".tfrecord" | |
if num_passes == 1 and task_id is None: | |
return s | |
if task_id is None: | |
num_tasks = 1 | |
task_id = 0 | |
current_shard = task_id * num_passes + pass_id | |
total_shards = num_tasks * num_passes | |
return s + "-{}-of-{}".format(current_shard, total_shards) | |
def create_tfrecords( | |
tokenizer: tokenization.FullSentencePieceTokenizer, | |
input_file_or_files: str, | |
use_eod_token: bool, | |
do_lower_case: bool, | |
per_host_batch_size: int, | |
seq_length: int, | |
reuse_length: int, | |
bi_data: bool, | |
num_cores_per_host: int, | |
save_dir: str, | |
prefix: str = "", | |
suffix: str = "", | |
num_tasks: Optional[int] = None, | |
task_id: Optional[int] = None, | |
num_passes: int = 1): | |
"""Runs the end-to-end preprocessing pipeline.""" | |
logging.info("Input configuration:") | |
logging.info("input file(s): %s", input_file_or_files) | |
logging.info("use_eod_token: %s", use_eod_token) | |
logging.info("do_lower_case: %s", do_lower_case) | |
logging.info("per_host_batch_size: %d", per_host_batch_size) | |
logging.info("seq_length: %d", seq_length) | |
logging.info("reuse_length: %d", reuse_length) | |
logging.info("bi_data: %s", bi_data) | |
logging.info("num_cores_per_host: %d", num_cores_per_host) | |
logging.info("save_dir: %s", save_dir) | |
if task_id is not None and num_tasks is not None: | |
logging.info("task_id: %d", task_id) | |
logging.info("num_tasks: %d", num_tasks) | |
input_files = [] | |
for input_pattern in input_file_or_files.split(","): | |
input_files.extend(tf.io.gfile.glob(input_pattern)) | |
logging.info("*** Reading from input files ***") | |
for input_file in input_files: | |
logging.info(" %s", input_file) | |
logging.info("Shuffling the files with a fixed random seed.") | |
np.random.shuffle(input_files) | |
if num_tasks is not None: | |
assert task_id is not None | |
logging.info("Total number of input files: %d", len(input_files)) | |
logging.info("Splitting into %d shards of %d files each.", | |
num_tasks, len(input_files) // num_tasks) | |
input_files = input_files[task_id::num_tasks] | |
all_data = preprocess_and_tokenize_input_files( | |
input_files=input_files, | |
tokenizer=tokenizer, | |
use_eod=use_eod_token, | |
do_lower_case=do_lower_case) | |
for pass_id in range(num_passes): | |
logging.info("Beginning pass %d of %d", pass_id, num_passes) | |
tokens, sentence_ids = shuffle_and_combine_preprocessed_data(all_data) | |
assert len(tokens) == len(sentence_ids) | |
filename = get_tfrecord_name( | |
per_host_batch_size=per_host_batch_size, | |
num_cores_per_host=num_cores_per_host, | |
seq_length=seq_length, | |
bi_data=bi_data, | |
use_eod_token=use_eod_token, | |
reuse_length=reuse_length, | |
do_lower_case=do_lower_case, | |
prefix=prefix, | |
suffix=suffix, | |
pass_id=pass_id, | |
num_passes=num_passes, | |
num_tasks=num_tasks, | |
task_id=task_id) | |
save_path = os.path.join(save_dir, filename) | |
if os.path.exists(save_path): | |
# If the path already exists, then we were probably preempted but | |
# previously wrote this file. | |
logging.info("%s already exists, skipping this batch.", save_path) | |
else: | |
instances = _convert_tokens_to_instances( | |
tokenizer=tokenizer, | |
tokens=tokens, | |
sentence_ids=sentence_ids, | |
per_host_batch_size=per_host_batch_size, | |
seq_length=seq_length, | |
reuse_length=reuse_length, | |
bi_data=bi_data, | |
num_cores_per_host=num_cores_per_host) | |
write_instances_to_tfrecord(instances=instances, save_path=save_path) | |
if task_id is None or task_id == 0: | |
corpus_info = { | |
"vocab_size": 32000, | |
"per_host_batch_size": per_host_batch_size, | |
"num_cores_per_host": num_cores_per_host, | |
"seq_length": seq_length, | |
"reuse_length": reuse_length, | |
"do_lower_case": do_lower_case, | |
"bi_data": bi_data, | |
"use_eod_token": use_eod_token, | |
} | |
corpus_fname = os.path.basename(filename) + ".json" | |
corpus_destination = os.path.join(save_dir, corpus_fname) | |
logging.info("Saving corpus info to %s", corpus_destination) | |
with tf.io.gfile.GFile(corpus_destination, "w") as fp: | |
json.dump(corpus_info, fp) | |
def main(_): | |
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) | |
create_tfrecords( | |
tokenizer=tokenizer, | |
input_file_or_files=FLAGS.input_file, | |
use_eod_token=FLAGS.use_eod_token, | |
do_lower_case=FLAGS.do_lower_case, | |
per_host_batch_size=FLAGS.per_host_batch_size, | |
seq_length=FLAGS.seq_length, | |
reuse_length=FLAGS.reuse_length, | |
bi_data=FLAGS.bi_data, | |
num_cores_per_host=FLAGS.num_cores_per_host, | |
save_dir=FLAGS.save_dir, | |
prefix=FLAGS.prefix, | |
suffix=FLAGS.suffix, | |
num_tasks=FLAGS.num_tasks, | |
task_id=FLAGS.task_id, | |
num_passes=FLAGS.num_passes) | |
if __name__ == "__main__": | |
np.random.seed(0) | |
logging.set_verbosity(logging.INFO) | |
app.run(main) | |