File size: 5,898 Bytes
14c9e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
torch.cuda.empty_cache()
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
base_model = "mistralai/Mistral-7B-Instruct-v0.1"
# new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.1"
new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.3-vn_v2"

# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
print(tokenizer.pad_token, tokenizer.pad_token_id)
tokenizer.padding_side = "right"

# Reload the base model
base_model_reload = AutoModelForCausalLM.from_pretrained(
    base_model, low_cpu_mem_usage=True,
    return_dict=True,torch_dtype=torch.bfloat16,
    device_map= {"": 0})
model = PeftModel.from_pretrained(base_model_reload, new_model)
# model = model.merge_and_unload()



model.config.use_cache = True
model.eval()

def dialogue(role, content):
    return {
        "role": role,
        "content": content
    }


import json, random
TRAIN_DSET = "iroseka_dataset.jsonl"
try:
    with open(TRAIN_DSET, "r", encoding="utf-8") as f:
        examples = [json.loads(line) for line in f]
except FileNotFoundError:
    print("Few-shot data not found, skipping...")
    examples = []

def format_chat_history(example, few_shot=0):
    user_msgs = []
    # for inference each round, we only need the user messages
    for msg in example["messages"]:
        # if msg["role"] == "user":
        user_msgs.append(msg["content"]) 
    messages = [
        dialogue("user", "\n".join(user_msgs)), # join user messages together
        # example["messages"][-1], # the last message is the bot's response
    ]

    if few_shot > 0:
        # randomly sample a few messages from the dialogue history
        few_shot_data = random.sample(examples, few_shot)
        for few_shot_example in few_shot_data:
            few_shot_msgs = []
            for msg in few_shot_example["messages"]:
                if msg["role"] == "user":
                    few_shot_msgs.append(msg["content"])
            messages = [
                dialogue("user", "\n".join(few_shot_msgs)),
                few_shot_example["messages"][-1]
            ] + messages

    encodeds = tokenizer.apply_chat_template(messages, tokenize=False)
    return encodeds

def format_chat_history_v2(example, few_shot):
    # TODO: implement few-shot learning
    user_msg = []
    user_msg.append("<s>")
    for msg in example["messages"]:
        # [INST] What is your favourite condiment? [/INST]
        user_msg.append(f"[INST] {msg['content']} [/INST]")
    # user_msg.append("</s>")
    if "next_speaker" in example:
        user_msg.append(f"[INST] {example['next_speaker']}: ")
    return " ".join(user_msg)

from transformers import StoppingCriteria, StoppingCriteriaList
class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for seq in input_ids:
            for stop in self.stops:
                if len(seq) >= len(stop) and torch.all((stop == seq[-len(stop):])).item():
                    return True
        return False
    
stop_words = ["[/INST]"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

def inference(chat_history):
    # chat_history: dict, with "messages" key storing dialogue history, in OpenAI format
    formatted = format_chat_history_v2(chat_history, few_shot=1)
    print(formatted)
    model_inputs = tokenizer(
        [formatted],
        return_tensors="pt",
    )
    print(model_inputs)
    model_inputs = model_inputs.to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=model_inputs.input_ids,
            attention_mask=model_inputs.attention_mask,
            # max_length=1024,
            do_sample=True,
            top_p=1,
            # contrastive search
            # top_k=50,
            # penalty_alpha=0.6,
            # num_return_sequences=1,
            temperature=0.3,
            # num_return_sequences=3,
            use_cache=True,
            # pad_token_id=tokenizer.eos_token_id, # eos_token_id is not available for some models
            pad_token_id=tokenizer.pad_token_id, # eos_token_id is not available for some models
            eos_token_id=tokenizer.eos_token_id,
            bos_token_id=tokenizer.bos_token_id,
            output_scores=True,
            output_attentions=False,
            output_hidden_states=False,
            max_new_tokens=256,
            # num_beams=9,
            # num_beam_groups=3,
            # repetition_penalty=1.0,
            # diversity_penalty=0.5,
            # num_beams=5,
            # stopping_criteria=stopping_criteria,
        )
        # print(outputs)
        text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        def postprocess(t):
            t = t.split("[/INST]")
            t = [x.replace("[INST]", "").strip() for x in t]
            t = [x for x in t if x != ""]
            return t[-1]
        # text = [postprocess(t) for t in text]
        
    return text

    
if __name__ == "__main__":
    chat_history = {
        "messages": [
            # dialogue("system", ""),
            dialogue("user", "ๅ‚็™ฝ: ็œŸ็ด…ใฎ่จ€่‘‰ใŒ่ƒธใฎไธญใซๆป‘ใ‚Š่พผใ‚“ใงใใ‚‹ใ€‚"),
            dialogue("user", "ๆ‚ ้ฆฌ: ใฃ"),
            dialogue("user", "ๅ‚็™ฝ: ้™็•Œใ ใฃใŸใ€‚"),
            dialogue("user", "ๆ‚ ้ฆฌ: ็œŸ็ด…๏ผŒๅคงๅฅฝใใงใ™ใ€‚ใ“ใ‚Œใ‹ใ‚‰ใ‚‚ใšใฃใจไธ€็ท’ใซใ„ใฆใใ ใ•ใ„ใ€‚"),
        ],
        "next_speaker": "็œŸ็ด…"
    }
    print(inference(chat_history))