|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Rebalance a set of CSV/TFRecord shards to a target number of files. |
|
""" |
|
|
|
import argparse |
|
import datetime |
|
import os |
|
|
|
import apache_beam as beam |
|
import tensorflow as tf, tf_keras |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--input_path", |
|
default=None, |
|
required=True, |
|
help="Input path.") |
|
parser.add_argument( |
|
"--output_path", |
|
default=None, |
|
required=True, |
|
help="Output path.") |
|
parser.add_argument( |
|
"--num_output_files", |
|
type=int, |
|
default=256, |
|
help="Number of output file shards.") |
|
parser.add_argument( |
|
"--filetype", |
|
default="tfrecord", |
|
help="File type, needs to be one of {tfrecord, csv}.") |
|
parser.add_argument( |
|
"--project", |
|
default=None, |
|
help="ID (not name) of your project. Ignored by DirectRunner") |
|
parser.add_argument( |
|
"--runner", |
|
help="Runner for Apache Beam, needs to be one of " |
|
"{DirectRunner, DataflowRunner}.", |
|
default="DirectRunner") |
|
parser.add_argument( |
|
"--region", |
|
default=None, |
|
help="region") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
def rebalance_data_shards(): |
|
"""Rebalances data shards.""" |
|
|
|
def csv_pipeline(pipeline: beam.Pipeline): |
|
"""Rebalances CSV dataset. |
|
|
|
Args: |
|
pipeline: Beam pipeline object. |
|
""" |
|
_ = ( |
|
pipeline |
|
| beam.io.ReadFromText(args.input_path) |
|
| beam.io.WriteToText(args.output_path, |
|
num_shards=args.num_output_files)) |
|
|
|
def tfrecord_pipeline(pipeline: beam.Pipeline): |
|
"""Rebalances TFRecords dataset. |
|
|
|
Args: |
|
pipeline: Beam pipeline object. |
|
""" |
|
example_coder = beam.coders.ProtoCoder(tf.train.Example) |
|
_ = ( |
|
pipeline |
|
| beam.io.ReadFromTFRecord(args.input_path, coder=example_coder) |
|
| beam.io.WriteToTFRecord(args.output_path, file_name_suffix="tfrecord", |
|
coder=example_coder, |
|
num_shards=args.num_output_files)) |
|
|
|
job_name = ( |
|
f"shard-rebalancer-{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}") |
|
|
|
|
|
options = { |
|
"staging_location": os.path.join(args.output_path, "tmp", "staging"), |
|
"temp_location": os.path.join(args.output_path, "tmp"), |
|
"job_name": job_name, |
|
"project": args.project, |
|
"save_main_session": True, |
|
"region": args.region, |
|
} |
|
|
|
opts = beam.pipeline.PipelineOptions(flags=[], **options) |
|
|
|
with beam.Pipeline(args.runner, options=opts) as pipeline: |
|
if args.filetype == "tfrecord": |
|
tfrecord_pipeline(pipeline) |
|
elif args.filetype == "csv": |
|
csv_pipeline(pipeline) |
|
|
|
|
|
if __name__ == "__main__": |
|
rebalance_data_shards() |
|
|