Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""JAX Ops for symmetric matrices used by the Shampoo optimizer.""" | |
import functools | |
from typing import Any, List, Optional, Sequence, Union | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax import struct | |
from jax import lax | |
class SlicedSymmetricMatrix: | |
"""A symmetric matrix represented by lower-triangular block row slices. | |
For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented | |
by the block rows a and [b, c]. | |
The matrix may be batched, in which case each entry of block_rows may have | |
dimension greater than 2. The last two dimensions represent the rows and cols. | |
""" | |
block_rows: List[jnp.ndarray] | |
def product_with_transpose( | |
mat1, | |
mat2, | |
axes, | |
precision=lax.Precision.DEFAULT, | |
): | |
"""Returns mat1 * mat2^T for two matrices (possibly batched). | |
The rows and columns are the last two dimensions for each matrix. | |
Args: | |
mat1: First matrix. | |
mat2: Second matrix. | |
axes: The axes over which to apply the product. | |
precision: JAX precision to use for the multiplication. | |
""" | |
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision) | |
def sliced_transposed_product( | |
mat, | |
block_size, | |
axes=(-1,), | |
precision=lax.Precision.DEFAULT, | |
): | |
"""Returns the blocked slices representing a symmetric contraction. | |
Specifically, the output is a contraction of the input mat with itself, in the | |
specified axes. | |
Args: | |
mat: The matrix for which we will compute a contraction with itself. | |
block_size: The size of row blocks to compute. | |
axes: Axes to use for the contraction. | |
precision: The precision to use in each computation. | |
Raises: | |
ValueError: Raised when the specified block size does not evenly divide | |
the number of rows of the input mat. | |
""" | |
rank = len(mat.shape) | |
def _make_axis_positive(ax): | |
assert -rank <= ax < rank | |
return ax + rank if ax < 0 else ax | |
positive_axes = [_make_axis_positive(ax) for ax in axes] | |
assert len(positive_axes) == len(axes) | |
remaining_axes = set(range(rank)) - set(positive_axes) | |
assert len(remaining_axes) == 1 | |
remaining_ax = remaining_axes.pop() | |
num_rows = mat.shape[remaining_ax] | |
if num_rows % block_size != 0: | |
raise ValueError( | |
"The row dimension must be divisible by block_size. " | |
f"Instead got row dimension={num_rows} and block_size={block_size}." | |
) | |
block_rows = [] | |
for i in range(num_rows // block_size): | |
start_indices = [0] * rank | |
start_indices[remaining_ax] = i * block_size | |
slice_sizes = list(mat.shape) | |
slice_sizes[remaining_ax] = block_size | |
slice_sizes_full = list(mat.shape) | |
slice_sizes_full[remaining_ax] = (i + 1) * block_size | |
block_rows.append( | |
product_with_transpose( | |
lax.dynamic_slice( | |
mat, start_indices=start_indices, slice_sizes=slice_sizes | |
), | |
lax.dynamic_slice( | |
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full | |
), | |
axes=(axes, axes), | |
precision=precision, | |
) | |
) | |
return SlicedSymmetricMatrix(block_rows=block_rows) | |
def sliced_transposed_product_concat( | |
mat, | |
block_size, | |
axes=(-1,), | |
precision=lax.Precision.DEFAULT, | |
): | |
"""Returns the concatenated slices representing mat*mat^T. | |
Args: | |
mat: The matrix for which we will compute mat*mat^T. It does not need to be | |
square, and may be batched. | |
block_size: The size of row blocks to compute. | |
axes: Axes to use for the contraction. | |
precision: The precision to use in each computation. | |
Raises: | |
ValueError: Raised when the specified block size does not evenly divide | |
the number of rows of the input mat. | |
""" | |
sliced_symmetric_matrix = sliced_transposed_product( | |
mat=mat, block_size=block_size, axes=axes, precision=precision | |
) | |
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1) | |
def materialize_matrix(symmetric_matrix): | |
"""Returns a materialized symmetric matrix. | |
Args: | |
symmetric_matrix: the matrix represented by lower-triangular block slices. | |
""" | |
block_rows = symmetric_matrix.block_rows | |
block_size = block_rows[0].shape[-2] | |
num_blocks = len(block_rows) | |
# Slice the lower-triangular and diagonal blocks into blocks. | |
blocks = [ | |
[ | |
block_row[Ellipsis, i * block_size : (i + 1) * block_size] | |
for i in range(k + 1) | |
] | |
for k, block_row in enumerate(block_rows) | |
] | |
# Generate the (off-diagonal) upper-triangular blocks. | |
off_diags = [[] for _ in range(num_blocks - 1)] | |
for k, block_row in enumerate(block_rows[1:]): | |
for i in range(k + 1): | |
off_diags[i].append( | |
jnp.swapaxes( | |
a=block_row[Ellipsis, i * block_size : (i + 1) * block_size], | |
axis1=-1, | |
axis2=-2, | |
) | |
) | |
return jnp.block( | |
[row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]] | |
) | |
def materialize_matrix_from_concat( | |
block_rows_concat, | |
num_blocks=None, | |
): | |
"""Returns a materialized symmetric matrix from concatenated slices. | |
Args: | |
block_rows_concat: The matrix represented as the concatenated | |
lower-triangular blocks. | |
num_blocks: The number of block-rows used to represent the symmetric matrix. | |
If not specified, it is inferred from the shape of block_rows_concat. | |
""" | |
if num_blocks is None: | |
num_blocks = find_num_blocks(block_rows_concat) | |
block_size = block_rows_concat.shape[-2] | |
block_rows = [ | |
block_rows_concat[ | |
Ellipsis, | |
(k * (k + 1)) | |
// 2 | |
* block_size : (((k + 1) * (k + 2)) // 2 + 1) | |
* block_size, | |
] | |
for k in range(num_blocks) | |
] | |
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows)) | |
def update_sliced_rows( | |
symmetric_matrix, | |
mat, | |
alpha, | |
beta, | |
axes=(-1,), | |
): | |
"""Implements the blocked equivalent of SYRK. | |
Specifically, the symmetric matrix (represented using lower-triangular block | |
rows) is updated using the sliced product of mat. | |
Args: | |
symmetric_matrix: The symmetric matrix to update. | |
mat: The matrix to use for the update = mat * mat^T. The number of rows | |
should match that of symmetric_matrix. | |
alpha: The weight for the update. | |
beta: The weight for the original symmetric matrix. | |
axes: Axes to use for the contraction of the update. | |
Returns: | |
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix. | |
""" | |
block_size = symmetric_matrix.block_rows[0].shape[-2] | |
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes) | |
return SlicedSymmetricMatrix( | |
block_rows=[ | |
update * alpha + row * beta | |
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows) | |
] | |
) | |
def num_blocks_from_total_blocks(total_blocks): | |
"""Returns the number of blocks (i.e. | |
block rows) from the total blocks. | |
This is the inverse of the function x -> x*(x+1)/2. | |
For example, the matrix M = [[A, B^T], [B, C]] may be represented using a | |
total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2. | |
Args: | |
total_blocks: The total blocks used to represent the matrix. | |
""" | |
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32) | |
if (num_blocks * (num_blocks + 1)) / 2 != total_blocks: | |
raise ValueError( | |
f"total_blocks={total_blocks} does not correspond to " | |
"a symmetric matrix. It must have the form total_blocks = x*(x+1)/2." | |
) | |
return num_blocks | |
def find_num_blocks(block_rows_concat): | |
"""Returns the number of (row) blocks representing the concatenated matrix. | |
For example, an input with dimensions [256, 2560] represents 10 square blocks, | |
which matches 4 lower-triangular block rows (1+2+3+4). So this function will | |
return 4. | |
Use ordinary numpy functions here so that the returned value is static. | |
Args: | |
block_rows_concat: The concatenated block array. | |
Raises: | |
ValueError: When the dimensions of the matrix do not correspond to a lower | |
triangular block representation. | |
""" | |
# Compute the number of square blocks used to represent the matrix. | |
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2] | |
# Determine the number of block rows by inverting y = x*(x+1)/2. | |
return num_blocks_from_total_blocks(total_blocks) | |
def slice_symmetric_matrix( | |
mat, | |
block_size, | |
): | |
"""Returns sliced row blocks. | |
Args: | |
mat: A symmetric matrix. | |
block_size: The size of the row slices. | |
""" | |
num_rows = mat.shape[-2] | |
num_cols = mat.shape[-1] | |
if num_rows != num_cols: | |
raise ValueError("mat is not square.") | |
if num_rows % block_size != 0: | |
raise ValueError( | |
"block size does not evenly divide rows. " | |
f"num_rows={num_rows}, block_size={block_size}" | |
) | |
return SlicedSymmetricMatrix( | |
block_rows=[ | |
mat[ | |
Ellipsis, | |
i * block_size : (i + 1) * block_size, | |
0 : (i + 1) * block_size, | |
] | |
for i in range(num_rows // block_size) | |
] | |
) | |
def slice_symmetric_matrix_concat( | |
mat, | |
block_size, | |
): | |
"""Returns the concatenated sliced row blocks. | |
Args: | |
mat: A symmetric matrix. | |
block_size: The size of the row slices. | |
""" | |
sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size) | |
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1) | |
def sliced_matrix_diag(mat): | |
"""Returns the diagonal of the symmetric matrix. | |
Args: | |
mat: The symmetric matrix represented in concatenated block form. | |
""" | |
rows, cols = mat.shape | |
total_blocks = cols // rows | |
num_blocks = num_blocks_from_total_blocks(total_blocks) | |
diags = [] | |
for i in range(num_blocks): | |
last_index = rows * ((i + 2) * (i + 1)) // 2 | |
first_index = last_index - rows | |
diags.append(jnp.diag(mat[Ellipsis, first_index:last_index])) | |
return jnp.concatenate(diags, axis=-1) | |
def diag_as_concat(diag, block_size): | |
"""Returns the representation of a diagonal matrix in symmetric block form. | |
Args: | |
diag: The 1D array for the diagonals. | |
block_size: The size of blocks to use. Must divide the length of diag. | |
""" | |
assert len(diag.shape) == 1 # diag must be 1D. | |
assert len(diag) % block_size == 0 | |
num_diag_blocks = len(diag) // block_size | |
blocks = [] | |
for i in range(num_diag_blocks): | |
blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype)) | |
blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size])) | |
return jnp.concatenate(blocks, axis=-1) | |
def row_abs_maxes(mat): | |
"""Returns the max of the absolute values of the rows of the full matrix. | |
For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using | |
mat = [1, 6, 2] with block_size = 1. In this case the function returns the | |
aboslute row maxes of the original symmetric matrix, [6, 6]. | |
Args: | |
mat: The symmetric matrix represented as the concatenated blocks. | |
""" | |
rows, cols = mat.shape | |
# Find col and row max for each block. | |
col_maxes = [] | |
row_maxes = [] | |
for i in range(cols // rows): | |
block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows]) | |
col_maxes.append(jnp.max(block, axis=1)) | |
row_maxes.append(jnp.max(block, axis=0)) | |
# global row max from block maxes. | |
num_blocks = num_blocks_from_total_blocks(cols // rows) | |
maxes = [] | |
for i in range(num_blocks): | |
maxes.append( | |
jnp.concatenate( | |
row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)] | |
+ [ | |
col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)] | |
for j in range(i + 1, num_blocks) | |
], | |
axis=-1, | |
) | |
) | |
return jnp.max(jnp.stack(maxes), axis=0) | |
def times_vector(mat, vec): | |
"""Returns the symmetric block-concatenated matrix multiplied by a vector. | |
Specifically, each value in the vector is multiplied by a row of the full | |
matrix. That is, the vector is broadcast and multiplied element-wise. Note | |
this would be the transpose of full_mat * vec if full_mat represented the full | |
symmetric matrix. | |
Args: | |
mat: The symmetric matrix represented as the concatenated blocks. | |
vec: The vector, having the same dimension as the materialized matrix. | |
""" | |
rows, cols = mat.shape | |
num_blocks = num_blocks_from_total_blocks(cols // rows) | |
multiplied = [] | |
for i in range(num_blocks): | |
mat_block = mat[ | |
Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2 | |
] | |
vec_block = vec[Ellipsis, rows * i : rows * (i + 1)] | |
multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block)) | |
return jnp.concatenate(multiplied, axis=-1) | |