boris commited on
Commit
901ff72
·
1 Parent(s): fdf7698

feat: shard by host is optional

Browse files
Files changed (2) hide show
  1. dalle_mini/data.py +7 -2
  2. tools/train/train.py +8 -2
dalle_mini/data.py CHANGED
@@ -27,6 +27,7 @@ class Dataset:
27
  do_train: bool = False
28
  do_eval: bool = True
29
  seed_dataset: int = None
 
30
  train_dataset: Dataset = field(init=False)
31
  eval_dataset: Dataset = field(init=False)
32
  rng_dataset: jnp.ndarray = field(init=False)
@@ -42,7 +43,11 @@ class Dataset:
42
  if isinstance(f, str):
43
  setattr(self, k, list(braceexpand(f)))
44
  # for list of files, split training data shards by host
45
- if isinstance(self.train_file, list) and self.multi_hosts:
 
 
 
 
46
  self.train_file = self.train_file[
47
  jax.process_index() :: jax.process_count()
48
  ]
@@ -185,7 +190,7 @@ class Dataset:
185
  first_loop = True
186
  while self.multi_hosts or first_loop:
187
  # in multi-host, we run forever (no epoch) as hosts need to stop
188
- # at same the time and we don't know how much data is on each host
189
  if not first_loop:
190
  # multi-host setting, we reshuffle shards
191
  epoch += 1
 
27
  do_train: bool = False
28
  do_eval: bool = True
29
  seed_dataset: int = None
30
+ shard_by_host: bool = False
31
  train_dataset: Dataset = field(init=False)
32
  eval_dataset: Dataset = field(init=False)
33
  rng_dataset: jnp.ndarray = field(init=False)
 
43
  if isinstance(f, str):
44
  setattr(self, k, list(braceexpand(f)))
45
  # for list of files, split training data shards by host
46
+ if (
47
+ isinstance(self.train_file, list)
48
+ and self.multi_hosts
49
+ and self.shard_by_host
50
+ ):
51
  self.train_file = self.train_file[
52
  jax.process_index() :: jax.process_count()
53
  ]
 
190
  first_loop = True
191
  while self.multi_hosts or first_loop:
192
  # in multi-host, we run forever (no epoch) as hosts need to stop
193
+ # at the same time and we don't know how much data is on each host
194
  if not first_loop:
195
  # multi-host setting, we reshuffle shards
196
  epoch += 1
tools/train/train.py CHANGED
@@ -112,16 +112,22 @@ class DataTrainingArguments:
112
  metadata={"help": "An optional input evaluation data file (glob acceptable)."},
113
  )
114
  # data loading should not be a bottleneck so we use "streaming" mode by default
115
- streaming: bool = field(
116
  default=True,
117
  metadata={"help": "Whether to stream the dataset."},
118
  )
119
- use_auth_token: bool = field(
120
  default=False,
121
  metadata={
122
  "help": "Whether to use the authentication token for private datasets."
123
  },
124
  )
 
 
 
 
 
 
125
  max_train_samples: Optional[int] = field(
126
  default=None,
127
  metadata={
 
112
  metadata={"help": "An optional input evaluation data file (glob acceptable)."},
113
  )
114
  # data loading should not be a bottleneck so we use "streaming" mode by default
115
+ streaming: Optional[bool] = field(
116
  default=True,
117
  metadata={"help": "Whether to stream the dataset."},
118
  )
119
+ use_auth_token: Optional[bool] = field(
120
  default=False,
121
  metadata={
122
  "help": "Whether to use the authentication token for private datasets."
123
  },
124
  )
125
+ shard_by_host: Optional[bool] = field(
126
+ default=False,
127
+ metadata={
128
+ "help": "Whether to shard data files by host in multi-host environments."
129
+ },
130
+ )
131
  max_train_samples: Optional[int] = field(
132
  default=None,
133
  metadata={