A trick to run it on a hoard of smaller GPUs
#1
by
lancercat
- opened
import torch
from accelerate import dispatch_model,infer_auto_device_map,load_checkpoint_and_dispatch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
StoppingCriteriaList,
StoppingCriteria,
)
class StoppingCriteriaSub(StoppingCriteria):
'''Checks if the last n tokens in the input_ids list match the stops list.'''
def __init__(self, stops = []):
super().__init__()
self.stops = stops
def __call__(self, input_ids, scores):
id_list = input_ids[0].tolist()
return id_list[-len(self.stops):] == self.stops
class model:
def __init__(this):
config = AutoConfig.from_pretrained("/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model.tie_weights();
this.model = load_checkpoint_and_dispatch(
model, "/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/",
device_map=infer_auto_device_map(
model, {"cuda:0": "14Gib", "cuda:1": "20Gib"}, dtype=torch.float16),
no_split_module_classes=["LlamaAttention"], dtype=torch.float16,
)
this.llama_tokenizer = AutoTokenizer.from_pretrained(
"/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/")
this.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277, 29937])]) # stop at ###
# the template is based on Vicuna template question and ChatGPT's answer to it. It probably can be better tuned.
def generate_llama(this,text,template, max_new_tokens=256):
'''Generate result using llama model'''
context = template.format(text)
input_ids = this.llama_tokenizer(context, return_tensors="pt").input_ids.to(this.model.device)
output_ids = this.model.generate(input_ids, do_sample=True, top_p=0.8, stopping_criteria=this.stopping_criteria,
max_new_tokens=max_new_tokens)
decode_string = this.llama_tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return decode_string.replace(context, "").replace("###", "").strip("\n")
if __name__ == "__main__":
mod=model();
#torch.save(model,"/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/mess.pt")
print(mod.generate_llama("blah",template = (INSER_YOUR_TEMPLATE)));
Edit: My use case is to translate my Chinese to proper Chinese, so it was called a translator.