File size: 8,913 Bytes
33e938e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict
from pipeline import are_two_tensors_similar
import torch
@dataclasses.dataclass
class CacheContext:
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
self.incremental_name_counters[name] += 1
return f"{name}_{idx}"
def reset_incremental_names(self):
self.incremental_name_counters.clear()
@torch.compiler.disable
def get_buffer(self, name):
return self.buffers.get(name)
@torch.compiler.disable
def set_buffer(self, name, buffer):
self.buffers[name] = buffer
def clear_buffers(self):
self.buffers.clear()
@torch.compiler.disable
def get_buffer(name):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable
def set_buffer(name, buffer):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
_current_cache_context = None
def create_cache_context():
return CacheContext()
def get_current_cache_context():
return _current_cache_context
def set_current_cache_context(cache_context=None):
global _current_cache_context
_current_cache_context = cache_context
@contextlib.contextmanager
def cache_context(cache_context):
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
try:
yield
finally:
_current_cache_context = old_cache_context
@torch.compiler.disable
def are_two_tensors_similar_old(t1, t2, *, threshold, parallelized=False):
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1
return diff.item() < threshold
@torch.compiler.disable
def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
@torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
prev_first_hidden_states_residual,
first_hidden_states_residual,
threshold=threshold,
parallelized=parallelized,
)
return can_use_cache
class CachedTransformerBlocks(torch.nn.Module):
def __init__(
self,
transformer_blocks,
single_transformer_blocks=None,
*,
transformer=None,
residual_diff_threshold,
return_hidden_states_first=True,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer_blocks
self.single_transformer_blocks = single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.return_hidden_states_first = return_hidden_states_first
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
if self.residual_diff_threshold <= 0.0:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, *args, **kwargs)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
hidden_states, encoder_hidden_states = first_transformer_block(
hidden_states, encoder_hidden_states, *args, **kwargs
)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states
)
else:
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
(
hidden_states,
encoder_hidden_states,
hidden_states_residual,
encoder_hidden_states_residual,
) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
torch._dynamo.graph_break()
return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
for block in self.transformer_blocks[1:]:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, *args, **kwargs)
encoder_hidden_states, hidden_states = hidden_states.split(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
# hidden_states_shape = hidden_states.shape
# encoder_hidden_states_shape = encoder_hidden_states.shape
hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape)
encoder_hidden_states = (
encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
)
# hidden_states = hidden_states.contiguous()
# encoder_hidden_states = encoder_hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape)
encoder_hidden_states_residual = (
encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
)
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
|