Spaces:
Runtime error
Runtime error
File size: 14,491 Bytes
6742988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 |
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
import functools
from typing import Any, List, Optional, Sequence, Union
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax import lax
@struct.dataclass
class SlicedSymmetricMatrix:
"""A symmetric matrix represented by lower-triangular block row slices.
For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented
by the block rows a and [b, c].
The matrix may be batched, in which case each entry of block_rows may have
dimension greater than 2. The last two dimensions represent the rows and cols.
"""
block_rows: List[jnp.ndarray]
def product_with_transpose(
mat1,
mat2,
axes,
precision=lax.Precision.DEFAULT,
):
"""Returns mat1 * mat2^T for two matrices (possibly batched).
The rows and columns are the last two dimensions for each matrix.
Args:
mat1: First matrix.
mat2: Second matrix.
axes: The axes over which to apply the product.
precision: JAX precision to use for the multiplication.
"""
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
def sliced_transposed_product(
mat,
block_size,
axes=(-1,),
precision=lax.Precision.DEFAULT,
):
"""Returns the blocked slices representing a symmetric contraction.
Specifically, the output is a contraction of the input mat with itself, in the
specified axes.
Args:
mat: The matrix for which we will compute a contraction with itself.
block_size: The size of row blocks to compute.
axes: Axes to use for the contraction.
precision: The precision to use in each computation.
Raises:
ValueError: Raised when the specified block size does not evenly divide
the number of rows of the input mat.
"""
rank = len(mat.shape)
def _make_axis_positive(ax):
assert -rank <= ax < rank
return ax + rank if ax < 0 else ax
positive_axes = [_make_axis_positive(ax) for ax in axes]
assert len(positive_axes) == len(axes)
remaining_axes = set(range(rank)) - set(positive_axes)
assert len(remaining_axes) == 1
remaining_ax = remaining_axes.pop()
num_rows = mat.shape[remaining_ax]
if num_rows % block_size != 0:
raise ValueError(
"The row dimension must be divisible by block_size. "
f"Instead got row dimension={num_rows} and block_size={block_size}."
)
block_rows = []
for i in range(num_rows // block_size):
start_indices = [0] * rank
start_indices[remaining_ax] = i * block_size
slice_sizes = list(mat.shape)
slice_sizes[remaining_ax] = block_size
slice_sizes_full = list(mat.shape)
slice_sizes_full[remaining_ax] = (i + 1) * block_size
block_rows.append(
product_with_transpose(
lax.dynamic_slice(
mat, start_indices=start_indices, slice_sizes=slice_sizes
),
lax.dynamic_slice(
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
),
axes=(axes, axes),
precision=precision,
)
)
return SlicedSymmetricMatrix(block_rows=block_rows)
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
def sliced_transposed_product_concat(
mat,
block_size,
axes=(-1,),
precision=lax.Precision.DEFAULT,
):
"""Returns the concatenated slices representing mat*mat^T.
Args:
mat: The matrix for which we will compute mat*mat^T. It does not need to be
square, and may be batched.
block_size: The size of row blocks to compute.
axes: Axes to use for the contraction.
precision: The precision to use in each computation.
Raises:
ValueError: Raised when the specified block size does not evenly divide
the number of rows of the input mat.
"""
sliced_symmetric_matrix = sliced_transposed_product(
mat=mat, block_size=block_size, axes=axes, precision=precision
)
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
@jax.jit
def materialize_matrix(symmetric_matrix):
"""Returns a materialized symmetric matrix.
Args:
symmetric_matrix: the matrix represented by lower-triangular block slices.
"""
block_rows = symmetric_matrix.block_rows
block_size = block_rows[0].shape[-2]
num_blocks = len(block_rows)
# Slice the lower-triangular and diagonal blocks into blocks.
blocks = [
[
block_row[Ellipsis, i * block_size : (i + 1) * block_size]
for i in range(k + 1)
]
for k, block_row in enumerate(block_rows)
]
# Generate the (off-diagonal) upper-triangular blocks.
off_diags = [[] for _ in range(num_blocks - 1)]
for k, block_row in enumerate(block_rows[1:]):
for i in range(k + 1):
off_diags[i].append(
jnp.swapaxes(
a=block_row[Ellipsis, i * block_size : (i + 1) * block_size],
axis1=-1,
axis2=-2,
)
)
return jnp.block(
[row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]]
)
@functools.partial(jax.jit, static_argnames=("num_blocks"))
def materialize_matrix_from_concat(
block_rows_concat,
num_blocks=None,
):
"""Returns a materialized symmetric matrix from concatenated slices.
Args:
block_rows_concat: The matrix represented as the concatenated
lower-triangular blocks.
num_blocks: The number of block-rows used to represent the symmetric matrix.
If not specified, it is inferred from the shape of block_rows_concat.
"""
if num_blocks is None:
num_blocks = find_num_blocks(block_rows_concat)
block_size = block_rows_concat.shape[-2]
block_rows = [
block_rows_concat[
Ellipsis,
(k * (k + 1))
// 2
* block_size : (((k + 1) * (k + 2)) // 2 + 1)
* block_size,
]
for k in range(num_blocks)
]
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
@functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
def update_sliced_rows(
symmetric_matrix,
mat,
alpha,
beta,
axes=(-1,),
):
"""Implements the blocked equivalent of SYRK.
Specifically, the symmetric matrix (represented using lower-triangular block
rows) is updated using the sliced product of mat.
Args:
symmetric_matrix: The symmetric matrix to update.
mat: The matrix to use for the update = mat * mat^T. The number of rows
should match that of symmetric_matrix.
alpha: The weight for the update.
beta: The weight for the original symmetric matrix.
axes: Axes to use for the contraction of the update.
Returns:
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
"""
block_size = symmetric_matrix.block_rows[0].shape[-2]
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
return SlicedSymmetricMatrix(
block_rows=[
update * alpha + row * beta
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
]
)
def num_blocks_from_total_blocks(total_blocks):
"""Returns the number of blocks (i.e.
block rows) from the total blocks.
This is the inverse of the function x -> x*(x+1)/2.
For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
Args:
total_blocks: The total blocks used to represent the matrix.
"""
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
raise ValueError(
f"total_blocks={total_blocks} does not correspond to "
"a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
)
return num_blocks
def find_num_blocks(block_rows_concat):
"""Returns the number of (row) blocks representing the concatenated matrix.
For example, an input with dimensions [256, 2560] represents 10 square blocks,
which matches 4 lower-triangular block rows (1+2+3+4). So this function will
return 4.
Use ordinary numpy functions here so that the returned value is static.
Args:
block_rows_concat: The concatenated block array.
Raises:
ValueError: When the dimensions of the matrix do not correspond to a lower
triangular block representation.
"""
# Compute the number of square blocks used to represent the matrix.
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
# Determine the number of block rows by inverting y = x*(x+1)/2.
return num_blocks_from_total_blocks(total_blocks)
@functools.partial(jax.jit, static_argnames=("block_size"))
def slice_symmetric_matrix(
mat,
block_size,
):
"""Returns sliced row blocks.
Args:
mat: A symmetric matrix.
block_size: The size of the row slices.
"""
num_rows = mat.shape[-2]
num_cols = mat.shape[-1]
if num_rows != num_cols:
raise ValueError("mat is not square.")
if num_rows % block_size != 0:
raise ValueError(
"block size does not evenly divide rows. "
f"num_rows={num_rows}, block_size={block_size}"
)
return SlicedSymmetricMatrix(
block_rows=[
mat[
Ellipsis,
i * block_size : (i + 1) * block_size,
0 : (i + 1) * block_size,
]
for i in range(num_rows // block_size)
]
)
@functools.partial(jax.jit, static_argnames=("block_size"))
def slice_symmetric_matrix_concat(
mat,
block_size,
):
"""Returns the concatenated sliced row blocks.
Args:
mat: A symmetric matrix.
block_size: The size of the row slices.
"""
sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
def sliced_matrix_diag(mat):
"""Returns the diagonal of the symmetric matrix.
Args:
mat: The symmetric matrix represented in concatenated block form.
"""
rows, cols = mat.shape
total_blocks = cols // rows
num_blocks = num_blocks_from_total_blocks(total_blocks)
diags = []
for i in range(num_blocks):
last_index = rows * ((i + 2) * (i + 1)) // 2
first_index = last_index - rows
diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
return jnp.concatenate(diags, axis=-1)
def diag_as_concat(diag, block_size):
"""Returns the representation of a diagonal matrix in symmetric block form.
Args:
diag: The 1D array for the diagonals.
block_size: The size of blocks to use. Must divide the length of diag.
"""
assert len(diag.shape) == 1 # diag must be 1D.
assert len(diag) % block_size == 0
num_diag_blocks = len(diag) // block_size
blocks = []
for i in range(num_diag_blocks):
blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
return jnp.concatenate(blocks, axis=-1)
def row_abs_maxes(mat):
"""Returns the max of the absolute values of the rows of the full matrix.
For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
mat = [1, 6, 2] with block_size = 1. In this case the function returns the
aboslute row maxes of the original symmetric matrix, [6, 6].
Args:
mat: The symmetric matrix represented as the concatenated blocks.
"""
rows, cols = mat.shape
# Find col and row max for each block.
col_maxes = []
row_maxes = []
for i in range(cols // rows):
block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
col_maxes.append(jnp.max(block, axis=1))
row_maxes.append(jnp.max(block, axis=0))
# global row max from block maxes.
num_blocks = num_blocks_from_total_blocks(cols // rows)
maxes = []
for i in range(num_blocks):
maxes.append(
jnp.concatenate(
row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
+ [
col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
for j in range(i + 1, num_blocks)
],
axis=-1,
)
)
return jnp.max(jnp.stack(maxes), axis=0)
def times_vector(mat, vec):
"""Returns the symmetric block-concatenated matrix multiplied by a vector.
Specifically, each value in the vector is multiplied by a row of the full
matrix. That is, the vector is broadcast and multiplied element-wise. Note
this would be the transpose of full_mat * vec if full_mat represented the full
symmetric matrix.
Args:
mat: The symmetric matrix represented as the concatenated blocks.
vec: The vector, having the same dimension as the materialized matrix.
"""
rows, cols = mat.shape
num_blocks = num_blocks_from_total_blocks(cols // rows)
multiplied = []
for i in range(num_blocks):
mat_block = mat[
Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
]
vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
return jnp.concatenate(multiplied, axis=-1)
|