File size: 5,600 Bytes
c7a96cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch

from loguru import logger
from transformers import AutoConfig
from transformers.models.auto import modeling_auto
from typing import Optional

from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded

try:
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            )

        from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
        from text_generation_server.models.flash_llama import (
            FlashLlama,
            FlashLlamaSharded,
        )
        from text_generation_server.models.flash_santacoder import (
            FlashSantacoder,
            FlashSantacoderSharded,
        )

        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
except ImportError:
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
    FLASH_ATTENTION = False

__all__ = [
    "Model",
    "BLOOM",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "Galactica",
    "GalacticaSharded",
    "GPTNeoxSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPT",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

if FLASH_ATTENTION:
    __all__.append(FlashNeoX)
    __all__.append(FlashNeoXSharded)
    __all__.append(FlashSantacoder)
    __all__.append(FlashSantacoderSharded)
    __all__.append(FlashLlama)
    __all__.append(FlashLlamaSharded)

FLASH_ATT_ERROR_MESSAGE = (
    "{} requires Flash Attention CUDA kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)


def get_model(
    model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
    if "facebook/galactica" in model_id:
        if sharded:
            return GalacticaSharded(model_id, revision, quantize=quantize)
        else:
            return Galactica(model_id, revision, quantize=quantize)

    if "bigcode" in model_id:
        if sharded:
            if not FLASH_ATTENTION:
                raise NotImplementedError(
                    FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
                )
            return FlashSantacoderSharded(model_id, revision, quantize=quantize)
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
            return santacoder_cls(model_id, revision, quantize=quantize)

    config = AutoConfig.from_pretrained(model_id, revision=revision)
    model_type = config.model_type

    if model_type == "bloom":
        if sharded:
            return BLOOMSharded(model_id, revision, quantize=quantize)
        else:
            return BLOOM(model_id, revision, quantize=quantize)

    if model_type == "gpt_neox":
        if sharded:
            neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
            return neox_cls(model_id, revision, quantize=quantize)
        else:
            neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
            return neox_cls(model_id, revision, quantize=quantize)

    if model_type == "llama":
        if sharded:
            if FLASH_ATTENTION:
                return FlashLlamaSharded(model_id, revision, quantize=quantize)
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
        else:
            llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
            return llama_cls(model_id, revision, quantize=quantize)

    if config.model_type == "opt":
        if sharded:
            return OPTSharded(model_id, revision, quantize=quantize)
        else:
            return OPT(model_id, revision, quantize=quantize)

    if model_type == "t5":
        if sharded:
            return T5Sharded(model_id, revision, quantize=quantize)
        else:
            return Seq2SeqLM(model_id, revision, quantize=quantize)

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
        return CausalLM(model_id, revision, quantize=quantize)
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
        return Seq2SeqLM(model_id, revision, quantize=quantize)

    raise ValueError(f"Unsupported model type {model_type}")