Spaces:
Running
Running
File size: 8,154 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
import math
from typing import ContextManager, Sequence, TypeVar
import numpy as np
import torch
MAX_SUPPORTED_DISTANCE = 1e6
TSequence = TypeVar("TSequence", bound=Sequence)
def slice_python_object_as_numpy(
obj: TSequence, idx: int | list[int] | slice | np.ndarray
) -> TSequence:
"""
Slice a python object (like a list, string, or tuple) as if it was a numpy object.
Example:
>>> obj = "ABCDE"
>>> slice_python_object_as_numpy(obj, [1, 3, 4])
"BDE"
>>> obj = [1, 2, 3, 4, 5]
>>> slice_python_object_as_numpy(obj, np.arange(5) < 3)
[1, 2, 3]
"""
if isinstance(idx, int):
idx = [idx]
if isinstance(idx, np.ndarray) and idx.dtype == bool:
sliced_obj = [obj[i] for i in np.where(idx)[0]]
elif isinstance(idx, slice):
sliced_obj = obj[idx]
else:
sliced_obj = [obj[i] for i in idx]
match obj, sliced_obj:
case str(), list():
sliced_obj = "".join(sliced_obj)
case _:
sliced_obj = obj.__class__(sliced_obj) # type: ignore
return sliced_obj # type: ignore
def rbf(values, v_min, v_max, n_bins=16):
"""
Returns RBF encodings in a new dimension at the end.
"""
rbf_centers = torch.linspace(
v_min, v_max, n_bins, device=values.device, dtype=values.dtype
)
rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
rbf_std = (v_max - v_min) / n_bins
z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
return torch.exp(-(z**2))
def batched_gather(data, inds, dim=0, no_batch_dims=0):
ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1)
def knn_graph(
coords: torch.Tensor,
coord_mask: torch.Tensor,
padding_mask: torch.Tensor,
sequence_id: torch.Tensor,
*,
no_knn: int,
):
L = coords.shape[-2]
num_by_dist = min(no_knn, L)
device = coords.device
coords = coords.nan_to_num()
coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None])
padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None]
if sequence_id is not None:
padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze(
sequence_id, 2
)
dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1)
arange = torch.arange(L, device=device)
seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs()
# We only support up to a certain distance, above that, we use sequence distance
# instead. This is so that when a large portion of the structure is masked out,
# the edges are built according to sequence distance.
max_dist = MAX_SUPPORTED_DISTANCE
torch._assert_async((dists[~coord_mask] < max_dist).all())
struct_then_seq_dist = (
seq_dists.to(dists.dtype)
.mul(1e2)
.add(max_dist)
.where(coord_mask, dists)
.masked_fill(padding_pairwise_mask, torch.inf)
)
dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False)
# This is a L x L tensor, where we index by rows first,
# and columns are the edges we should pick.
chosen_edges = edges[..., :num_by_dist]
chosen_mask = dists[..., :num_by_dist].isfinite()
return chosen_edges, chosen_mask
def stack_variable_length_tensors(
sequences: Sequence[torch.Tensor],
constant_value: int | float = 0,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
"""Automatically stack tensors together, padding variable lengths with the
value in constant_value. Handles an arbitrary number of dimensions.
Examples:
>>> tensor1, tensor2 = torch.ones([2]), torch.ones([5])
>>> stack_variable_length_tensors(tensor1, tensor2)
tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones.
>>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3])
>>> stack_variable_length_tensors(tensor1, tensor2)
tensor of shape [2, 5, 4]
"""
batch_size = len(sequences)
shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
if dtype is None:
dtype = sequences[0].dtype
device = sequences[0].device
array = torch.full(shape, constant_value, dtype=dtype, device=device)
for arr, seq in zip(array, sequences):
arrslice = tuple(slice(dim) for dim in seq.shape)
arr[arrslice] = seq
return array
def unbinpack(
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
):
"""
Args:
tensor (Tensor): [B, L, ...]
Returns:
Tensor: [B_unbinpacked, L_unbinpack, ...]
"""
if sequence_id is None:
return tensor
unpacked_tensors = []
num_sequences = sequence_id.max(dim=-1).values + 1
for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
zip(sequence_id, num_sequences)
):
for seqid in range(batch_num_sequences):
mask = batch_seqid == seqid
unpacked = tensor[batch_idx, mask]
unpacked_tensors.append(unpacked)
return stack_variable_length_tensors(unpacked_tensors, pad_value)
def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]:
"""
Returns an autocast context manager that disables downcasting by AMP.
Args:
device_type: The device type ('cpu' or 'cuda')
Returns:
An autocast context manager with the specified behavior.
"""
if device_type == "cpu":
return torch.amp.autocast(device_type, enabled=False)
elif device_type == "cuda":
return torch.amp.autocast(device_type, dtype=torch.float32)
else:
raise ValueError(f"Unsupported device type: {device_type}")
def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]:
"""Merge overlapping ranges into sorted, non-overlapping segments.
Args:
ranges: collection of ranges to merge.
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
no larger than this size.
Returns:
non-overlapping ranges merged from the inputs, sorted by position.
"""
ranges = sorted(ranges, key=lambda r: r.start)
merge_gap_max = merge_gap_max if merge_gap_max is not None else 0
assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}"
merged = []
for r in ranges:
if not merged:
merged.append(r)
else:
last = merged[-1]
if last.stop + merge_gap_max >= r.start:
merged[-1] = range(last.start, max(last.stop, r.stop))
else:
merged.append(r)
return merged
def list_nan_to_none(l: list) -> list:
if l is None:
return None # type: ignore
elif isinstance(l, float):
return None if math.isnan(l) else l # type: ignore
elif isinstance(l, list):
return [list_nan_to_none(x) for x in l]
else:
# Don't go into other structures.
return l
def list_none_to_nan(l: list) -> list:
if l is None:
return math.nan # type: ignore
elif isinstance(l, list):
return [list_none_to_nan(x) for x in l]
else:
return l
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
if x is None:
return None
if convert_none_to_nan:
x = list_none_to_nan(x)
return torch.tensor(x)
def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
if x is None:
return None
x = x.tolist()
if convert_nan_to_none:
x = list_nan_to_none(x)
return x
|