|
|
|
|
|
|
|
from llmfoundry.models.layers.attention import ( |
|
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention, |
|
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, |
|
scaled_multihead_dot_product_attention, triton_flash_attn_fn) |
|
from llmfoundry.models.layers.blocks import MPTBlock |
|
from llmfoundry.models.layers.custom_embedding import SharedEmbedding |
|
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY |
|
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn |
|
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm |
|
|
|
__all__ = [ |
|
'scaled_multihead_dot_product_attention', |
|
'flash_attn_fn', |
|
'triton_flash_attn_fn', |
|
'MultiheadAttention', |
|
'MultiQueryAttention', |
|
'attn_bias_shape', |
|
'build_attn_bias', |
|
'build_alibi_bias', |
|
'ATTN_CLASS_REGISTRY', |
|
'MPTMLP', |
|
'MPTBlock', |
|
'NORM_CLASS_REGISTRY', |
|
'LPLayerNorm', |
|
'FC_CLASS_REGISTRY', |
|
'SharedEmbedding', |
|
'FFN_CLASS_REGISTRY', |
|
'build_ffn', |
|
] |
|
|