Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for linear algebra.
进行线性代数计算
This module contains additional pytorch layers for linear algebra operations,
such as a more parallelization-friendly implementation of eigvenalue estimation.
"""
import torch
def eig_power_iteration(A, num_iterations=50, eps=1e-5):
"""Estimate largest magnitude eigenvalue and associated eigenvector.
This uses a simple power iteration algorithm to estimate leading
eigenvalues, which can often be considerably faster than torch's built-in
eigenvalue routines. All steps are differentiable and small constants are
added to any division to preserve the stability of the gradients. For more
information on power iteration, see
https://en.wikipedia.org/wiki/Power_iteration.
Args:
A (tensor): Batch of square matrices with shape
`(..., num_dims, num_dims)`.
num_iterations (int, optional): Number of iterations for power
iteration. Default: 50.
eps (float, optional): Small number to prevent division by zero.
Default: 1E-5.
Returns:
lam (tensor): Batch of estimated highest-magnitude eigenvalues with
shape `(...)`.
v (tensor): Associated eigvector with shape `(..., num_dims)`.
"""
_safe = lambda x: x + eps
dims = list(A.size())[:-1]
v = torch.randn(dims, device=A.device).unsqueeze(-1)
for i in range(num_iterations):
v_prev = v
Av = torch.matmul(A, v)
v = Av / _safe(Av.norm(p=2, dim=-2, keepdim=True))
# Compute eigenvalue
v_prev = v_prev.transpose(-1, -2)
lam = torch.matmul(v_prev, Av) / _safe(torch.abs(torch.matmul(v_prev, v)))
# Reshape
v = v.squeeze(-1)
lam = lam.view(list(lam.size())[:-2])
return lam, v
def eig_leading(A, num_iterations=50):
"""Estimate largest positive eigenvalue and associated eigenvector.
This estimates the *most positive* eigenvalue of each matrix in a batch of
matrices by using two consecutive power iterations with spectral shifting.
Args:
A (tensor): Batch of square matrices with shape
`(..., num_dims, num_dims)`.
num_iterations (int, optional): Number of iterations for power
iteration. Default: 50.
Returns:
lam (tensor): Estimated most positive eigenvalue with shape `(...)`.
v (tensor): Associated eigenvectors with shape `(..., num_dims)`.
"""
batch_dims = list(A.size())[:-2]
# First pass gets largest magnitude
lam_1, vec_1 = eig_power_iteration(A, num_iterations)
# Second pass guaranteed to grab most positive eigenvalue
lam_1_abs = torch.abs(lam_1)
lam_I = lam_1_abs.reshape(batch_dims + [1, 1]) * torch.eye(4, device=A.device).view(
[1 for _ in batch_dims] + [4, 4]
)
A_shift = A + lam_I
lam_2, vec = eig_power_iteration(A_shift, num_iterations)
# Shift back to original specta
lam = lam_2 - lam_1_abs
return lam, vec