boris commited on
Commit
bcd360f
·
2 Parent(s): 7f2f8ed 728a3c3

Merge branch 'main' of https://github.com/borisdayma/dalle-mini into main

Browse files
.github/workflows/sync_to_hub.yml CHANGED
@@ -17,4 +17,4 @@ jobs:
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
- run: git push https://boris:[email protected]/spaces/flax-community/dalle-mini main
 
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://boris:[email protected]/spaces/dalle-mini/dalle-mini main
.github/workflows/sync_to_hub_debug.yml CHANGED
@@ -14,4 +14,4 @@ jobs:
14
  - name: Push to hub
15
  env:
16
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
17
- run: git push --force https://boris:[email protected]/spaces/flax-community/dalle-mini-debug +HEAD:main
 
14
  - name: Push to hub
15
  env:
16
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
17
+ run: git push --force https://boris:[email protected]/spaces/dalle-mini/dalle-mini-debug +HEAD:main
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: green
6
  sdk: streamlit
7
  app_file: app/streamlit/app.py
8
  pinned: True
 
9
  ---
10
 
11
  # DALL·E Mini
@@ -18,7 +19,7 @@ _Generate images from a text prompt_
18
 
19
  Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
20
 
21
- You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).
22
 
23
  ## How does it work?
24
 
 
6
  sdk: streamlit
7
  app_file: app/streamlit/app.py
8
  pinned: True
9
+ license: apache-2.0
10
  ---
11
 
12
  # DALL·E Mini
 
19
 
20
  Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
21
 
22
+ You can create your own pictures with [the demo](https://huggingface.co/spaces/dalle-mini/dalle-mini).
23
 
24
  ## How does it work?
25
 
app/streamlit/app.py CHANGED
@@ -78,7 +78,7 @@ if prompt != "":
78
  </div>
79
  </div>
80
  </div>
81
- <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
82
  """,
83
  unsafe_allow_html=True,
84
  )
 
78
  </div>
79
  </div>
80
  </div>
81
+ <small><i>Predictions may take up to 5mn under high load. Please stand by.</i></small>
82
  """,
83
  unsafe_allow_html=True,
84
  )
setup.cfg CHANGED
@@ -27,6 +27,7 @@ install_requires =
27
  einops
28
  unidecode
29
  ftfy
 
30
  pillow
31
  jax
32
  flax
 
27
  einops
28
  unidecode
29
  ftfy
30
+ emoji
31
  pillow
32
  jax
33
  flax
src/dalle_mini/data.py CHANGED
@@ -43,6 +43,8 @@ class Dataset:
43
  if self.seed_dataset is None:
44
  # create a random seed
45
  self.seed_dataset = random.randint(0, 2**32 - 1)
 
 
46
  self.multi_hosts = jax.process_count() > 1
47
  # feed blank captions only in streaming mode for now
48
  # otherwise dataset could be cached with same blanked captions
@@ -173,6 +175,7 @@ class Dataset:
173
  blank_caption_function,
174
  text_column=self.text_column,
175
  blank_caption_prob=self.blank_caption_prob,
 
176
  )
177
  if hasattr(self, "train_dataset"):
178
  self.train_dataset = (
@@ -180,7 +183,9 @@ class Dataset:
180
  if self.streaming
181
  else self.train_dataset.map(
182
  partial_blank_caption_function,
183
- num_proc=self.preprocessing_num_workers,
 
 
184
  load_from_cache_file=False,
185
  desc="Blanking some captions",
186
  )
@@ -316,8 +321,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
316
  return shifted_input_ids
317
 
318
 
319
- def blank_caption_function(example, text_column, blank_caption_prob):
320
- if blank_caption_prob and np.random.rand() < blank_caption_prob:
 
 
 
 
321
  example[text_column] = ""
322
  return example
323
 
 
43
  if self.seed_dataset is None:
44
  # create a random seed
45
  self.seed_dataset = random.randint(0, 2**32 - 1)
46
+ # set numpy rng
47
+ self.np_rng = np.random.default_rng(self.seed_dataset)
48
  self.multi_hosts = jax.process_count() > 1
49
  # feed blank captions only in streaming mode for now
50
  # otherwise dataset could be cached with same blanked captions
 
175
  blank_caption_function,
176
  text_column=self.text_column,
177
  blank_caption_prob=self.blank_caption_prob,
178
+ rng=self.np_rng,
179
  )
180
  if hasattr(self, "train_dataset"):
