File size: 134 Bytes
616f571
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from dataclasses import dataclass

import torch


@dataclass
class Cache:
    key_states: torch.Tensor
    value_states: torch.Tensor