Spaces:
Runtime error
Runtime error
File size: 2,763 Bytes
bfd34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
from torch import nn
from .feed_forward import FeedForward
try:
from .cross_attention import PatchedCrossAttention as CrossAttention
except:
try:
from .memory_efficient_cross_attention import MemoryEfficientCrossAttention as CrossAttention
except:
from .cross_attention import CrossAttention
from ..util import checkpoint
from ...patches import router
class BasicTransformerBlock(nn.Module):
def __init__(
self,dim,n_heads,d_head,dropout=0.0,context_dim=None,
gated_ff=True,checkpoint=True,disable_self_attn=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
# is a self-attention if not self.disable_self_attn
self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
# is self-attn if context is none
self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = x + self.attn1(self.norm1(x), context=context if self.disable_self_attn else None)
x = x + self.attn2(self.norm2(x), context=context)
x = x + self.ff(self.norm3(x))
return x
class PatchedBasicTransformerBlock(nn.Module):
def __init__(
self,dim,n_heads,d_head,dropout=0.0,context_dim=None,
gated_ff=True,checkpoint=True,disable_self_attn=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
# is a self-attention if not self.disable_self_attn
self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
# is self-attn if context is none
self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
return router.basic_transformer_forward(self, x, context)
|