File size: 3,263 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Rebalance a set of CSV/TFRecord shards to a target number of files.
"""

import argparse
import datetime
import os

import apache_beam as beam
import tensorflow as tf, tf_keras


parser = argparse.ArgumentParser()
parser.add_argument(
    "--input_path",
    default=None,
    required=True,
    help="Input path.")
parser.add_argument(
    "--output_path",
    default=None,
    required=True,
    help="Output path.")
parser.add_argument(
    "--num_output_files",
    type=int,
    default=256,
    help="Number of output file shards.")
parser.add_argument(
    "--filetype",
    default="tfrecord",
    help="File type, needs to be one of {tfrecord, csv}.")
parser.add_argument(
    "--project",
    default=None,
    help="ID (not name) of your project. Ignored by DirectRunner")
parser.add_argument(
    "--runner",
    help="Runner for Apache Beam, needs to be one of "
    "{DirectRunner, DataflowRunner}.",
    default="DirectRunner")
parser.add_argument(
    "--region",
    default=None,
    help="region")

args = parser.parse_args()


def rebalance_data_shards():
  """Rebalances data shards."""

  def csv_pipeline(pipeline: beam.Pipeline):
    """Rebalances CSV dataset.

    Args:
      pipeline: Beam pipeline object.
    """
    _ = (
        pipeline
        | beam.io.ReadFromText(args.input_path)
        | beam.io.WriteToText(args.output_path,
                              num_shards=args.num_output_files))

  def tfrecord_pipeline(pipeline: beam.Pipeline):
    """Rebalances TFRecords dataset.

    Args:
      pipeline: Beam pipeline object.
    """
    example_coder = beam.coders.ProtoCoder(tf.train.Example)
    _ = (
        pipeline
        | beam.io.ReadFromTFRecord(args.input_path, coder=example_coder)
        | beam.io.WriteToTFRecord(args.output_path, file_name_suffix="tfrecord",
                                  coder=example_coder,
                                  num_shards=args.num_output_files))

  job_name = (
      f"shard-rebalancer-{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}")

  # set up Beam pipeline.
  options = {
      "staging_location": os.path.join(args.output_path, "tmp", "staging"),
      "temp_location": os.path.join(args.output_path, "tmp"),
      "job_name": job_name,
      "project": args.project,
      "save_main_session": True,
      "region": args.region,
  }

  opts = beam.pipeline.PipelineOptions(flags=[], **options)

  with beam.Pipeline(args.runner, options=opts) as pipeline:
    if args.filetype == "tfrecord":
      tfrecord_pipeline(pipeline)
    elif args.filetype == "csv":
      csv_pipeline(pipeline)


if __name__ == "__main__":
  rebalance_data_shards()