File size: 313 Bytes
b0b44df
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from torch.utils.checkpoint import checkpoint

def BasicTransformerBlock_forward(self, x, context=None):
    return checkpoint(self._forward, x, context)

def AttentionBlock_forward(self, x):
    return checkpoint(self._forward, x)

def ResBlock_forward(self, x, emb):
    return checkpoint(self._forward, x, emb)