A trick to run it on a hoard of smaller GPUs

#1
by lancercat - opened
import torch
from accelerate import dispatch_model,infer_auto_device_map,load_checkpoint_and_dispatch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    StoppingCriteriaList,
    StoppingCriteria,
)

class StoppingCriteriaSub(StoppingCriteria):
    '''Checks if the last n tokens in the input_ids list match the stops list.'''
    def __init__(self, stops = []):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids, scores):
        id_list = input_ids[0].tolist()
        return id_list[-len(self.stops):] == self.stops


class model:
    def __init__(this):
        config = AutoConfig.from_pretrained("/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/")

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)
        model.tie_weights();
        this.model = load_checkpoint_and_dispatch(
            model, "/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/",
            device_map=infer_auto_device_map(
                model, {"cuda:0": "14Gib", "cuda:1": "20Gib"}, dtype=torch.float16),
            no_split_module_classes=["LlamaAttention"], dtype=torch.float16,
        )
        this.llama_tokenizer = AutoTokenizer.from_pretrained(
            "/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/")
        this.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277, 29937])])  # stop at ###

        # the template is based on Vicuna template question and ChatGPT's answer to it. It probably can be better tuned.


    def generate_llama(this,text,template, max_new_tokens=256):
        '''Generate result using llama model'''
        context = template.format(text)
        input_ids = this.llama_tokenizer(context, return_tensors="pt").input_ids.to(this.model.device)
        output_ids = this.model.generate(input_ids, do_sample=True, top_p=0.8, stopping_criteria=this.stopping_criteria,
                                    max_new_tokens=max_new_tokens)
        decode_string = this.llama_tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
        return decode_string.replace(context, "").replace("###", "").strip("\n")

if __name__ == "__main__":
    mod=model();

    #torch.save(model,"/run/media/xxx/modelzoo/cllm/vicuna-chinese-replication-beta/mess.pt")
    print(mod.generate_llama("blah",template = (INSER_YOUR_TEMPLATE)));

Edit: My use case is to translate my Chinese to proper Chinese, so it was called a translator.

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment