File size: 1,099 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.layers.attention import (
    ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
    attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
    scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm

__all__ = [
    'scaled_multihead_dot_product_attention',
    'flash_attn_fn',
    'triton_flash_attn_fn',
    'MultiheadAttention',
    'MultiQueryAttention',
    'attn_bias_shape',
    'build_attn_bias',
    'build_alibi_bias',
    'ATTN_CLASS_REGISTRY',
    'MPTMLP',
    'MPTBlock',
    'NORM_CLASS_REGISTRY',
    'LPLayerNorm',
    'FC_CLASS_REGISTRY',
    'SharedEmbedding',
    'FFN_CLASS_REGISTRY',
    'build_ffn',
]