File size: 6,317 Bytes
14c9e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

import os


base_model_id = "mistralai/Mistral-7B-Instruct-v0.1" 

WORK = "vn_v2"
new_model_id = f"kmichiru/Nikaido-7B-mistral-instruct-v0.3-{WORK}"

# DSET = {
#     "train": f"dataset_iroseka/{WORK}_dataset.jsonl",
#     "eval": f"dataset_iroseka/{WORK}_validations.jsonl"
# }

DSET = {
    "train": f"dataset_iroseka/{WORK}_train.jsonl",
    "eval": f"dataset_iroseka/{WORK}_val.jsonl"
}


dataset = load_dataset("json", data_files=DSET)
# model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
# max_length = 1024
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

def dialogue(role, content):
    return {
        "role": role,
        "content": content
    }

def format_chat_history(example):
    user_msgs = []
    for msg in example["messages"]:
        if msg["role"] == "user":
            user_msgs.append(msg["content"])
    messages = [
        dialogue("user", "\n".join(user_msgs)), # join user messages together
        example["messages"][-1], # the last message is the bot's response
    ]
    encodeds = tokenizer.apply_chat_template(messages, tokenize=False)
    return encodeds

def prep_speaker(msg: str):
    msg = msg.replace("\u3000", " ") # replace full-width spaces
    speaker, content = msg.split(":", 1)
    speaker = speaker.strip()
    content = content.strip()
    if len(speaker) == 0:
        speaker = "傍白"

    return f"{speaker}: {content}"
    

def format_chat_history_v2(example):
    user_msg = []
    user_msg.append("<s>")
    for msg in example["messages"]:
        # [INST] What is your favourite condiment? [/INST]
        if msg["role"] != "system":
            user_msg.append(f"[INST] {prep_speaker(msg['content'])} [/INST]")
    # user_msg.append("</s>")
    return " ".join(user_msg)

# def format_chat_history_v2(example):
#     user_msgs = []
#     for msg in example["messages"]:
#         if msg["role"] == "user":
#             user_msgs.append(msg["content"])
#     messages = [
#         dialogue("user", "\n".join(user_msgs)), # join user messages together
#         example["messages"][-1], # the last message is the bot's response
#     ]
#     encodeds = tokenizer.apply_chat_template(messages, tokenize=False)
#     return encodeds

print(format_chat_history_v2(dataset['train'][0]))

def generate_and_tokenize_prompt(prompt, max_length=2048):
    result = tokenizer(
        format_chat_history_v2(prompt),
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    result["labels"] = result["input_ids"]
    return result

tokenized_dataset = dataset.map(generate_and_tokenize_prompt)
print(tokenized_dataset['train'][0])

# # stats data length
# def plot_data_lengths(tokenized_dataset):
#     lengths = []
#     for split in tokenized_dataset:
#         lengths += [len(x['input_ids']) for x in tokenized_dataset[split]]
#     print(f"Max length: {max(lengths)}")
#     print(f"Min length: {min(lengths)}")
#     print(f"Mean length: {sum(lengths)/len(lengths)}")
#     print(f"Median length: {sorted(lengths)[len(lengths)//2]}")

# plot_data_lengths(tokenized_dataset)
print(tokenized_dataset['train'][0])

#Adding the adapters in the layers
from peft import LoraConfig, get_peft_model
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable%: {100 * trainable_params / all_param}"
    )
model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.bfloat16)
# model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
        r=64,
        lora_alpha=64,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj"]
    )
model = get_peft_model(model, peft_config)
print_trainable_parameters(model)
print(model)

import wandb, os
# wandb.login()

wandb_project = "NikaidoLM"
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project

import transformers
from datetime import datetime

project = wandb_project
base_model_name = "mistral"
run_name = base_model_name + "-" + project
output_name = f"{run_name}-{WORK}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
output_dir = "/scratch/generalvision/mowentao/mistral-out/" + output_name

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=500,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        num_train_epochs=3,
        weight_decay=5e-4,
        # max_steps=10_000,
        learning_rate=1e-4, # Want a small lr for finetuning
        bf16=True,
        optim="paged_adamw_32bit",
        logging_steps=100,              # When to start reporting loss
        logging_dir=output_dir,        # Directory for storing logs
        save_strategy="steps",       # Save the model checkpoint every logging step
        save_steps=500,                # Save checkpoints every 50 steps
        evaluation_strategy="steps", # Evaluate the model every logging step
        eval_steps=100,               # Evaluate and save checkpoints every 50 steps
        do_eval=True,                # Perform evaluation at the end of training
        report_to="wandb",           # Comment this out if you don't want to use weights & baises
        run_name=output_name,         # Name of the W&B run (optional)
        lr_scheduler_type="cosine",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()
trainer.model.save_pretrained(new_model_id)
wandb.finish()