File size: 13,964 Bytes
89c0b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 |
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import torch
from protenix.openfold_local.utils.tensor_utils import tensor_tree_map, tree_map
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if start_edges is None:
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if end_edges is None:
end_edges = [e == (d - 1) for e, d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if len(start) == 0:
return [tuple()]
elif len(start) == 1:
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s, e in zip(start, end):
if s == e:
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if divergence_idx == len(dims):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s
for s in _get_minimal_slice_set(
start[divergence_idx + 1 :],
[d - 1 for d in dims[divergence_idx + 1 :]],
dims[divergence_idx + 1 :],
start_edges=start_edges[divergence_idx + 1 :],
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s
for s in _get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1 :]],
end[divergence_idx + 1 :],
dims[divergence_idx + 1 :],
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
end_edges=end_edges[divergence_idx + 1 :],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if start_edges[divergence_idx] and end_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif start_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif end_edges[divergence_idx]:
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if middle_ground > 1:
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
def chunk_layer(
layer: Callable,
inputs: dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
_out: Any = None,
_add_into_out: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
if not low_mem:
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if _out is not None:
reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
prepped_outputs = tensor_tree_map(reshape_fn, _out)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
i = 0
out = prepped_outputs
for _ in range(no_chunks):
# Chunk the input
if not low_mem:
select_chunk = lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
else:
select_chunk = partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims),
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
if _add_into_out:
v[i : i + chunk_size] += d2[k]
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
if _add_into_out:
x1[i : i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
if _add_into_out:
out[i : i + chunk_size] += output_chunk
else:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
class ChunkSizeTuner:
def __init__(
self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if min_chunk_size >= self.max_chunk_size:
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if not viable:
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1, ac2):
consistent = True
for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2)
if type(ac1) is list or type(ac1) is tuple:
consistent &= self._compare_arg_caches(a1, a2)
elif type(ac1) is dict:
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
consistent &= self._compare_arg_caches(a1_items, a2_items)
else:
consistent &= a1 == a2
return consistent
def tune_chunk_size(
self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
remove_tensors = lambda a: a.shape if type(a) is torch.Tensor else a
arg_data = tree_map(remove_tensors, args, object)
if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data)
consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if not consistent:
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size
|