Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for comparing and mapping point clouds via optimal transport.
This module contains minimalist implementations of basic optimal transport
routines which can be used to, for example, measure similarities between
point clouds of different shapes by computing optimal mappings between them.
For more information see the excellent book by Peyre,
https://arxiv.org/pdf/1803.00567.pdf
"""
import numpy as np
import torch
def optimize_couplings_sinkhorn(C, scale=1.0, iterations=10):
"""Solve entropy regularized optimized transport via Sinkhorn iteration.
This method uses the log-domain for numerical stability.
Args:
C (Tensor): Batch of cost matrices with with shape `(B, I, J)`.
scale (float, optional): Entropy regularization parameter for
rescaling the cost matrix.
iterations (int, optional): Number of Sinkhorn iterations.
Returns:
T (Tensor): Couplings map with shape `(B, I, J)`.
"""
log_T = -C * scale
# Initialize normalizers
B, I, J = log_T.shape
log_u = torch.zeros((B, I), device=log_T.device)
log_v = torch.zeros((B, J), device=log_T.device)
log_a = log_u - np.log(I)
log_b = log_v - np.log(J)
# Iterate normalizers
for j in range(iterations):
log_u = log_a - torch.logsumexp(log_T + log_v.unsqueeze(1), 2)
log_v = log_b - torch.logsumexp(log_T + log_u.unsqueeze(2), 1)
log_T = log_T + log_v.unsqueeze(1) + log_u.unsqueeze(2)
T = torch.exp(log_T)
return T
def optimize_couplings_gw(
D_a, D_b, scale=200.0, iterations_outer=30, iterations_inner=10,
):
"""Gromov-Wasserstein Optimal Transport.
https://arxiv.org/pdf/1905.07645.pdf
Args:
D_a (Tensor): Distance matrix describing objects in set `a` with shape `(B, I, I)`.
D_b (Tensor): Distance matrix describing objects in set `b` with shape `(B, J, J)`.
scale (float, optional): Entropy regularization parameter for
rescaling the cost matrix.
iterations_outer (int, optional): Number of outer GW iterations.
iterations_inner (int, optional): Number of inner Sinkhorn iterations.
Returns:
T (Tensor): Couplings map with shape `(B, I, J)`.
"""
# Gromov-Wasserstein Distance
N_a = D_a.shape[1]
N_b = D_b.shape[1]
p_a = torch.ones_like(D_a[:, :, 0]) / N_a
p_b = torch.ones_like(D_b[:, :, 0]) / N_b
C_ab = (
torch.einsum("bij,bj->bi", D_a ** 2, p_a)[:, :, None]
+ torch.einsum("bij,bj->bi", D_b ** 2, p_b)[:, None, :]
)
T_gw = torch.einsum("bi,bj->bij", p_a, p_b)
for i in range(iterations_outer):
cost = C_ab - 2.0 * torch.einsum("bik,bkl,blj->bij", D_a, T_gw, D_b)
T_gw = optimize_couplings_sinkhorn(cost, scale, iterations=iterations_inner)
# Compute cost
cost = C_ab - 2.0 * torch.einsum("bik,bkl,blj->bij", D_a, T_gw, D_b)
D_gw = (T_gw * cost).sum([-1, -2]).abs().sqrt()
return T_gw, D_gw