181
  self.train_dataset = (
 
183
  if self.streaming
184
  else self.train_dataset.map(
185
  partial_blank_caption_function,
186
+ num_proc=None
187
+ if self.seed_dataset
188
+ else self.preprocessing_num_workers,
189
  load_from_cache_file=False,
190
  desc="Blanking some captions",
191
  )
 
321
  return shifted_input_ids
322
 
323
 
324
+ def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
325
+ if (
326
+ blank_caption_prob
327
+ and (rng.random() if rng is not None else np.random.random())
328
+ < blank_caption_prob
329
+ ):
330
  example[text_column] = ""
331
  return example
332
 
src/dalle_mini/model/text.py CHANGED
@@ -8,6 +8,7 @@ import random
8
  import re
9
  from pathlib import Path
10
 
 
11
  import ftfy
12
  from huggingface_hub import hf_hub_download
13
  from unidecode import unidecode
@@ -213,6 +214,8 @@ class TextNormalizer:
213
  t = ftfy.fix_text(t)
214
  # fix html
215
  t = fix_html(t)
 
 
216
  # decode and simplify text: see unidecode library
217
  t = unidecode(t)
218
  # lower case
 
8
  import re
9
  from pathlib import Path
10
 
11
+ import emoji
12
  import ftfy
13
  from huggingface_hub import hf_hub_download
14
  from unidecode import unidecode
 
214
  t = ftfy.fix_text(t)
215
  # fix html
216
  t = fix_html(t)
217
+ # decode emojis (would be removed by unidecode)
218
+ t = emoji.demojize(t)
219
  # decode and simplify text: see unidecode library
220
  t = unidecode(t)
221
  # lower case
tools/train/config/mega/config.json CHANGED
@@ -1,30 +1,49 @@
1
  {
2
  "activation_dropout": 0.0,
3
- "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
  "d_model": 2048,
7
  "decoder_attention_heads": 32,
8
- "decoder_ffn_dim": 8192,
9
  "decoder_layerdrop": 0.0,
10
- "decoder_layers": 24,
11
  "decoder_start_token_id": 16384,
 
12
  "dropout": 0.0,
13
  "encoder_attention_heads": 32,
14
- "encoder_ffn_dim": 8192,
15
  "encoder_layerdrop": 0.0,
16
- "encoder_layers": 24,
17
- "encoder_vocab_size": 50264,
18
  "eos_token_id": 16385,
 
 
19
  "image_length": 256,
20
- "image_vocab_size": 16391,
21
  "init_std": 0.01,
22
  "is_encoder_decoder": true,
 
 
 
23
  "max_text_length": 64,
 
24
  "model_type": "dallebart",
25
  "normalize_text": true,
26
  "pad_token_id": 16385,
27
  "scale_embedding": false,
 
 
28
  "tie_word_embeddings": false,
29
- "use_cache": true
 
 
 
 
 
 
 
 
 
 
30
  }
 
1
  {
2
  "activation_dropout": 0.0,
3
+ "activation_function": "swish",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
  "d_model": 2048,
7
  "decoder_attention_heads": 32,
8
+ "decoder_ffn_dim": 4096,
9
  "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 25,
11
  "decoder_start_token_id": 16384,
12
+ "do_sample": true,
13
  "dropout": 0.0,
14
  "encoder_attention_heads": 32,
15
+ "encoder_ffn_dim": 4096,
16
  "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 25,
18
+ "encoder_vocab_size": 50272,
19
  "eos_token_id": 16385,
20
+ "force_ln_scale": false,
21
+ "gradient_checkpointing": false,
22
  "image_length": 256,
23
+ "image_vocab_size": 16415,
24
  "init_std": 0.01,
25
  "is_encoder_decoder": true,
26
+ "ln_positions": "normformer",
27
+ "ln_type": "layernorm",
28
+ "max_length": 257,
29
  "max_text_length": 64,
30
+ "min_length": 257,
31
  "model_type": "dallebart",
32
  "normalize_text": true,
33
  "pad_token_id": 16385,
34
  "scale_embedding": false,
35
+ "sinkhorn_iters": 1,
36
+ "tau_init": 0.05,
37
  "tie_word_embeddings": false,
38
+ "use_absolute_position_embeddings": true,
39
+ "use_alibi": false,
40
+ "use_bias": false,
41
+ "use_cache": true,
42
+ "use_cosine_attention": false,
43
+ "use_deepnet_scaling": false,
44
+ "use_final_ln_decoder": true,
45
+ "use_final_ln_encoder": true,
46
+ "use_glu": true,
47
+ "use_head_scale": false,
48
+ "use_swin_position_embeddings": false
49
  }
tools/train/config/mini/config.json CHANGED
@@ -16,7 +16,7 @@
16
  "eos_token_id": 16385,
17
  "gradient_checkpointing": false,
18
  "image_length": 256,
19
- "image_vocab_size": 16384,
20
  "init_std": 0.02,
21
  "is_encoder_decoder": true,
22
  "max_text_length": 64,
 
16
  "eos_token_id": 16385,
17
  "gradient_checkpointing": false,
18
  "image_length": 256,
19
+ "image_vocab_size": 16391,
20
  "init_std": 0.02,
21
  "is_encoder_decoder": true,
22
  "max_text_length": 64,
tools/train/scalable_shampoo/README.md CHANGED
@@ -4,4 +4,4 @@ Files copied from [google-research/scalable_shampoo/optax](https://github.com/go
4
 
5
  Imports have been modified to be relative.
6
 
7
- This will be replaced with `optax-shampoo` package eventually.
 
4
 
5
  Imports have been modified to be relative.
6
 
7
+ This will eventually be replaced with `optax-shampoo` package.
tools/train/scalable_shampoo/distributed_shampoo.py CHANGED
@@ -25,13 +25,12 @@
25
  # Authors: Rohan Anil (rohananil at google dot com)
26
  # & Vineet Gupta (vineet at google dot com)
27
  #
28
-
29
  """Distributed Shampoo Implementation."""
30
 
31
  import enum
32
  import functools
33
  import itertools
34
- from typing import Any, List, NamedTuple
35
 
36
  import chex
37
  import jax
@@ -43,6 +42,7 @@ from flax import struct
43
  from jax import lax
44
 
45
  from .quantization_utils import QuantizedValue
 
46
 
47
  # Dtype for inverse-pth root routine
48
  # Switch to f64 if you have hardware that supports it. Enable the jax flag
@@ -141,7 +141,10 @@ class GraftingType(enum.IntEnum):
141
 
142
 
143
  def power_iteration(
144
- matrix, num_iters=100, error_tolerance=1e-6, precision=lax.Precision.HIGHEST
 
 
 
145
  ):
146
  r"""Power iteration algorithm.
147
 
@@ -156,10 +159,10 @@ def power_iteration(
156
  matrix: the symmetric PSD matrix.
157
  num_iters: Number of iterations.
158
  error_tolerance: Iterative exit condition.
159
- precision: precision XLA related flag, the available options are:
160
- a) lax.Precision.DEFAULT (better step time, but not precise)
161
- b) lax.Precision.HIGH (increased precision, slower)
162
- c) lax.Precision.HIGHEST (best possible precision, slowest)
163
 
164
  Returns:
165
  eigen vector, eigen value
@@ -196,7 +199,11 @@ def power_iteration(
196
  return v_out, s_out
197
 
198
 
199
- def mat_power(mat_m, p, precision=lax.Precision.HIGHEST):
 
 
 
 
200
  """A simple matrix power method. M^p where p can be TracedValue."""
201
  power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
202
 
@@ -245,15 +252,19 @@ def matrix_inverse_pth_root(
245
  num_iters: Maximum number of iterations.
246
  ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
247
  error_tolerance: Error indicator, useful for early termination.
248
- precision: precision XLA related flag, the available options are:
249
- a) lax.Precision.DEFAULT (better step time, but not precise)
250
- b) lax.Precision.HIGH (increased precision, slower)
251
- c) lax.Precision.HIGHEST (best possible precision, slowest)
252
 
253
  Returns:
254
  matrix^(-1/p)
255
  """
256
 
 
 
 
 
257
  assert matrix.shape[0] == matrix.shape[1]
258
 
259
  # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
@@ -336,8 +347,8 @@ def merge_small_dims(shape_to_merge, max_dim):
336
  return resulting_shape
337
 
338
 
339
- def pad_matrix(mat, max_size):
340
- """Pad a matrix to a max_size.
341
 
342
  Args:
343
  mat: a matrix to pad.
@@ -346,19 +357,132 @@ def pad_matrix(mat, max_size):
346
  Returns:
347
  Given M returns [[M, 0], [0, I]]
348
  """
349
- size = mat.shape[0]
350
- assert size <= max_size
351
- if size == max_size:
 
 
 
 
 
 
 
 
352
  return mat
353
- pad_size = max_size - size
354
- zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
355
- zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
 
356
  eye = jnp.eye(pad_size, dtype=mat.dtype)
357
  mat = jnp.concatenate([mat, zs1], 1)
358
  mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
359
  return mat
360
 
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def pad_vector(vec, max_size):
363
  """Pad a vector to a max_size.
364
 
@@ -694,18 +818,17 @@ def distributed_shampoo(
694
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
695
  shard_optimizer_states: Shard optimizer states to save memory in model
696
  parallel training.
697
- best_effort_memory_usage_reduction: Best effort memory usage reduction.
698
- diagonal_statistics -> jnp.bfloat16
699
- momentum buffers (2x) -> jnp.int8
700
  statistics, preconditioners -> jnp.int16 + diagonals
701
  inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
702
  determine that using this threshold.
703
  moving_average_for_momentum: Whether to use moving average for momentum
704
  instead of exponential moving average.
705
  skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
706
- greater than this value.
707
- clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
708
- when using RMSProp Grafting).
709
  precision: precision XLA related flag, the available options are: a)
710
  lax.Precision.DEFAULT (better step time, but not precise) b)
711
  lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
@@ -1167,7 +1290,7 @@ def distributed_shampoo(
1167
  new_padded_statistics = []
1168
  for stat in new_stats_flat:
1169
  new_padded_statistics.extend(
1170
- [pad_matrix(stat, max_size) for stat in stat.statistics]
1171
  )
1172
 
1173
  # Create global stats
@@ -1388,7 +1511,7 @@ def distributed_shampoo(
1388
  num_devices = lax.psum(1, batch_axis_name)
1389
  num_statistics = len(statistics)
1390
  # Pad statistics and exponents to next multiple of num_devices.
1391
- packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1392
  to_pad = -num_statistics % num_devices
1393
  packed_statistics.extend(
1394
  [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
@@ -1540,7 +1663,7 @@ def distributed_shampoo(
1540
  # diagonals [d] f32
1541
  # bucket_sizes [d] f32
1542
  packed_quantized_statistics = [
1543
- pad_matrix(stat.quantized, max_size) for stat in statistics
1544
  ]
1545
  packed_quantized_diagonals = [
1546
  pad_vector(stat.diagonal, max_size) for stat in statistics
@@ -1772,7 +1895,7 @@ def distributed_shampoo(
1772
  """
1773
  num_statistics = len(statistics)
1774
  to_pad = -num_statistics % num_devices_for_pjit
1775
- padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1776
  padded_statistics.extend(
1777
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
1778
  )
 
25
  # Authors: Rohan Anil (rohananil at google dot com)
26
  # & Vineet Gupta (vineet at google dot com)
27
  #
 
28
  """Distributed Shampoo Implementation."""
29
 
30
  import enum
31
  import functools
32
  import itertools
33
+ from typing import Any, List, NamedTuple, Tuple
34
 
35
  import chex
36
  import jax
 
42
  from jax import lax
43
 
44
  from .quantization_utils import QuantizedValue
45
+ from .symmetric_matrices import symmetric_matrices
46
 
47
  # Dtype for inverse-pth root routine
48
  # Switch to f64 if you have hardware that supports it. Enable the jax flag
 
141
 
142
 
143
  def power_iteration(
144
+ matrix,
145
+ num_iters=100,
146
+ error_tolerance=1e-6,
147
+ precision=lax.Precision.HIGHEST,
148
  ):
149
  r"""Power iteration algorithm.
150
 
 
159
  matrix: the symmetric PSD matrix.
160
  num_iters: Number of iterations.
161
  error_tolerance: Iterative exit condition.
162
+ precision: precision XLA related flag, the available options are: a)
163
+ lax.Precision.DEFAULT (better step time, but not precise) b)
164
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
165
+ (best possible precision, slowest)
166
 
167
  Returns:
168
  eigen vector, eigen value
 
199
  return v_out, s_out
200
 
201
 
202
+ def mat_power(
203
+ mat_m,
204
+ p,
205
+ precision=lax.Precision.HIGHEST,
206
+ ):
207
  """A simple matrix power method. M^p where p can be TracedValue."""
208
  power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
209
 
 
252
  num_iters: Maximum number of iterations.
253
  ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
254
  error_tolerance: Error indicator, useful for early termination.
255
+ precision: precision XLA related flag, the available options are: a)
256
+ lax.Precision.DEFAULT (better step time, but not precise) b)
257
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
258
+ (best possible precision, slowest)
259
 
260
  Returns:
261
  matrix^(-1/p)
262
  """
263
 
264
+ # If the input is not square, materialize it from the concatenated form.
265
+ if matrix.shape[0] != matrix.shape[1]:
266
+ matrix = symmetric_matrices.materialize_matrix_from_concat(matrix)
267
+
268
  assert matrix.shape[0] == matrix.shape[1]
269
 
270
  # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
 
347
  return resulting_shape
348
 
349
 
350
+ def pad_square_matrix(mat, max_size):
351
+ """Pad a square matrix up to max_size.
352
 
353
  Args:
354
  mat: a matrix to pad.
 
357
  Returns:
358
  Given M returns [[M, 0], [0, I]]
359
  """
360
+ rows, cols = mat.shape
361
+ if rows != cols:
362
+ raise ValueError(
363
+ "Must have rows == cols, instead got " f"rows={rows}, cols={cols}"
364
+ )
365
+ if cols > max_size:
366
+ raise ValueError(
367
+ "Must have cols <= max_size. Instead got "
368
+ f"cols={cols}, max_size={max_size}."
369
+ )
370
+ if rows == max_size:
371
  return mat
372
+ pad_size = max_size - rows
373
+
374
+ zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
375
+ zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
376
  eye = jnp.eye(pad_size, dtype=mat.dtype)
377
  mat = jnp.concatenate([mat, zs1], 1)
378
  mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
379
  return mat
380
 
381
 
382
+ def make_sliced_padding(
383
+ symmetric_block_size,
384
+ num_blocks,
385
+ starting_block,
386
+ dtype,
387
+ ):
388
+ """Returns padding for symmetric block matrix.
389
+
390
+ Specifically, the padding is given concatenated rectangular matrices
391
+ representing the lower-triangular rows below the starting block. For example,
392
+ if we want to pad the symmetric matrix
393
+
394
+ M = [[A, B^T]
395
+ [B, C]],
396
+
397
+ the desired output (in terms of the full matrix) with num_blocks = 4 is
398
+
399
+ M_padded = [[A, B^T, 0, 0]
400
+ [B, C, 0, 0]
401
+ [0, 0, I, 0]
402
+ 0, 0, 0, I].
403
+
404
+ We would represent M as the block matrix mat = [A, B, C]. In this form, the
405
+ additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower
406
+ triangular parts in the third and fourth rows).
407
+
408
+ Args:
409
+ symmetric_block_size: The size of each block.
410
+ num_blocks: The total number of blocks.
411
+ starting_block: The block where to start the padding.
412
+ dtype: The type to use for the blocks.
413
+ """
414
+ if starting_block == num_blocks:
415
+ return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype)
416
+
417
+ blocks = []
418
+ for i in range(starting_block, num_blocks):
419
+ blocks.append(
420
+ jnp.zeros(
421
+ shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype
422
+ )
423
+ )
424
+ blocks.append(jnp.eye(symmetric_block_size, dtype=dtype))
425
+ return jnp.concatenate(blocks, axis=-1)
426
+
427
+
428
+ def pad_block_symmetric_matrix(
429
+ mat,
430
+ symmetric_block_size,
431
+ max_num_blocks,
432
+ ):
433
+ """Returns the padded blocked symmetric matrix.
434
+
435
+ The size of the padded matrix will be:
436
+ [symmetric_block_size, symmetric_block_size * max_num_blocks]
437
+
438
+ The input matrix can either:
439
+ - Be square with size less or equal to symmetric_block_size. In this case,
440
+ mat will first be padded to a square matrix of size symmetric_block_size,
441
+ and then be padded again up to the full size of the blocked matrix.
442
+ - Be a rectangle with number of rows equal to block size.
443
+ In this case, number of columns must be a multiple of number of rows, and
444
+ the ratio must correspond to a block representation of a symmetric matrix.
445
+ That is, the ratio must have form x * (x + 1) / 2. Here, x represents the
446
+ number of block rows represented by the matrix.
447
+
448
+ Args:
449
+ mat: The input block matrix.
450
+ symmetric_block_size: The size of blocks.
451
+ max_num_blocks: The largest number of blocks to pad to.
452
+ """
453
+ rows, cols = mat.shape
454
+ if rows > symmetric_block_size:
455
+ raise ValueError(
456
+ "Must have rows <= symmetric_block_size. Instead got "
457
+ f"rows={rows}, symmetric_block_size={symmetric_block_size}."
458
+ )
459
+ if rows > cols:
460
+ raise ValueError(
461
+ "Must have rows <= cols, instead got " f"rows={rows}, cols={cols}."
462
+ )
463
+ if cols > symmetric_block_size * max_num_blocks:
464
+ raise ValueError(
465
+ "Must have cols <= symmetric_block_size * max_num_blocks "
466
+ f"Instead got cols={cols}, "
467
+ f"symmetric_block_size={symmetric_block_size}, "
468
+ f"max_num_blocks={max_num_blocks}."
469
+ )
470
+ if rows < symmetric_block_size:
471
+ mat = pad_square_matrix(mat, max_size=symmetric_block_size)
472
+ # Update rows and cols after possibly padding in pad_square_matrix.
473
+ rows, cols = mat.shape
474
+ assert rows == symmetric_block_size
475
+ assert cols % rows == 0
476
+ filled_blocks = cols // rows
477
+ padding_blocks = make_sliced_padding(
478
+ symmetric_block_size=symmetric_block_size,
479
+ num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks),
480
+ starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks),
481
+ dtype=mat.dtype,
482
+ )
483
+ return jnp.concatenate([mat, padding_blocks], axis=-1)
484
+
485
+
486
  def pad_vector(vec, max_size):
