File size: 2,754 Bytes
729ce9e 73a8df3 729ce9e 351fc66 73697bb 6ee6fc5 729ce9e 886f598 2497126 886f598 bfcf497 886f598 2497126 73a8df3 364b58d 64acb26 57acff1 886f598 73a8df3 886f598 73a8df3 57acff1 73a8df3 886f598 73a8df3 886f598 729ce9e 24320e8 729ce9e 2497126 886f598 729ce9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
import transformers
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
from typing import Dict, List, Any
class PreTrainedPipeline():
def __init__(self, path=""):
path = "oleksandrfluxon/mpt-7b-chat-4bit"
print("===> path", path)
with torch.autocast('cuda'):
config = transformers.AutoConfig.from_pretrained(
path,
trust_remote_code=True
)
# config.attn_config['attn_impl'] = 'triton'
config.init_device = 'cuda:0' # For fast initialization directly on GPU!
config.max_seq_len = 4096 # (input + output) tokens can now be up to 4096
print("===> loading model")
model = transformers.AutoModelForCausalLM.from_pretrained(
path,
config=config,
# torch_dtype=torch.bfloat16, # Load model weights in bfloat16
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
load_in_8bit=True # Load model in the lowest 4-bit precision quantization
)
# model.to('cuda')
print("===> model loaded")
# removed device_map="auto"
tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', padding_side="left")
max_memory = get_balanced_memory(
model,
max_memory=None,
no_split_module_classes=["MPTBlock"],
dtype='float16',
low_zero=False
)
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=["MPTBlock"],
dtype='float16'
)
model = dispatch_model(model, device_map=device_map)
# device='cuda:0'
self.pipeline = transformers.pipeline('text-generation', model=model, tokenizer=tokenizer)
print("===> init finished")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
parameters (:obj: `str`)
Return:
A :obj:`str`: todo
"""
# get inputs
inputs = data.pop("inputs",data)
parameters = data.pop("parameters", {})
date = data.pop("date", None)
print("===> inputs", inputs)
print("===> parameters", parameters)
with torch.autocast('cuda'):
result = self.pipeline(inputs, **parameters)
print("===> result", result)
return result
|