File size: 79,099 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 |
import copy
import importlib.metadata
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging
if is_quanto_available():
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version >= version.parse("0.2.0"):
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
logger = logging.get_logger(__name__)
class Cache(torch.nn.Module):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
def __init__(self):
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.
Return:
A tuple containing the updated key and value states.
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
@dataclass
class CacheConfig:
"""
Base class for cache configs
"""
cache_implementation: None
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a CacheConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
CacheConfig: Instance of CacheConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default
`QuantizationConfig()` is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
writer.write(json_string)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return copy.deepcopy(self.__dict__)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
@dataclass
class QuantizedCacheConfig(CacheConfig):
"""
Configuration class for quantized cache settings.
Attributes:
backend (`str`, *optional*, defaults to `"quanto"`):
Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
axis_key (`int`, *optional*, defaults to 0):
Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
axis_value (`int`, *optional*, defaults to 0):
Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
device (`str`, *optional*, defaults to `"cpu"`):
Device on which to perform computations, should be same as the model's device.
"""
def __init__(
self,
backend: str = "quanto",
nbits: Optional[int] = 4,
axis_key: Optional[int] = 0,
axis_value: Optional[int] = 0,
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
compute_dtype: Optional[torch.dtype] = torch.float16,
device: Optional[str] = "cpu",
):
self.backend = backend
self.nbits = nbits
self.axis_key = axis_key
self.axis_value = axis_value
self.q_group_size = q_group_size
self.residual_length = residual_length
self.compute_dtype = compute_dtype
self.device = device
def validate(self):
"""Validates if the arguments passed are correct"""
incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
# Check that the values are reasonable in general (nbits, axis)
# Later in QuantizedCache init we check if they are supported for that particular backend
if self.nbits not in [1, 2, 3, 4, 8]:
raise ValueError(
incorrect_arg_msg.format(
key="nbits",
correct_value="2 or 4 or 8",
found_value=self.nbits,
),
)
if self.q_group_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="q_group_size",
correct_value="a positive integer",
found_value=self.q_group_size,
),
)
if self.residual_length < 0:
raise ValueError(
incorrect_arg_msg.format(
key="residual_length",
correct_value="a positive integer",
found_value=self.residual_length,
),
)
if self.axis_key not in [0, 1, -1]:
raise ValueError(
incorrect_arg_msg.format(
key="axis_key",
correct_value="`1` or `0`, `-1`",
found_value=self.axis_key,
),
)
if self.axis_value not in [0, 1, -1]:
raise ValueError(
incorrect_arg_msg.format(
key="axis_value",
correct_value="`1` or `0` or `-1`",
found_value=self.axis_value,
),
)
class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = DynamicCache()
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= max_length:
return
self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out
@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
return cache
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
Useful for generating from models with very long context.
In addition to the default CUDA stream, where all forward() computations happen,
this class uses another stream, the prefetch stream, which it creates itself.
Since scheduling of operations on separate streams happens independently, this class uses
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
ensure the eviction is scheduled after all computations on that cache are finished.
"""
def __init__(self) -> None:
if not torch.cuda.is_available():
raise RuntimeError("OffloadedCache can only be used with a GPU")
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.beam_idx = None # used to delay beam search operations
def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
self.prefetch_stream.synchronize()
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
# Now deal with beam search ops which were delayed
if self.beam_idx is not None:
self.beam_idx = self.beam_idx.to(original_device)
key_tensor = key_tensor.index_select(0, self.beam_idx)
value_tensor = value_tensor.index_select(0, self.beam_idx)
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Saves the beam indices and reorders the cache when the tensor is back to its device."""
# We delay this operation until the tensors are back to their original
# device because performing torch.index_select on the CPU is very slow
del self.beam_idx
self.beam_idx = beam_idx.clone()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
self.evict_previous_layer(layer_idx)
else:
key_tensor, value_tensor = self[layer_idx]
self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
# According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
# if a method is not supposed to be supported in a subclass we should set it to None
from_legacy_cache = None
to_legacy_cache = None
class QuantizedCache(DynamicCache):
"""
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
"""
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
super().__init__()
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []
self.nbits = cache_config.nbits
self.residual_length = cache_config.residual_length
self.q_group_size = cache_config.q_group_size
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
keys_to_return = torch.cat(keys_to_return, dim=-2)
values_to_return = torch.cat(values_to_return, dim=-2)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
self._quantized_value_cache[layer_idx] = self._quantize(
values_to_return.contiguous(), axis=self.axis_value
)
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return keys_to_return, values_to_return
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
def _quantize(self, tensor, axis):
"""Quantizes a key/value using a defined quantization method."""
raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
def _dequantize(self, q_tensor):
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
class QuantoQuantizedCache(QuantizedCache):
"""
Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
Parameters:
cache_config (`QuantizedCacheConfig`):
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
Example:
```python
>>> # Run pip install quanto first if you don't have it yet
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> cache_config = QuantizedCacheConfig(nbits=4)
>>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version < version.parse("0.2.0"):
raise ImportError(
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
f"Please upgrade quanto with `pip install -U quanto`"
)
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
if self.axis_key not in [0, -1]:
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
if self.axis_value not in [0, -1]:
raise ValueError(
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
)
self.qtype = qint4 if self.nbits == 4 else qint2
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
def _quantize(self, tensor, axis):
scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
return qtensor
def _dequantize(self, qtensor):
return qtensor.dequantize()
class HQQQuantizedCache(QuantizedCache):
"""
Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
Parameters:
cache_config (`QuantizedCacheConfig`):
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
Example:
```python
>>> # Run pip install hqq first if you don't have it yet
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
>>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
if self.nbits not in [1, 2, 3, 4, 8]:
raise ValueError(
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
)
if self.axis_key not in [0, 1]:
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
if self.axis_value not in [0, 1]:
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
self.quantizer = HQQQuantizer
def _quantize(self, tensor, axis):
qtensor, meta = self.quantizer.quantize(
tensor,
axis=axis,
device=self.device,
compute_dtype=self.compute_dtype,
nbits=self.nbits,
group_size=self.q_group_size,
)
meta["compute_dtype"] = self.compute_dtype
self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
return qtensor, meta
def _dequantize(self, qtensor):
quant_tensor, meta = qtensor
tensor = self.quantizer.dequantize(quant_tensor, meta)
return tensor
class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Parameters:
window_length (`int`):
The length of the context window.
num_sink_tokens (`int`):
The number of sink tokens. See the original paper for more information.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states
def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.window_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted.
Return:
A tuple containing the updated key and value states.
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and layer_idx == 0:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length:
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
else:
# Shifting cache
keys_to_keep = self.key_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
]
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
keys_to_keep[..., partial_rotation_size:],
)
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
if partial_rotation_size is not None:
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
values_to_keep = self.value_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
]
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
# operation, that avoids copies and uses less memory.
try:
# If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
class SlidingWindowCache(StaticCache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__(
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.max_cache_len:
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
to_shift = cache_position >= self.max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
try:
cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed no matter how long the sentence is
return None
def reset(self):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
class EncoderDecoderCache(Cache):
"""
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
cross-attention caches.
Example:
```python
>>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
>>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
>>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
>>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
>>> # Prepare cache classes for encoder and decoder and pass it to model's forward
>>> self_attention_cache = DynamicCache()
>>> cross_attention_cache = DynamicCache()
>>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
super().__init__()
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
self.is_updated = {}
for layer_idx in range(len(cross_attention_cache.key_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
self.self_attention_cache.key_cache[layer_idx],
self.self_attention_cache.value_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
self.cross_attention_cache.value_cache[layer_idx],
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.self_attention_cache)
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
legacy_cache = ()
if len(self.cross_attention_cache) > 0:
for self_attn, cross_attn in zip(
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
):
legacy_cache += (self_attn + cross_attn,)
else:
legacy_cache = self.self_attention_cache.to_legacy_cache()
return legacy_cache
@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
cache.self_attention_cache.update(key_states, value_states, layer_idx)
if len(past_key_values[layer_idx]) > 2:
key_states, value_states = past_key_values[layer_idx][2:]
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
if hasattr(self.self_attention_cache, "reset"):
self.self_attention_cache.reset()
if hasattr(self.cross_attention_cache, "reset"):
self.cross_attention_cache.reset()
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
raise ValueError(
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
)
for layer_idx in self.is_updated:
self.is_updated[layer_idx] = False
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
self.self_attention_cache.reorder_cache(beam_idx)
self.cross_attention_cache.reorder_cache(beam_idx)
def check_dynamic_cache(self, method: str):
if not (
isinstance(self.self_attention_cache, DynamicCache)
and isinstance(self.cross_attention_cache, DynamicCache)
):
raise ValueError(
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
)
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out
@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
self_attention_cache.update(layer_keys, layer_values, idx)
layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
cross_attention_cache.update(layer_keys, layer_values, idx)
return cls(self_attention_cache, cross_attention_cache)
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
self.self_attention_cache.batch_repeat_interleave(repeats)
self.cross_attention_cache.batch_repeat_interleave(repeats)
def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
self.check_dynamic_cache(self.batch_select_indices.__name__)
self.self_attention_cache.batch_select_indices(indices)
self.cross_attention_cache.batch_select_indices(indices)
class HybridCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`, *optional*, defaults to `"cpu"`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.is_sliding = torch.tensor(
[not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
max_batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
self.head_dim,
)
for i in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] > max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position >= max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
update_fn = self._sliding_update
else:
update_fn = self._static_update
return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
dtype (*optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Attributes:
dtype: (`torch.dtype`):
The default `dtype` used to initializing the cache.
intermediate_size: (`int`):
Model's intermediate_size taken from config.
ssm_state_size: (`int`):
Model's state_size taken from config.
conv_kernel_size: (`int`):
Model's convolution kernel size taken from config
conv_states: (`torch.Tensor`):
A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
ssm_states: (`torch.Tensor`):
A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
Example:
```python
>>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv = outputs.past_key_values
```
"""
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Optional[str] = None,
**kwargs,
):
self.dtype = dtype
self.max_batch_size = max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=dtype,
)
torch._dynamo.mark_static_address(self.conv_states)
torch._dynamo.mark_static_address(self.ssm_states)
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
return self.ssm_states[layer_idx]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
|