487
  """Pad a vector to a max_size.
488
 
 
818
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
819
  shard_optimizer_states: Shard optimizer states to save memory in model
820
  parallel training.
821
+ best_effort_memory_usage_reduction: Best effort memory usage reduction. -
822
+ diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 -
 
823
  statistics, preconditioners -> jnp.int16 + diagonals
824
  inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
825
  determine that using this threshold.
826
  moving_average_for_momentum: Whether to use moving average for momentum
827
  instead of exponential moving average.
828
  skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
829
+ greater than this value.
830
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when
831
+ using RMSProp Grafting).
832
  precision: precision XLA related flag, the available options are: a)
833
  lax.Precision.DEFAULT (better step time, but not precise) b)
834
  lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
 
1290
  new_padded_statistics = []
1291
  for stat in new_stats_flat:
1292
  new_padded_statistics.extend(
1293
+ [pad_square_matrix(stat, max_size) for stat in stat.statistics]
1294
  )
1295
 
1296
  # Create global stats
 
1511
  num_devices = lax.psum(1, batch_axis_name)
1512
  num_statistics = len(statistics)
1513
  # Pad statistics and exponents to next multiple of num_devices.
1514
+ packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1515
  to_pad = -num_statistics % num_devices
1516
  packed_statistics.extend(
1517
  [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
 
1663
  # diagonals [d] f32
1664
  # bucket_sizes [d] f32
1665
  packed_quantized_statistics = [
1666
+ pad_square_matrix(stat.quantized, max_size) for stat in statistics
1667
  ]
1668
  packed_quantized_diagonals = [
1669
  pad_vector(stat.diagonal, max_size) for stat in statistics
 
1895
  """
