boris commited on
Commit
b7b619a
·
unverified ·
1 Parent(s): 7939874

feat(train): log norm and histograms (#143)

Browse files

* feat(train): log norm and histograms
* feat: update shampoo

tools/train/scalable_shampoo/distributed_shampoo.py CHANGED
@@ -832,8 +832,11 @@ def distributed_shampoo(
832
  if not _skip_preconditioning(param):
833
  sizes = [s[0] for s in shapes]
834
  shapes = preconditioner.shapes_for_preconditioners()
835
- statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
836
- preconditioners = [jnp.eye(max_size) for s in shapes]
 
 
 
837
  padded_statistics.extend(statistics)
838
  padded_preconditioners.extend(preconditioners)
839
  exponent = (
@@ -1244,8 +1247,10 @@ def distributed_shampoo(
1244
  preconditioners = []
1245
  if not _skip_preconditioning(param):
1246
  shapes = preconditioner.shapes_for_preconditioners()
1247
- statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
1248
- preconditioners = [jnp.eye(s[0]) for s in shapes]
 
 
1249
 
1250
  diagonal_statistics = []
1251
  if _graft_type_has_diagonal_statistics():
 
832
  if not _skip_preconditioning(param):
833
  sizes = [s[0] for s in shapes]
834
  shapes = preconditioner.shapes_for_preconditioners()
835
+ statistics = [
836
+ matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
837
+ for s in shapes
838
+ ]
839
+ preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
840
  padded_statistics.extend(statistics)
841
  padded_preconditioners.extend(preconditioners)
842
  exponent = (
 
1247
  preconditioners = []
1248
  if not _skip_preconditioning(param):
1249
  shapes = preconditioner.shapes_for_preconditioners()
1250
+ statistics = [
1251
+ matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
1252
+ ]
1253
+ preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]
1254
 
1255
  diagonal_statistics = []
1256
  if _graft_type_has_diagonal_statistics():
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py CHANGED
@@ -16,10 +16,11 @@
16
  """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
 
18
  import functools
19
- from typing import List, Union
20
 
21
  import jax
22
  import jax.numpy as jnp
 
23
  from flax import struct
24
  from jax import lax
25
 
@@ -41,6 +42,7 @@ class SlicedSymmetricMatrix:
41
  def product_with_transpose(
42
  mat1,
43
  mat2,
 
44
  precision=lax.Precision.DEFAULT,
45
  ):
46
  """Returns mat1 * mat2^T for two matrices (possibly batched).
@@ -50,50 +52,85 @@ def product_with_transpose(
50
  Args:
51
  mat1: First matrix.
52
  mat2: Second matrix.
 
53
  precision: JAX precision to use for the multiplication.
54
  """
55
- return jnp.einsum("...ij,...kj->...ik", mat1, mat2, precision=precision)
56
 
57
 
58
- @functools.partial(jax.jit, static_argnames=("block_size", "precision"))
59
  def sliced_transposed_product(
60
  mat,
61
  block_size,
 
62
  precision=lax.Precision.DEFAULT,
63
  ):
64
- """Returns the blocked slices representing a symmetric matrix mat*mat^T.
 
 
 
65
 
66
  Args:
67
- mat: The matrix for which we will compute mat*mat^T. It does not need to be
68
- square, and may be batched.
69
  block_size: The size of row blocks to compute.
 
70
  precision: The precision to use in each computation.
71
 
72
  Raises:
73
  ValueError: Raised when the specified block size does not evenly divide
74
  the number of rows of the input mat.
75
  """
76
- num_rows = mat.shape[-2]
 
 
 
 
 
 
 
 
 
 
 
 
77
  if num_rows % block_size != 0:
78
  raise ValueError(
79
  "The row dimension must be divisible by block_size. "
80
  f"Instead got row dimension={num_rows} and block_size={block_size}."
81
  )
82
- block_rows = [
83
- product_with_transpose(
84
- mat[Ellipsis, i * block_size : (i + 1) * block_size, :],
85
- mat[Ellipsis, 0 : (i + 1) * block_size, :],
86
- precision,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
- for i in range(num_rows // block_size)
89
- ]
90
  return SlicedSymmetricMatrix(block_rows=block_rows)
91
 
92
 
93
- @functools.partial(jax.jit, static_argnames=("block_size", "precision"))
94
  def sliced_transposed_product_concat(
95
  mat,
96
  block_size,
 
97
  precision=lax.Precision.DEFAULT,
98
  ):
99
  """Returns the concatenated slices representing mat*mat^T.
@@ -102,6 +139,7 @@ def sliced_transposed_product_concat(
102
  mat: The matrix for which we will compute mat*mat^T. It does not need to be
103
  square, and may be batched.
104
  block_size: The size of row blocks to compute.
 
105
  precision: The precision to use in each computation.
106
 
107
  Raises:
@@ -109,7 +147,7 @@ def sliced_transposed_product_concat(
109
  the number of rows of the input mat.
110
  """
111
  sliced_symmetric_matrix = sliced_transposed_product(
112
- mat=mat, block_size=block_size, precision=precision
113
  )
114
  return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
115
 
@@ -179,12 +217,13 @@ def materialize_matrix_from_concat(
179
  return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
180
 
181
 
182
- @functools.partial(jax.jit, static_argnames=("alpha", "beta"))
183
  def update_sliced_rows(
184
  symmetric_matrix,
185
  mat,
186
  alpha,
187
  beta,
 
188
  ):
189
  """Implements the blocked equivalent of SYRK.
190
 
@@ -197,15 +236,45 @@ def update_sliced_rows(
197
  should match that of symmetric_matrix.
198
  alpha: The weight for the update.
199
  beta: The weight for the original symmetric matrix.
 
200
 
201
  Returns:
202
  The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
203
  """
204
  block_size = symmetric_matrix.block_rows[0].shape[-2]
205
- sym_prod = sliced_transposed_product(mat=mat, block_size=block_size)
206
  return SlicedSymmetricMatrix(
207
  block_rows=[
208
  update * alpha + row * beta
209
  for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
210
  ]
211
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
 
18
  import functools
19
+ from typing import Any, List, Sequence, Union
20
 
21
  import jax
22
  import jax.numpy as jnp
23
+ import numpy as np
24
  from flax import struct
25
  from jax import lax
26
 
 
42
  def product_with_transpose(
43
  mat1,
44
  mat2,
45
+ axes,
46
  precision=lax.Precision.DEFAULT,
47
  ):
48
  """Returns mat1 * mat2^T for two matrices (possibly batched).
 
52
  Args:
53
  mat1: First matrix.
54
  mat2: Second matrix.
55
+ axes: The axes over which to apply the product.
56
  precision: JAX precision to use for the multiplication.
57
  """
58
+ return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
59
 
60
 
61
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
62
  def sliced_transposed_product(
63
  mat,
64
  block_size,
65
+ axes=(-1,),
66
  precision=lax.Precision.DEFAULT,
67
  ):
68
+ """Returns the blocked slices representing a symmetric contraction.
69
+
70
+ Specifically, the output is a contraction of the input mat with itself, in the
71
+ specified axes.
72
 
73
  Args:
74
+ mat: The matrix for which we will compute a contraction with itself.
 
75
  block_size: The size of row blocks to compute.
76
+ axes: Axes to use for the contraction.
77
  precision: The precision to use in each computation.
78
 
79
  Raises:
80
  ValueError: Raised when the specified block size does not evenly divide
81
  the number of rows of the input mat.
82
  """
83
+ rank = len(mat.shape)
84
+
85
+ def _make_axis_positive(ax):
86
+ assert -rank <= ax < rank
87
+ return ax + rank if ax < 0 else ax
88
+
89
+ positive_axes = [_make_axis_positive(ax) for ax in axes]
90
+ assert len(positive_axes) == len(axes)
91
+ remaining_axes = set(range(rank)) - set(positive_axes)
92
+ assert len(remaining_axes) == 1
93
+ remaining_ax = remaining_axes.pop()
94
+
95
+ num_rows = mat.shape[remaining_ax]
96
  if num_rows % block_size != 0:
97
  raise ValueError(
98
  "The row dimension must be divisible by block_size. "
99
  f"Instead got row dimension={num_rows} and block_size={block_size}."
100
  )
101
+
102
+ block_rows = []
103
+ for i in range(num_rows // block_size):
104
+ start_indices = [0] * rank
105
+ start_indices[remaining_ax] = i * block_size
106
+
107
+ slice_sizes = list(mat.shape)
108
+ slice_sizes[remaining_ax] = block_size
109
+
110
+ slice_sizes_full = list(mat.shape)
111
+ slice_sizes_full[remaining_ax] = (i + 1) * block_size
112
+
113
+ block_rows.append(
114
+ product_with_transpose(
115
+ lax.dynamic_slice(
116
+ mat, start_indices=start_indices, slice_sizes=slice_sizes
117
+ ),
118
+ lax.dynamic_slice(
119
+ mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
120
+ ),
121
+ axes=(axes, axes),
122
+ precision=precision,
123
+ )
124
  )
125
+
 
126
  return SlicedSymmetricMatrix(block_rows=block_rows)
127
 
128
 
129
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
130
  def sliced_transposed_product_concat(
131
  mat,
132
  block_size,
133
+ axes=(-1,),
134
  precision=lax.Precision.DEFAULT,
135
  ):
136
  """Returns the concatenated slices representing mat*mat^T.
 
139
  mat: The matrix for which we will compute mat*mat^T. It does not need to be
140
  square, and may be batched.
141
  block_size: The size of row blocks to compute.
142
+ axes: Axes to use for the contraction.
143
  precision: The precision to use in each computation.
144
 
145
  Raises:
 
147
  the number of rows of the input mat.
148
  """
149
  sliced_symmetric_matrix = sliced_transposed_product(
150
+ mat=mat, block_size=block_size, axes=axes, precision=precision
151
  )
152
  return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
153
 
 
217
  return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
218
 
219
 
220
+ @functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
221
  def update_sliced_rows(
222
  symmetric_matrix,
223
  mat,
224
  alpha,
225
  beta,
226
+ axes=(-1,),
227
  ):
228
  """Implements the blocked equivalent of SYRK.
229
 
 
236
  should match that of symmetric_matrix.
237
  alpha: The weight for the update.
238
  beta: The weight for the original symmetric matrix.
239
+ axes: Axes to use for the contraction of the update.
240
 
241
  Returns:
242
  The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
243
  """
244
  block_size = symmetric_matrix.block_rows[0].shape[-2]
245
+ sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
246
  return SlicedSymmetricMatrix(
247
  block_rows=[
248
  update * alpha + row * beta
249
  for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
250
  ]
251
  )
252
+
253
+
254
+ def find_num_blocks(block_rows_concat):
255
+ """Returns the number of (row) blocks representing the concatenated matrix.
256
+
257
+ For example, an input with dimensions [256, 2560] represents 10 square blocks,
258
+ which matches 4 lower-triangular block rows (1+2+3+4). So this function will
259
+ return 4.
260
+
261
+ Use ordinary numpy functions here so that the returned value is static.
262
+
263
+ Args:
264
+ block_rows_concat: The concatenated block array.
265
+
266
+ Raises:
267
+ ValueError: When the dimensions of the matrix do not correspond to a lower
268
+ triangular block representation.
269
+ """
270
+ # Compute the number of square blocks used to represent the matrix.
271
+ total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
272
+ # Determine the number of block rows by inverting y = x*(x+1)/2.
273
+ num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
274
+ if num_blocks * (num_blocks + 1) / 2 != total_blocks:
275
+ raise ValueError(
276
+ "Could not determine an appropriate number of blocks for "
277
+ "the concatenated matrix."
278
+ )
279
+ else:
280
+ return num_blocks
tools/train/train.py CHANGED
@@ -37,7 +37,7 @@ import optax
37
  import transformers
38
  import wandb
39
  from datasets import Dataset
40
- from flax.core.frozen_dict import FrozenDict, freeze
41
  from flax.serialization import from_bytes, to_bytes
42
  from flax.training import train_state
43
  from flax.training.common_utils import onehot
@@ -405,6 +405,12 @@ class TrainingArguments:
405
  default=False,
406
  metadata={"help": "Log model to wandb at `save_steps` frequency."},
407
  )
 
 
 
 
 
 
408
 
409
  seed_model: int = field(
410
  default=42,
@@ -514,10 +520,22 @@ class MetricsLogger:
514
 
515
  def log(self, metrics, prefix=None):
516
  if jax.process_index() == 0:
517
- log_metrics = {
518
- f"{prefix}/{k}" if prefix is not None else k: v
519
- for k, v in metrics.items()
520
- }
 
 
 
 
 
 
 
 
 
 
 
 
521
  wandb.log({**log_metrics, **self.state_dict})
522
 
523
 
@@ -1024,8 +1042,9 @@ def main():
1024
  lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
1025
  )
1026
 
1027
- # update state
1028
  grads = with_sharding_constraint(grads, param_spec)
 
 
1029
  state = state.apply_gradients(
1030
  grads=grads,
1031
  dropout_rng=dropout_rng,
@@ -1033,11 +1052,49 @@ def main():
1033
  train_samples=state.train_samples + batch_size_per_step,
1034
  )
1035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1036
  metrics = {
1037
  "loss": loss,
1038
  "learning_rate": learning_rate_fn(state.step),
 
 
1039
  }
1040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1041
  return state, metrics
1042
 
1043
  # Define eval fn
 
37
  import transformers
38
  import wandb
39
  from datasets import Dataset
40
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
41
  from flax.serialization import from_bytes, to_bytes
42
  from flax.training import train_state
43
  from flax.training.common_utils import onehot
 
405
  default=False,
406
  metadata={"help": "Log model to wandb at `save_steps` frequency."},
407
  )
408
+ log_histograms: bool = field(
409
+ default=False,
410
+ metadata={
411
+ "help": "Log parameters and gradients histograms. Slows down training."
412
+ },
413
+ )
414
 
415
  seed_model: int = field(
416
  default=42,
 
520
 
521
  def log(self, metrics, prefix=None):
522
  if jax.process_index() == 0:
523
+ log_metrics = {}
524
+ for k, v in metrics.items():
525
+ if prefix is not None:
526
+ k = f"{prefix}/{k}"
527
+ if "_norm" in k:
528
+ log_metrics[f"{k}/"] = unfreeze(v)
529
+ elif "_hist" in k:
530
+ v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
531
+ v = jax.tree_map(
532
+ lambda x: wandb.Histogram(np_histogram=x),
533
+ v,
534
+ is_leaf=lambda x: isinstance(x, tuple),
535
+ )
536
+ log_metrics[f"{k}/"] = v
537
+ else:
538
+ log_metrics[k] = v
539
  wandb.log({**log_metrics, **self.state_dict})
540
 
541
 
 
1042
  lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
1043
  )
1044
 
 
1045
  grads = with_sharding_constraint(grads, param_spec)
1046
+
1047
+ # update state
1048
  state = state.apply_gradients(
1049
  grads=grads,
1050
  dropout_rng=dropout_rng,
 
1052
  train_samples=state.train_samples + batch_size_per_step,
1053
  )
1054
 
1055
+ # get norm and histogram of grads and params
1056
+ zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
1057
+
1058
+ def maybe_fn(fn, val, zeros):
1059
+ """Call fn only if it is a logging step"""
1060
+ return jax.lax.cond(
1061
+ state.step % training_args.logging_steps == 0,
1062
+ fn,
1063
+ lambda _: zeros,
1064
+ val,
1065
+ )
1066
+
1067
+ def norm(val):
1068
+ return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
1069
+
1070
+ gradients_norm = maybe_fn(norm, grads, zeros_norm)
1071
+ params_norm = maybe_fn(norm, state.params, zeros_norm)
1072
+
1073
  metrics = {
1074
  "loss": loss,
1075
  "learning_rate": learning_rate_fn(state.step),
1076
+ "gradients_norm": gradients_norm,
1077
+ "params_norm": params_norm,
1078
  }
1079
 
1080
+ if training_args.log_histograms:
1081
+ zeros_hist = jax.tree_map(
1082
+ lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
1083
+ )
1084
+
1085
+ def histogram(val):
1086
+ return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
1087
+
1088
+ gradients_hist = maybe_fn(histogram, grads, zeros_hist)
1089
+ params_hist = maybe_fn(histogram, state.params, zeros_hist)
1090
+
1091
+ metrics.update(
1092
+ {
1093
+ "params_hist": params_hist,
1094
+ "gradients_hist": gradients_hist,
1095
+ }
1096
+ )
1097
+
1098
  return state, metrics
1099
 
1100
  # Define eval fn