crystal-technologies's picture
Upload 303 files
de4ade4
raw
history blame contribute delete
1.1 kB
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
__all__ = [
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
'MultiheadAttention',
'MultiQueryAttention',
'attn_bias_shape',
'build_attn_bias',
'build_alibi_bias',
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
]