iLoRA / model /model_interface.py
MingLi
fork and bug fix from https://github.com/AkaliKong/iLoRA
9f13819
import inspect
import torch
import importlib
from torch import nn
from torch.nn import functional as F
import torch.optim.lr_scheduler as lrs
import pytorch_lightning as pl
from transformers import LlamaForCausalLM, LlamaTokenizer
import random
from pandas.core.frame import DataFrame
import os.path as op
import os
from optims import LinearWarmupCosineLRScheduler
import numpy as np
from .peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel, MoeLoraConfig, MoeLoraModel
import pickle
from .router.nlpr import LambdaLayer, ResidualBlock, GateFunction, NLPRecommendationRouter, build_router
# from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
class MInterface(pl.LightningModule):
def __init__(self,
**kargs):
super().__init__()
self.save_hyperparameters()
self.load_llm(self.hparams.llm_path)
if self.hparams.router == 'share':
self.router = build_router()
self.load_rec_model(self.hparams.rec_model_path)
self.load_projector()
self.gradient_storage = {}
def forward(self, batch):
targets = batch["tokens"].input_ids.masked_fill(
batch["tokens"].input_ids == self.llama_tokenizer.pad_token_id, -100
) # [batch_size, max_len]
targets = targets.masked_fill((batch["tokens"].token_type_ids == 0)[:,1:], -100)
# targets = targets.masked_fill((batch["tokens"].token_type_ids == 0)[:,:], -100)
input_embeds, user_embeds = self.wrap_emb(batch)
if self.hparams.router == 'share':
gate_weights = self.router(user_embeds)
outputs = self.llama_model(
inputs_embeds=input_embeds,
attention_mask=batch["tokens"].attention_mask,
return_dict=True,
labels=targets,
use_cache=False,
user_embeds=user_embeds,
gate_weights=gate_weights
)
return outputs
outputs = self.llama_model(
inputs_embeds=input_embeds,
attention_mask=batch["tokens"].attention_mask,
return_dict=True,
labels=targets,
use_cache=False,
user_embeds=user_embeds
)
return outputs
def generate(self, batch,temperature=0.8,do_sample=False,num_beams=1,max_gen_length=64,min_gen_length=1,repetition_penalty=1.0,length_penalty=1.0, num_return_sequences=1):
input_embeds, user_embeds = self.wrap_emb(batch)
if self.hparams.router == 'share':
gate_weights = self.router(user_embeds)
generate_ids = self.llama_model.generate(
inputs_embeds=input_embeds,
attention_mask=batch["tokens"].attention_mask,
temperature=temperature,
do_sample=do_sample,
num_beams=num_beams,
max_new_tokens=max_gen_length,
min_new_tokens=min_gen_length,
pad_token_id=self.llama_tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
user_embeds=user_embeds,
gate_weights = gate_weights
)
output_text=self.llama_tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
outputs=[text.strip() for text in output_text]
return outputs
gate_weights = self.router(user_embeds)
generate_ids = self.llama_model.generate(
inputs_embeds=input_embeds,
attention_mask=batch["tokens"].attention_mask,
temperature=temperature,
do_sample=do_sample,
num_beams=num_beams,
max_new_tokens=max_gen_length,
min_new_tokens=min_gen_length,
pad_token_id=self.llama_tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
user_embeds=user_embeds,
gate_weights = gate_weights
)
output_text=self.llama_tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
outputs=[text.strip() for text in output_text]
return outputs
def capture_and_store_gradients(self):
for name, param in self.llama_model.named_parameters():
if "lora" in name and param.grad is not None:
if name not in self.gradient_storage:
self.gradient_storage[name] = []
self.gradient_storage[name].append(param.grad.clone().detach())
if self.trainer.global_step % 10 == 0:
self.save_gradients_to_file()
def save_gradients_to_file(self):
directory = self.hparams.capture_dir
if not os.path.exists(directory):
os.makedirs(directory)
file_path = os.path.join(directory, f'gradients_step_{self.trainer.global_step}.pkl')
with open(file_path, 'wb') as f:
pickle.dump(self.gradient_storage, f)
self.gradient_storage = {}
def training_step(self, batch, batch_idx):
if self.scheduler:
self.scheduler.step(self.trainer.global_step, self.current_epoch, self.trainer.max_steps)
if batch["flag"]:
for name, param in self.projector.named_parameters():
param.requires_grad = False
else:
for name, param in self.projector.named_parameters():
param.requires_grad = True
out = self(batch)
loss = self.configure_loss(out)
self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('lr', self.scheduler.optimizer.param_groups[0]['lr'], on_step=True, on_epoch=True, prog_bar=True)
self.log('global_step_num', self.trainer.global_step, on_step=True, on_epoch=True, prog_bar=True)
return loss
def on_validation_epoch_start(self):
self.val_content={
"generate":[],
"real":[],
"cans":[],
}
@torch.no_grad()
def validation_step(self, batch, batch_idx):
generate_output = self.generate(batch)
output=[]
for i,generate in enumerate(generate_output):
real=batch['correct_answer'][i]
cans=batch['cans_name'][i]
generate=generate.strip().split("\n")[0]
output.append((generate,real,cans))
return output
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
for generate,real,cans in outputs:
self.val_content["generate"].append(generate)
self.val_content["real"].append(real)
self.val_content["cans"].append(cans)
def on_validation_epoch_end(self):
df=DataFrame(self.val_content)
if not os.path.exists(self.hparams.output_dir):
os.makedirs(self.hparams.output_dir)
df.to_csv(op.join(self.hparams.output_dir, 'valid.csv'))
prediction_valid_ratio,hr=self.calculate_hr1(self.val_content)
metric=hr*prediction_valid_ratio
self.log('val_prediction_valid', prediction_valid_ratio, on_step=False, on_epoch=True, prog_bar=True)
self.log('val_hr', hr, on_step=False, on_epoch=True, prog_bar=True)
self.log('metric', metric, on_step=False, on_epoch=True, prog_bar=True)
def on_test_epoch_start(self):
self.test_content={
"generate":[],
"real":[],
"cans":[],
}
@torch.no_grad()
def test_step(self, batch, batch_idx):
generate_output = self.generate(batch)
output=[]
for i,generate in enumerate(generate_output):
real=batch['correct_answer'][i]
cans=batch['cans_name'][i]
generate=generate.strip().split("\n")[0]
output.append((generate,real,cans))
return output
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
for generate,real,cans in outputs:
self.test_content["generate"].append(generate)
self.test_content["real"].append(real)
self.test_content["cans"].append(cans)
def on_test_epoch_end(self):
df=DataFrame(self.test_content)
if not os.path.exists(self.hparams.output_dir):
os.makedirs(self.hparams.output_dir)
df.to_csv(op.join(self.hparams.output_dir, 'test.csv'))
prediction_valid_ratio,hr=self.calculate_hr1(self.test_content)
metric=hr*prediction_valid_ratio
self.log('test_prediction_valid', prediction_valid_ratio, on_step=False, on_epoch=True, prog_bar=True)
self.log('test_hr', hr, on_step=False, on_epoch=True, prog_bar=True)
self.log('metric', metric, on_step=False, on_epoch=True, prog_bar=True)
def configure_optimizers(self):
if hasattr(self.hparams, 'weight_decay'):
weight_decay = self.hparams.weight_decay
else:
weight_decay = 0
optimizer = torch.optim.SGD([
{'params': self.projector.parameters(), 'lr': self.hparams.lr, 'weight_decay':weight_decay},
{'params': self.router.parameters(), 'lr': self.hparams.lr * 0.3, 'weight_decay':weight_decay},
{'params': [p for n, p in self.llama_model.named_parameters() if "gating" not in n], 'lr': self.hparams.lr},
# {'params': [p for n, p in self.llama_model.named_parameters() if "gating" in n], 'lr': self.hparams.lr * 1, 'weight_decay':weight_decay}
# {'params': self.llama_model.parameters(), 'lr': self.hparams.lr},
])
for i, param_group in enumerate(optimizer.param_groups):
print(f"Initial LR for group {i}: {param_group['lr']}")
total_params = sum(p.numel() for p in param_group['params'])
print(f"Parameter Group {i}: {total_params} parameters")
if self.hparams.lr_scheduler is None:
return optimizer
else:
max_step = self.trainer.max_steps
warmup_steps = max_step // 20
print(f'max_step: {max_step}')
print(f'warmup_steps: {warmup_steps}')
if self.hparams.lr_scheduler == 'cosine':
init_lr_list = [
self.hparams.lr,
self.hparams.lr * 0.3,
self.hparams.lr * 1
]
min_lr_list = [
self.hparams.lr_decay_min_lr,
self.hparams.lr_decay_min_lr * 0.3,
self.hparams.lr_decay_min_lr * 1
]
warmup_start_lr_list = [
self.hparams.lr_warmup_start_lr,
self.hparams.lr_warmup_start_lr * 0.3,
self.hparams.lr_warmup_start_lr * 1
]
self.scheduler = LinearWarmupCosineLRScheduler(
optimizer=optimizer,
max_step=max_step,
min_lr_list=min_lr_list,
init_lr_list=init_lr_list,
warmup_steps=warmup_steps,
warmup_start_lr_list=warmup_start_lr_list
)
for i, param_group in enumerate(optimizer.param_groups):
print(f"Initial LR for group {i}: {param_group['lr']}")
total_params = sum(p.numel() for p in param_group['params'])
print(f"Parameter Group {i}: {total_params} parameters")
else:
self.scheduler = None
raise ValueError('Invalid lr_scheduler type!')
return optimizer
def configure_loss(self, out, labels=None):
loss = self.hparams.loss.lower()
if loss == 'lm':
return out.loss
else:
raise ValueError("Invalid Loss Type!")
def on_save_checkpoint(self, checkpoint):
if self.hparams.save == 'part':
checkpoint.pop('optimizer_states')
to_be_removed = []
for key, value in checkpoint['state_dict'].items():
try:
if not self.get_parameter(key).requires_grad:
to_be_removed.append(key)
except AttributeError:
to_be_removed.append(key)
for key in to_be_removed:
checkpoint['state_dict'].pop(key)
elif self.hparams.save == 'all':
pass
def load_llm(self, llm_path):
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llm_path, use_fast=False)
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.llama_tokenizer.padding_side = "right"
self.llama_tokenizer.add_special_tokens({'additional_special_tokens': ['[PH]','[HistoryEmb]','[CansEmb]','[ItemEmb]']})
self.llama_model = LlamaForCausalLM.from_pretrained(llm_path, device_map="auto",load_in_8bit=True)
self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
if self.hparams.llm_tuning == 'lora':
if self.hparams.peft_dir:
self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True)
else:
if self.hparams.peft_config:
peft_config = LoraConfig(**LoraConfig.from_json_file(self.hparams.peft_config))
else:
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=self.hparams.lora_r,
lora_alpha=self.hparams.lora_alpha,
lora_dropout=self.hparams.lora_dropout,
target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])
self.peft_config = peft_config
self.llama_model = get_peft_model(self.llama_model, peft_config)
self.llama_model.print_trainable_parameters()
elif self.hparams.llm_tuning == 'freeze':
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
elif self.hparams.llm_tuning == 'freeze_lora':
if self.hparams.peft_dir:
self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True)
else:
if self.hparams.peft_config:
peft_config = LoraConfig(**LoraConfig.from_json_file(self.hparams.peft_config))
else:
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=self.hparams.lora_r,
lora_alpha=self.hparams.lora_alpha,
lora_dropout=self.hparams.lora_dropout,
target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])
self.peft_config = peft_config
self.llama_model = get_peft_model(self.llama_model, peft_config)
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
self.llama_model.print_trainable_parameters()
elif self.hparams.llm_tuning == 'moelora':
if self.hparams.peft_dir:
self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True)
else:
if self.hparams.peft_config:
peft_config = MoeLoraConfig(**MoeLoraConfig.from_json_file(self.hparams.peft_config))
else:
peft_config = MoeLoraConfig(task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=self.hparams.lora_r,
lora_alpha=self.hparams.lora_alpha,
lora_dropout=self.hparams.lora_dropout,
target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
num_moe=self.hparams.num_moe,
gating=self.hparams.gating)
self.peft_config = peft_config
self.llama_model = get_peft_model(self.llama_model, peft_config)
"""for name, param in self.llama_model.named_parameters():
if "gating" not in name:
param.requires_grad = False"""
self.llama_model.print_trainable_parameters()
else:
raise NotImplementedError()
print('Loading LLAMA Done')
def load_projector(self):
name = self.hparams.model_name
camel_name = ''.join([i.capitalize() for i in name.split('_')])
try:
Model = getattr(importlib.import_module(
'.'+name, package=__package__), camel_name)
except:
raise ValueError(
f'Invalid Module File Name or Invalid Class Name {name}.{camel_name}!')
self.projector = self.instancialize(Model, rec_size=self.hparams.rec_size, llm_size=self.llama_model.config.hidden_size)
def instancialize(self, Model, **other_args):
class_args = inspect.getargspec(Model.__init__).args[1:]
inkeys = self.hparams.keys()
args1 = {}
for arg in class_args:
if arg in inkeys:
args1[arg] = getattr(self.hparams, arg)
args1.update(other_args)
# args1: args在hparams中有的部分
return Model(**args1)
def load_rec_model(self, rec_model_path):
print('Loading Rec Model')
self.rec_model = torch.load(rec_model_path, map_location="cpu")
self.rec_model.eval()
for name, param in self.rec_model.named_parameters():
param.requires_grad = False
print('Loding Rec model Done')
def encode_items(self, seq):
if self.hparams.rec_embed=="SASRec":
item_rec_embs=self.rec_model.cacu_x(seq)
elif self.hparams.rec_embed in ['Caser','GRU']:
item_rec_embs=self.rec_model.item_embeddings(seq)
item_txt_embs=self.projector(item_rec_embs)
return item_txt_embs
def encode_users(self, seq, len_seq):
if self.hparams.rec_embed=="SASRec":
user_rec_embs=self.rec_model.cacul_h(seq, len_seq)
elif self.hparams.rec_embed in ['Caser','GRU']:
user_rec_embs=self.rec_model.item_embeddings(seq)
user_txt_embs=self.projector(user_rec_embs)
return user_rec_embs
def embed_tokens(self, token_ids):
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
# batch -> embeds
def wrap_emb(self, batch):
input_embeds = self.llama_model.get_input_embeddings()(batch["tokens"].input_ids)
his_token_id=self.llama_tokenizer("[HistoryEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item()
cans_token_id=self.llama_tokenizer("[CansEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item()
item_token_id=self.llama_tokenizer("[ItemEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item()
his_item_embeds = self.encode_items(batch["seq"])
cans_item_embeds = self.encode_items(batch["cans"])
item_embeds=self.encode_items(batch["item_id"])
user_embeds=self.encode_users(batch["seq"], batch["len_seq"])
for i in range(len(batch["len_seq"])):
if (batch["tokens"].input_ids[i]==his_token_id).nonzero().shape[0]>0:
idx_tensor=(batch["tokens"].input_ids[i]==his_token_id).nonzero().view(-1)
for idx, item_emb in zip(idx_tensor,his_item_embeds[i,:batch["len_seq"][i].item()]):
input_embeds[i,idx]=item_emb
if (batch["tokens"].input_ids[i]==cans_token_id).nonzero().shape[0]>0:
idx_tensor=(batch["tokens"].input_ids[i]==cans_token_id).nonzero().view(-1)
for idx, item_emb in zip(idx_tensor,cans_item_embeds[i,:batch["len_cans"][i].item()]):
input_embeds[i,idx]=item_emb
if (batch["tokens"].input_ids[i]==item_token_id).nonzero().shape[0]>0:
idx=(batch["tokens"].input_ids[i]==item_token_id).nonzero().item()
input_embeds[i,idx]=item_embeds[i]
return input_embeds, user_embeds
def calculate_hr1(self,eval_content):
correct_num=0
valid_num=0
total_num=0
for i,generate in enumerate(eval_content["generate"]):
real=eval_content["real"][i]
cans=eval_content["cans"][i]
total_num+=1
generate=generate.strip().lower().strip()
real=real.strip().lower().strip()
cans=[item.strip().lower().strip() for item in cans]
gen_cans_list=[]
for cans_item in cans:
if cans_item in generate:
gen_cans_list.append(cans_item)
if len(gen_cans_list)==1:
valid_num+=1
if real == gen_cans_list[0]:
correct_num+=1
valid_ratio=valid_num/total_num
if valid_num>0:
hr1=correct_num/valid_num
else:
hr1=0
return valid_ratio,hr1