|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Processes crawled content from news URLs by generating tfrecords.""" |
|
|
|
import os |
|
from absl import app |
|
from absl import flags |
|
from official.nlp.nhnet import raw_data_processor |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string("crawled_articles", "/tmp/nhnet/", |
|
"Folder path to the crawled articles using news-please.") |
|
flags.DEFINE_string("vocab", None, "Filepath of the BERT vocabulary.") |
|
flags.DEFINE_bool("do_lower_case", True, |
|
"Whether the vocabulary is uncased or not.") |
|
flags.DEFINE_integer("len_title", 15, |
|
"Maximum number of tokens in story headline.") |
|
flags.DEFINE_integer("len_passage", 200, |
|
"Maximum number of tokens in article passage.") |
|
flags.DEFINE_integer("max_num_articles", 5, |
|
"Maximum number of articles in a story.") |
|
flags.DEFINE_bool("include_article_title_in_passage", False, |
|
"Whether to include article title in article passage.") |
|
flags.DEFINE_string("data_folder", None, |
|
"Folder path to the downloaded data folder (output).") |
|
flags.DEFINE_integer("num_tfrecords_shards", 20, |
|
"Number of shards for train/valid/test.") |
|
|
|
|
|
def transform_as_tfrecords(data_processor, filename): |
|
"""Transforms story from json to tfrecord (sharded). |
|
|
|
Args: |
|
data_processor: Instance of RawDataProcessor. |
|
filename: 'train', 'valid', or 'test'. |
|
""" |
|
print("Transforming json to tfrecord for %s..." % filename) |
|
story_filepath = os.path.join(FLAGS.data_folder, filename + ".json") |
|
output_folder = os.path.join(FLAGS.data_folder, "processed") |
|
os.makedirs(output_folder, exist_ok=True) |
|
output_filepaths = [] |
|
for i in range(FLAGS.num_tfrecords_shards): |
|
output_filepaths.append( |
|
os.path.join( |
|
output_folder, "%s.tfrecord-%.5d-of-%.5d" % |
|
(filename, i, FLAGS.num_tfrecords_shards))) |
|
(total_num_examples, |
|
generated_num_examples) = data_processor.generate_examples( |
|
story_filepath, output_filepaths) |
|
print("For %s, %d examples have been generated from %d stories in json." % |
|
(filename, generated_num_examples, total_num_examples)) |
|
|
|
|
|
def main(_): |
|
if not FLAGS.data_folder: |
|
raise ValueError("data_folder must be set as the downloaded folder path.") |
|
if not FLAGS.vocab: |
|
raise ValueError("vocab must be set as the filepath of BERT vocabulary.") |
|
data_processor = raw_data_processor.RawDataProcessor( |
|
vocab=FLAGS.vocab, |
|
do_lower_case=FLAGS.do_lower_case, |
|
len_title=FLAGS.len_title, |
|
len_passage=FLAGS.len_passage, |
|
max_num_articles=FLAGS.max_num_articles, |
|
include_article_title_in_passage=FLAGS.include_article_title_in_passage, |
|
include_text_snippet_in_example=True) |
|
print("Loading crawled articles...") |
|
num_articles = data_processor.read_crawled_articles(FLAGS.crawled_articles) |
|
print("Total number of articles loaded: %d" % num_articles) |
|
print() |
|
transform_as_tfrecords(data_processor, "train") |
|
transform_as_tfrecords(data_processor, "valid") |
|
transform_as_tfrecords(data_processor, "test") |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|