1896
  num_statistics = len(statistics)
1897
  to_pad = -num_statistics % num_devices_for_pjit
1898
+ padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1899
  padded_statistics.extend(
1900
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
1901
  )
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py CHANGED
@@ -16,7 +16,7 @@
16
  """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
 
18
  import functools
19
- from typing import Any, List, Sequence, Union
20
 
21
  import jax
22
  import jax.numpy as jnp
@@ -192,7 +192,7 @@ def materialize_matrix(symmetric_matrix):
192
  @functools.partial(jax.jit, static_argnames=("num_blocks"))
193
  def materialize_matrix_from_concat(
194
  block_rows_concat,
195
- num_blocks,
196
  ):
197
  """Returns a materialized symmetric matrix from concatenated slices.
198
 
@@ -200,7 +200,11 @@ def materialize_matrix_from_concat(
200
  block_rows_concat: The matrix represented as the concatenated
201
  lower-triangular blocks.
202
  num_blocks: The number of block-rows used to represent the symmetric matrix.
 
203
  """
 
 
 
204
  block_size = block_rows_concat.shape[-2]
205
 
206
  block_rows = [
@@ -251,6 +255,28 @@ def update_sliced_rows(
251
  )
252
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  def find_num_blocks(block_rows_concat):
255
  """Returns the number of (row) blocks representing the concatenated matrix.
256
 
@@ -270,11 +296,147 @@ def find_num_blocks(block_rows_concat):
270
  # Compute the number of square blocks used to represent the matrix.
271
  total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
272
  # Determine the number of block rows by inverting y = x*(x+1)/2.
273
- num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
274
- if num_blocks * (num_blocks + 1) / 2 != total_blocks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  raise ValueError(
276
- "Could not determine an appropriate number of blocks for "
277
- "the concatenated matrix."
278
  )
279
- else:
280
- return num_blocks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
 
18
  import functools
19
+ from typing import Any, List, Optional, Sequence, Union
20
 
21
  import jax
22
  import jax.numpy as jnp
 
192
  @functools.partial(jax.jit, static_argnames=("num_blocks"))
193
  def materialize_matrix_from_concat(
194
  block_rows_concat,
195
+ num_blocks=None,
196
  ):
197
  """Returns a materialized symmetric matrix from concatenated slices.
198
 
 
200
  block_rows_concat: The matrix represented as the concatenated
201
  lower-triangular blocks.
202
  num_blocks: The number of block-rows used to represent the symmetric matrix.
203
+ If not specified, it is inferred from the shape of block_rows_concat.
204
  """
205
+ if num_blocks is None:
206
+ num_blocks = find_num_blocks(block_rows_concat)
207
+
208
  block_size = block_rows_concat.shape[-2]
209
 
210
  block_rows = [
 
255
  )
256
 
257
 
258
+ def num_blocks_from_total_blocks(total_blocks):
259
+ """Returns the number of blocks (i.e.
260
+
261
+ block rows) from the total blocks.
262
+
263
+ This is the inverse of the function x -> x*(x+1)/2.
264
+
265
+ For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
266
+ total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
267
+
268
+ Args:
269
+ total_blocks: The total blocks used to represent the matrix.
270
+ """
271
+ num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
272
+ if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
273
+ raise ValueError(
274
+ f"total_blocks={total_blocks} does not correspond to "
275
+ "a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
276
+ )
277
+ return num_blocks
278
+
279
+
280
  def find_num_blocks(block_rows_concat):
281
  """Returns the number of (row) blocks representing the concatenated matrix.
282
 
 
296
  # Compute the number of square blocks used to represent the matrix.
297
  total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
298
  # Determine the number of block rows by inverting y = x*(x+1)/2.
299
+ return num_blocks_from_total_blocks(total_blocks)
300
+
301
+
302
+ @functools.partial(jax.jit, static_argnames=("block_size"))
303
+ def slice_symmetric_matrix(
304
+ mat,
305
+ block_size,
306
+ ):
307
+ """Returns sliced row blocks.
308
+
309
+ Args:
310
+ mat: A symmetric matrix.
311
+ block_size: The size of the row slices.
312
+ """
313
+ num_rows = mat.shape[-2]
314
+ num_cols = mat.shape[-1]
315
+ if num_rows != num_cols:
316
+ raise ValueError("mat is not square.")
317
+ if num_rows % block_size != 0:
318
  raise ValueError(
319
+ "block size does not evenly divide rows. "
320
+ f"num_rows={num_rows}, block_size={block_size}"
321
  )
322
+ return SlicedSymmetricMatrix(
323
+ block_rows=[
324
+ mat[
325
+ Ellipsis,
326
+ i * block_size : (i + 1) * block_size,
327
+ 0 : (i + 1) * block_size,
328
+ ]
329
+ for i in range(num_rows // block_size)
330
+ ]
331
+ )
332
+
333
+
334
+ @functools.partial(jax.jit, static_argnames=("block_size"))
335
+ def slice_symmetric_matrix_concat(
336
+ mat,
337
+ block_size,
338
+ ):
339
+ """Returns the concatenated sliced row blocks.
340
+
341
+ Args:
342
+ mat: A symmetric matrix.
343
+ block_size: The size of the row slices.
344
+ """
345
+ sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
346
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
347
+
348
+
349
+ def sliced_matrix_diag(mat):
350
+ """Returns the diagonal of the symmetric matrix.
351
+
352
+ Args:
353
+ mat: The symmetric matrix represented in concatenated block form.
354
+ """
355
+ rows, cols = mat.shape
356
+ total_blocks = cols // rows
357
+ num_blocks = num_blocks_from_total_blocks(total_blocks)
358
+ diags = []
359
+ for i in range(num_blocks):
360
+ last_index = rows * ((i + 2) * (i + 1)) // 2
361
+ first_index = last_index - rows
362
+ diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
363
+ return jnp.concatenate(diags, axis=-1)
364
+
365
+
366
+ def diag_as_concat(diag, block_size):
367
+ """Returns the representation of a diagonal matrix in symmetric block form.
368
+
369
+ Args:
370
+ diag: The 1D array for the diagonals.
371
+ block_size: The size of blocks to use. Must divide the length of diag.
372
+ """
373
+ assert len(diag.shape) == 1 # diag must be 1D.
374
+ assert len(diag) % block_size == 0
375
+ num_diag_blocks = len(diag) // block_size
376
+ blocks = []
377
+ for i in range(num_diag_blocks):
378
+ blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
379
+ blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
380
+ return jnp.concatenate(blocks, axis=-1)
381
+
382
+
383
+ def row_abs_maxes(mat):
384
+ """Returns the max of the absolute values of the rows of the full matrix.
385
+
386
+ For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
387
+ mat = [1, 6, 2] with block_size = 1. In this case the function returns the
388
+ aboslute row maxes of the original symmetric matrix, [6, 6].
389
+
390
+ Args:
391
+ mat: The symmetric matrix represented as the concatenated blocks.
392
+ """
393
+ rows, cols = mat.shape
394
+
395
+ # Find col and row max for each block.
396
+ col_maxes = []
397
+ row_maxes = []
398
+ for i in range(cols // rows):
399
+ block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
400
+ col_maxes.append(jnp.max(block, axis=1))
401
+ row_maxes.append(jnp.max(block, axis=0))
402
+
403
+ # global row max from block maxes.
404
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
405
+ maxes = []
406
+ for i in range(num_blocks):
407
+ maxes.append(
408
+ jnp.concatenate(
409
+ row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
410
+ + [
411
+ col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
412
+ for j in range(i + 1, num_blocks)
413
+ ],
414
+ axis=-1,
415
+ )
416
+ )
417
+
418
+ return jnp.max(jnp.stack(maxes), axis=0)
419
+
420
+
421
+ def times_vector(mat, vec):
422
+ """Returns the symmetric block-concatenated matrix multiplied by a vector.
423
+
424
+ Specifically, each value in the vector is multiplied by a row of the full
425
+ matrix. That is, the vector is broadcast and multiplied element-wise. Note
426
+ this would be the transpose of full_mat * vec if full_mat represented the full
427
+ symmetric matrix.
428
+
429
+ Args:
430
+ mat: The symmetric matrix represented as the concatenated blocks.
431
+ vec: The vector, having the same dimension as the materialized matrix.
432
+ """
433
+ rows, cols = mat.shape
434
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
435
+ multiplied = []
436
+ for i in range(num_blocks):
437
+ mat_block = mat[
438
+ Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
439
+ ]
440
+ vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
441
+ multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
442
+ return jnp.concatenate(multiplied, axis=-1)
tools/train/train.py CHANGED
@@ -368,6 +368,12 @@ class TrainingArguments:
368
  "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
369
  },
370
  )
 
 
 
 
 
 
371
 
372
  num_train_epochs: int = field(
373
  default=3, metadata={"help": "Total number of training epochs to perform."}
@@ -450,6 +456,11 @@ class TrainingArguments:
450
  metadata={"help": "Verify that TPU is not in use."},
451
  )
452
 
 
 
 
 
 
453
  mp_devices: Optional[int] = field(
454
  default=1,
455
  metadata={
@@ -500,6 +511,11 @@ class TrainingArguments:
500
  f"Output directory ({self.output_dir}) already exists and is not empty."
501
  "Use --overwrite_output_dir to overcome."
502
  )
 
 
 
 
 
503
  assert (
504
  self.mp_devices > 0
505
  ), f"Number of devices for model parallelism must be > 0"
@@ -530,6 +546,12 @@ def main():
530
  else:
531
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
532
 
 
 
 
 
 
 
533
  # Make one log on every process with the configuration for debugging.
534
  logging.basicConfig(
535
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -748,8 +770,20 @@ def main():
748
  graft_type=graft_type,
749
  nesterov=False,
750
  exponent_override=0,
751
- statistics_partition_spec=PartitionSpec(None, "dp", None),
752
- preconditioner_partition_spec=PartitionSpec("dp", None, None),
 
 
 
 
 
 
 
 
 
 
 
 
753
  num_devices_for_pjit=training_args.dp_devices,
754
  shard_optimizer_states=True,
755
  inverse_failure_threshold=0.1,
@@ -917,7 +951,7 @@ def main():
917
 
918
  # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
919
  # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
920
- use_vmap_trick = True
921
 
922
  # make grad_param_spec for vmap
923
  if use_vmap_trick:
@@ -1145,7 +1179,8 @@ def main():
1145
  self.log_time("train_per_log", delta_time, offset=False)
1146
 
1147
  def log_time(self, key, duration, offset=True):
1148
- wandb.log({f"time/{key}": duration, **self.state_dict})
 
1149
  if offset:
1150
  self.offset_time += duration
1151
 
@@ -1191,7 +1226,11 @@ def main():
1191
  # ======================== Evaluating ==============================
1192
  if training_args.do_eval:
1193
  start_eval_time = time.perf_counter()
1194
- eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
 
 
 
 
1195
  eval_steps = (
1196
  len_eval_dataset // eval_batch_size_per_step
1197
  if len_eval_dataset is not None
@@ -1353,10 +1392,12 @@ def main():
1353
  metrics_logger.update_state_metrics(local_state)
1354
  metrics_logger.log({})
1355
 
1356
- # Generate an epoch by shuffling sampling indices from the train dataset
 
 
1357
  train_loader = dataset.dataloader(
1358
  "train",
1359
- batch_size_per_node,
1360
  epoch,
1361
  )
1362
  # train
@@ -1373,12 +1414,12 @@ def main():
1373
 
1374
  # set correct shape to batch
1375
  # - add grad_step dim if gradient_accumulation_steps > 1
1376
- # - split per dp device if not multi-host for vmap trick (does not work in multi-host)
1377
  bs_shape = (
1378
- (batch_size_per_node_per_grad_step,)
1379
  if not use_vmap_trick
1380
  else (
1381
  jax.local_device_count()
 
1382
  // training_args.mp_devices, # local dp devices
1383
  training_args.per_device_train_batch_size,
1384
  )
 
368
  "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
369
  },
370
  )
371
+ shard_shampoo_across: str = field(
372
+ default="dp",
373
+ metadata={
374
+ "help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
375
+ },
376
+ )
377
 
378
  num_train_epochs: int = field(
379
  default=3, metadata={"help": "Total number of training epochs to perform."}
 
456
  metadata={"help": "Verify that TPU is not in use."},
457
  )
458
 
459
+ use_vmap_trick: bool = field(
460
+ default=True,
461
+ metadata={"help": "Verify that TPU is not in use."},
462
+ )
463
+
464
  mp_devices: Optional[int] = field(
465
  default=1,
466
  metadata={
 
511
  f"Output directory ({self.output_dir}) already exists and is not empty."
512
  "Use --overwrite_output_dir to overcome."
513
  )
514
+ assert self.shard_shampoo_across in [
515
+ "dp",
516
+ "mp",
517
+ "2d",
518
+ ], f"Shard shampoo across {self.shard_shampoo_across} not supported."
519
  assert (
520
  self.mp_devices > 0
521
  ), f"Number of devices for model parallelism must be > 0"
 
546
  else:
547
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
548
 
549
+ # check arguments
550
+ if training_args.mp_devices > jax.local_device_count():
551
+ assert (
552
+ data_args.seed_dataset is not None
553
+ ), "Seed dataset must be provided when model is split over multiple hosts"
554
+
555
  # Make one log on every process with the configuration for debugging.
556
  logging.basicConfig(
557
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
770
  graft_type=graft_type,
771
  nesterov=False,
772
  exponent_override=0,
773
+ statistics_partition_spec=PartitionSpec(
774
+ None, training_args.shard_shampoo_across, None
775
+ )
776
+ if training_args.shard_shampoo_across != "2d"
777
+ else PartitionSpec(None, "dp", "mp"),
778
+ preconditioner_partition_spec=PartitionSpec(
779
+ training_args.shard_shampoo_across, None, None
780
+ )
781
+ if training_args.shard_shampoo_across != "2d"
782
+ else PartitionSpec(
783
+ "mp" if training_args.mp_devices > training_args.dp_devices else "dp",
784
+ None,
785
+ None,
786
+ ),
787
  num_devices_for_pjit=training_args.dp_devices,
788
  shard_optimizer_states=True,
789
  inverse_failure_threshold=0.1,
 
951
 
952
  # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
953
  # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
954
+ use_vmap_trick = training_args.use_vmap_trick
955
 
956
  # make grad_param_spec for vmap
957
  if use_vmap_trick:
 
1179
  self.log_time("train_per_log", delta_time, offset=False)
1180
 
1181
  def log_time(self, key, duration, offset=True):
1182
+ if jax.process_index() == 0:
1183
+ wandb.log({f"time/{key}": duration, **self.state_dict})
1184
  if offset:
1185
  self.offset_time += duration
1186
 
 
1226
  # ======================== Evaluating ==============================
1227
  if training_args.do_eval:
1228
  start_eval_time = time.perf_counter()
1229
+ eval_loader = dataset.dataloader(
1230
+ "eval",
1231
+ eval_batch_size_per_step
1232
+ * max(1, training_args.mp_devices // jax.local_device_count()),
1233
+ )
1234
  eval_steps = (
1235
  len_eval_dataset // eval_batch_size_per_step
1236
  if len_eval_dataset is not None
 
1392
  metrics_logger.update_state_metrics(local_state)
1393
  metrics_logger.log({})
1394
 
1395
+ # load data - may be replicated on multiple nodes
1396
+ node_groups = max(1, training_args.mp_devices // jax.local_device_count())
1397
+ loader_bs = batch_size_per_node * node_groups
1398
  train_loader = dataset.dataloader(
1399
  "train",
1400
+ loader_bs,
1401
  epoch,
1402
  )
1403
  # train
 
1414
 
1415
  # set correct shape to batch
1416
  # - add grad_step dim if gradient_accumulation_steps > 1
 
1417
  bs_shape = (
1418
+ (batch_size_per_node_per_grad_step * node_groups,)
1419
  if not use_vmap_trick
1420
  else (
1421
  jax.local_device_count()
1422
+ * node_groups
1423
  // training_args.mp_devices, # local dp devices
1424
  training_args.per_device_train_batch_size,
1425
  )