Spaces:
Runtime error
Runtime error
File size: 5,600 Bytes
c7a96cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from loguru import logger
from transformers import AutoConfig
from transformers.models.auto import modeling_auto
from typing import Optional
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
try:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
supported = is_sm75 or is_sm8x or is_sm90
if not supported:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
)
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
FlashLlamaSharded,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoder,
FlashSantacoderSharded,
)
FLASH_ATTENTION = True
else:
FLASH_ATTENTION = False
except ImportError:
logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
)
FLASH_ATTENTION = False
__all__ = [
"Model",
"BLOOM",
"BLOOMSharded",
"CausalLM",
"FlashCausalLM",
"Galactica",
"GalacticaSharded",
"GPTNeoxSharded",
"Seq2SeqLM",
"SantaCoder",
"OPT",
"OPTSharded",
"T5Sharded",
"get_model",
]
if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(FlashLlamaSharded)
FLASH_ATT_ERROR_MESSAGE = (
"{} requires Flash Attention CUDA kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
if "bigcode" in model_id:
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if model_type == "bloom":
if sharded:
return BLOOMSharded(model_id, revision, quantize=quantize)
else:
return BLOOM(model_id, revision, quantize=quantize)
if model_type == "gpt_neox":
if sharded:
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
return neox_cls(model_id, revision, quantize=quantize)
else:
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
return neox_cls(model_id, revision, quantize=quantize)
if model_type == "llama":
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(model_id, revision, quantize=quantize)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(model_id, revision, quantize=quantize)
if config.model_type == "opt":
if sharded:
return OPTSharded(model_id, revision, quantize=quantize)
else:
return OPT(model_id, revision, quantize=quantize)
if model_type == "t5":
if sharded:
return T5Sharded(model_id, revision, quantize=quantize)
else:
return Seq2SeqLM(model_id, revision, quantize=quantize)
if sharded:
raise ValueError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(model_id, revision, quantize=quantize)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(model_id, revision, quantize=quantize)
raise ValueError(f"Unsupported model type {model_type}")
|