Spaces:
Runtime error
Runtime error
### Load Model From huggingface | |
import os | |
import tqdm | |
import joblib | |
import numpy as np | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import wandb | |
import peft | |
import loralib as lora | |
from peft import LoraConfig | |
import json | |
from torch.utils.data import DataLoader | |
from torch.utils.data import Dataset | |
from accelerate import Accelerator, DeepSpeedPlugin | |
from transformers import get_linear_schedule_with_warmup | |
""" | |
extra requirements: | |
pip install icetk | |
""" | |
checkpoint = "/model/chatglm-6b" | |
datafile='datasets/merge.json' | |
out_dir= 'outs/chatglm-6b' | |
use_wandb=True | |
mixed_precision = 'bf16' | |
accumulate_step = 8 | |
log_interval = 100 | |
Per_GPU_BATCH_SIZE = 2 | |
MAX_LENGTH = 256 # have huge impact on VRAM: 968:1, 256:4 | |
config = LoraConfig( | |
peft_type="LORA", | |
r=32, | |
lora_alpha=32, | |
target_modules=["q", "k", "v"], | |
lora_dropout=0.1, | |
) | |
LR = 2e-5 | |
NUM_EPOCHS = 3 | |
warm_up_ratio = 0.1 | |
device_map = "auto" | |
world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
ddp = world_size != 1 | |
if ddp: | |
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} | |
if use_wandb: | |
wandb.init( | |
project="LoRA", | |
name=f"{checkpoint}-{datafile}", | |
config=None, | |
) | |
else: | |
wandb.init(mode='disabled') | |
os.makedirs(out_dir, exist_ok=True) | |
tokenizer = AutoTokenizer.from_pretrained( | |
checkpoint, | |
trust_remote_code=True, | |
device_map=device_map, | |
) | |
# BUG: must remove special token '[MASK]' | |
# del tokenizer.vocab['MASK'] | |
### Dataset | |
EOS_ID = 150005 | |
PROMPT_DICT = { | |
"prompt_input": ( | |
"Below is an instruction that describes a task, paired with an input that provides further context. " | |
"Write a response that appropriately completes the request.\n\n" | |
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" | |
), | |
"prompt_no_input": ( | |
"Below is an instruction that describes a task. " | |
"Write a response that appropriately completes the request.\n\n" | |
"### Instruction:\n{instruction}\n\n### Response:" | |
), | |
} | |
with open(datafile, 'r') as f: | |
content = json.load(f) | |
pairs = [] | |
for line in content: | |
if line['input'] == '': | |
prompt = PROMPT_DICT['prompt_no_input'].format_map(line) | |
else: | |
prompt = PROMPT_DICT['prompt_input'].format_map(line) | |
completion = line['output']+'</s>' | |
if len(prompt) + len(completion) < MAX_LENGTH: | |
pairs.append({'prompt':prompt, 'completion':completion}) | |
class AlpacaDataset(Dataset): | |
def __init__(self, pairs, tokenizer) -> None: | |
super().__init__() | |
self.pairs = pairs | |
self.tokenizer = tokenizer | |
def __getitem__(self, index): | |
if self.pairs[index]['completion'][-4:] == '</s>': | |
prompt = self.tokenizer.encode(self.pairs[index]['prompt']) | |
completion = self.tokenizer.encode(self.pairs[index]['completion'][:-4], add_special_tokens=False) | |
completion += [EOS_ID] | |
else: | |
prompt = self.tokenizer.encode(self.pairs[index]['prompt']) | |
completion = self.tokenizer.encode(self.pairs[index]['completion'], add_special_tokens=False) | |
if 150001 not in prompt: | |
prompt = self.pairs[index]['prompt'].replace('[MASK]', '//MASK//').replace('[gMASK]', '//gMASK//') | |
completion = self.pairs[index]['completion'].replace('[MASK]', '//MASK//').replace('[gMASK]', '//gMASK//') | |
prompt = self.tokenizer.encode(prompt) | |
completion = self.tokenizer.encode(completion, add_special_tokens=False) | |
if 150001 not in prompt: | |
import pdb; pdb.set_trace() | |
return {'prompt':prompt, 'completion':completion} | |
def __len__(self): | |
return len(self.pairs) | |
def collate_fn(batch): | |
input_ids = [] | |
labels = [] | |
position_ids = [] | |
device='cuda:0' | |
_max_length = max([len(obj['prompt'])+len(obj['completion']) for obj in batch]) | |
attention_mask = torch.ones((len(batch), _max_length, _max_length), device=device) | |
attention_mask.tril_() | |
for i, obj in enumerate(batch): | |
context_length = obj['prompt'].index(150004) | |
attention_mask[i, :, :context_length] = 1 | |
to_pad = _max_length - len(obj['prompt']) - len(obj['completion']) | |
input_ids.append(obj['prompt'] + obj['completion'] + [tokenizer.pad_token_id] * to_pad) | |
position_ids.append(torch.stack( | |
[torch.arange(0, _max_length, device=device), | |
torch.concat([torch.zeros(context_length - 1, device=device), | |
torch.arange(0, _max_length - context_length + 1, device=device)])]).long() | |
) | |
labels.append(torch.tensor([-100] * len(obj['prompt']) + obj['completion'] + [-100] * to_pad, device=device).long()) | |
attention_mask.unsqueeze_(1) | |
attention_mask = (attention_mask < 0.5).bool() | |
return {'input_ids': torch.tensor(input_ids).long(), | |
'attention_mask': attention_mask, | |
'labels': torch.stack(labels), | |
'position_ids':torch.stack(position_ids)} | |
train_dataset = AlpacaDataset(pairs,tokenizer=tokenizer,) | |
train_dataloader = DataLoader(dataset=train_dataset, collate_fn = collate_fn, shuffle=True, batch_size=Per_GPU_BATCH_SIZE) | |
# check | |
for step, batch in enumerate(t:=tqdm.tqdm(train_dataloader)): | |
pass | |
model = AutoModel.from_pretrained( | |
checkpoint, | |
trust_remote_code=True, | |
) | |
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step) | |
accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin) | |
device = accelerator.device | |
### Insert LoRA to model | |
class QKV_layer(torch.nn.Module): | |
def __init__(self, in_features, out_features): | |
super(QKV_layer, self).__init__() | |
self.linear_q = torch.nn.Linear(in_features, out_features//3) | |
self.linear_k = torch.nn.Linear(in_features, out_features//3) | |
self.linear_v = torch.nn.Linear(in_features, out_features//3) | |
def update(self, target_layer): | |
self.linear_q.weight.data = target_layer.weight[:target_layer.out_features//3, :].data | |
self.linear_q.bias.data = target_layer.bias[:target_layer.out_features//3].data | |
self.linear_k.weight.data = target_layer.weight[target_layer.out_features//3:target_layer.out_features//3*2, :].data | |
self.linear_k.bias.data = target_layer.bias[target_layer.out_features//3:target_layer.out_features//3*2].data | |
self.linear_v.weight.data = target_layer.weight[target_layer.out_features//3*2:, :].data | |
self.linear_v.bias.data = target_layer.bias[target_layer.out_features//3*2:].data | |
def forward(self, x): | |
q = self.linear_q(x) | |
k = self.linear_k(x) | |
v = self.linear_v(x) | |
return torch.concat([q,k,v], dim = -1) | |
for key, module in model.named_modules(): | |
if key.endswith('attention'): | |
if isinstance(module.query_key_value, peft.tuners.lora.LoraModel): | |
module.query_key_value = peft.tuners.lora.LoraModel(config, module.query_key_value.model) | |
else: | |
# Here we split the query_key_value layer into three linear layer for LoRA. But you can also use merged linear. | |
qkv_layer = QKV_layer(module.query_key_value.in_features, module.query_key_value.out_features) | |
qkv_layer.update(module.query_key_value) | |
module.query_key_value = qkv_layer | |
module.query_key_value = peft.tuners.lora.LoraModel(config, module.query_key_value) | |
lora.mark_only_lora_as_trainable(model) | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
trainable_params = sum([np.prod(p.size()) for p in model_parameters]) | |
non_trainable_params = sum([np.prod(p.size()) for p in model_parameters]) | |
print('trainable_params:{} ({:.2f}%), non_trainable_params:{}'.format( | |
trainable_params, trainable_params/non_trainable_params*100,non_trainable_params | |
)) | |
### Training | |
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) | |
lr_scheduler = get_linear_schedule_with_warmup( | |
optimizer=optimizer, | |
num_warmup_steps=int(len(train_dataloader) / accumulate_step * warm_up_ratio), | |
num_training_steps=(int(len(train_dataloader) / accumulate_step) * NUM_EPOCHS), | |
) | |
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) | |
model.to(device).train() | |
for epoch in range(NUM_EPOCHS): | |
total_loss = 0 | |
for step, batch in enumerate(t:=tqdm.tqdm(train_dataloader)): | |
with accelerator.accumulate(model): | |
outputs = model(**batch) | |
loss_detach = outputs.loss.detach().cpu().float() | |
# t.set_description(f"loss: {loss_detach}") | |
t.set_postfix(loss=loss_detach.item()) | |
total_loss += loss_detach | |
loss = outputs.loss | |
if accelerator.is_main_process: | |
if step % log_interval == 0: | |
wandb.log({ | |
'train/loss': loss_detach.item(), | |
}) | |
accelerator.backward(loss) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
peft_model_id = f"finetune_{epoch}" | |
accelerator.save(lora.lora_state_dict(accelerator.unwrap_model(model)), f'{out_dir}/{peft_model_id}.pt') | |