moss2-2_5b-chat / README.md
zhanjun's picture
Create README.md
3fb7ff7 verified
|
raw
history blame
1.5 kB
from transformers import AutoModel, AutoTokenizer, StoppingCriteria
import torch
import argparse

class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [137625, 137632, 2]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-1].tolist()
        return any(eos_id in last_ids for eos_id in self.eos_sequence)


def test_model(ckpt):
    model = AutoModel.from_pretrained(ckpt, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(ckpt, trust_remote_code=True)
    init_prompt = "<|im_start|>user\n{input_message}<|end_of_user|>\n<|im_start|>"
    while True:
        history = ""
        print(f">>>让我们开始对话吧<<<")
        input_message = input()
        input_prompt = init_prompt.format(input_message = input_message)
        history += input_prompt
        input_ids = tokenizer.encode(history, return_tensors="pt")
        output = model.generate(input_ids, top_p=1.0, max_new_tokens=300, stopping_criteria = [EosListStoppingCriteria()]).squeeze()
        output_str = tokenizer.decode(output[input_ids.shape[1]: -1])
        print(output_str)
        print(">>>>>>>><<<<<<<<<<")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", type=str, help="path to the checkpoint", required=True)
    args = parser.parse_args()
    test_model(args.ckpt)