|
import inspect |
|
import torch |
|
import importlib |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import torch.optim.lr_scheduler as lrs |
|
|
|
import pytorch_lightning as pl |
|
|
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
import random |
|
from pandas.core.frame import DataFrame |
|
import os.path as op |
|
import os |
|
from optims import LinearWarmupCosineLRScheduler |
|
import numpy as np |
|
from .peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel, MoeLoraConfig, MoeLoraModel |
|
import pickle |
|
from .router.nlpr import LambdaLayer, ResidualBlock, GateFunction, NLPRecommendationRouter, build_router |
|
|
|
|
|
|
|
class MInterface(pl.LightningModule): |
|
def __init__(self, |
|
**kargs): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.load_llm(self.hparams.llm_path) |
|
|
|
if self.hparams.router == 'share': |
|
self.router = build_router() |
|
|
|
self.load_rec_model(self.hparams.rec_model_path) |
|
self.load_projector() |
|
self.gradient_storage = {} |
|
|
|
def forward(self, batch): |
|
targets = batch["tokens"].input_ids.masked_fill( |
|
batch["tokens"].input_ids == self.llama_tokenizer.pad_token_id, -100 |
|
) |
|
|
|
targets = targets.masked_fill((batch["tokens"].token_type_ids == 0)[:,1:], -100) |
|
|
|
|
|
input_embeds, user_embeds = self.wrap_emb(batch) |
|
|
|
if self.hparams.router == 'share': |
|
gate_weights = self.router(user_embeds) |
|
outputs = self.llama_model( |
|
inputs_embeds=input_embeds, |
|
attention_mask=batch["tokens"].attention_mask, |
|
return_dict=True, |
|
labels=targets, |
|
use_cache=False, |
|
user_embeds=user_embeds, |
|
gate_weights=gate_weights |
|
) |
|
return outputs |
|
|
|
outputs = self.llama_model( |
|
inputs_embeds=input_embeds, |
|
attention_mask=batch["tokens"].attention_mask, |
|
return_dict=True, |
|
labels=targets, |
|
use_cache=False, |
|
user_embeds=user_embeds |
|
) |
|
return outputs |
|
|
|
def generate(self, batch,temperature=0.8,do_sample=False,num_beams=1,max_gen_length=64,min_gen_length=1,repetition_penalty=1.0,length_penalty=1.0, num_return_sequences=1): |
|
input_embeds, user_embeds = self.wrap_emb(batch) |
|
if self.hparams.router == 'share': |
|
gate_weights = self.router(user_embeds) |
|
generate_ids = self.llama_model.generate( |
|
inputs_embeds=input_embeds, |
|
attention_mask=batch["tokens"].attention_mask, |
|
temperature=temperature, |
|
do_sample=do_sample, |
|
num_beams=num_beams, |
|
max_new_tokens=max_gen_length, |
|
min_new_tokens=min_gen_length, |
|
pad_token_id=self.llama_tokenizer.pad_token_id, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
num_return_sequences=num_return_sequences, |
|
user_embeds=user_embeds, |
|
gate_weights = gate_weights |
|
) |
|
output_text=self.llama_tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
outputs=[text.strip() for text in output_text] |
|
return outputs |
|
|
|
gate_weights = self.router(user_embeds) |
|
|
|
generate_ids = self.llama_model.generate( |
|
inputs_embeds=input_embeds, |
|
attention_mask=batch["tokens"].attention_mask, |
|
temperature=temperature, |
|
do_sample=do_sample, |
|
num_beams=num_beams, |
|
max_new_tokens=max_gen_length, |
|
min_new_tokens=min_gen_length, |
|
pad_token_id=self.llama_tokenizer.pad_token_id, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
num_return_sequences=num_return_sequences, |
|
|
|
user_embeds=user_embeds, |
|
gate_weights = gate_weights |
|
) |
|
output_text=self.llama_tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
outputs=[text.strip() for text in output_text] |
|
return outputs |
|
|
|
def capture_and_store_gradients(self): |
|
for name, param in self.llama_model.named_parameters(): |
|
if "lora" in name and param.grad is not None: |
|
if name not in self.gradient_storage: |
|
self.gradient_storage[name] = [] |
|
self.gradient_storage[name].append(param.grad.clone().detach()) |
|
|
|
if self.trainer.global_step % 10 == 0: |
|
self.save_gradients_to_file() |
|
|
|
def save_gradients_to_file(self): |
|
directory = self.hparams.capture_dir |
|
if not os.path.exists(directory): |
|
os.makedirs(directory) |
|
file_path = os.path.join(directory, f'gradients_step_{self.trainer.global_step}.pkl') |
|
with open(file_path, 'wb') as f: |
|
pickle.dump(self.gradient_storage, f) |
|
self.gradient_storage = {} |
|
|
|
def training_step(self, batch, batch_idx): |
|
if self.scheduler: |
|
self.scheduler.step(self.trainer.global_step, self.current_epoch, self.trainer.max_steps) |
|
if batch["flag"]: |
|
for name, param in self.projector.named_parameters(): |
|
param.requires_grad = False |
|
else: |
|
for name, param in self.projector.named_parameters(): |
|
param.requires_grad = True |
|
out = self(batch) |
|
loss = self.configure_loss(out) |
|
self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log('lr', self.scheduler.optimizer.param_groups[0]['lr'], on_step=True, on_epoch=True, prog_bar=True) |
|
self.log('global_step_num', self.trainer.global_step, on_step=True, on_epoch=True, prog_bar=True) |
|
|
|
return loss |
|
|
|
def on_validation_epoch_start(self): |
|
self.val_content={ |
|
"generate":[], |
|
"real":[], |
|
"cans":[], |
|
} |
|
|
|
@torch.no_grad() |
|
def validation_step(self, batch, batch_idx): |
|
generate_output = self.generate(batch) |
|
output=[] |
|
for i,generate in enumerate(generate_output): |
|
real=batch['correct_answer'][i] |
|
cans=batch['cans_name'][i] |
|
generate=generate.strip().split("\n")[0] |
|
output.append((generate,real,cans)) |
|
return output |
|
|
|
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): |
|
for generate,real,cans in outputs: |
|
self.val_content["generate"].append(generate) |
|
self.val_content["real"].append(real) |
|
self.val_content["cans"].append(cans) |
|
|
|
def on_validation_epoch_end(self): |
|
df=DataFrame(self.val_content) |
|
if not os.path.exists(self.hparams.output_dir): |
|
os.makedirs(self.hparams.output_dir) |
|
df.to_csv(op.join(self.hparams.output_dir, 'valid.csv')) |
|
prediction_valid_ratio,hr=self.calculate_hr1(self.val_content) |
|
metric=hr*prediction_valid_ratio |
|
self.log('val_prediction_valid', prediction_valid_ratio, on_step=False, on_epoch=True, prog_bar=True) |
|
self.log('val_hr', hr, on_step=False, on_epoch=True, prog_bar=True) |
|
self.log('metric', metric, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
def on_test_epoch_start(self): |
|
self.test_content={ |
|
"generate":[], |
|
"real":[], |
|
"cans":[], |
|
} |
|
|
|
@torch.no_grad() |
|
def test_step(self, batch, batch_idx): |
|
generate_output = self.generate(batch) |
|
output=[] |
|
for i,generate in enumerate(generate_output): |
|
real=batch['correct_answer'][i] |
|
cans=batch['cans_name'][i] |
|
generate=generate.strip().split("\n")[0] |
|
output.append((generate,real,cans)) |
|
return output |
|
|
|
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): |
|
for generate,real,cans in outputs: |
|
self.test_content["generate"].append(generate) |
|
self.test_content["real"].append(real) |
|
self.test_content["cans"].append(cans) |
|
|
|
def on_test_epoch_end(self): |
|
df=DataFrame(self.test_content) |
|
if not os.path.exists(self.hparams.output_dir): |
|
os.makedirs(self.hparams.output_dir) |
|
df.to_csv(op.join(self.hparams.output_dir, 'test.csv')) |
|
prediction_valid_ratio,hr=self.calculate_hr1(self.test_content) |
|
metric=hr*prediction_valid_ratio |
|
|
|
self.log('test_prediction_valid', prediction_valid_ratio, on_step=False, on_epoch=True, prog_bar=True) |
|
self.log('test_hr', hr, on_step=False, on_epoch=True, prog_bar=True) |
|
self.log('metric', metric, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
def configure_optimizers(self): |
|
if hasattr(self.hparams, 'weight_decay'): |
|
weight_decay = self.hparams.weight_decay |
|
else: |
|
weight_decay = 0 |
|
optimizer = torch.optim.SGD([ |
|
{'params': self.projector.parameters(), 'lr': self.hparams.lr, 'weight_decay':weight_decay}, |
|
|
|
{'params': self.router.parameters(), 'lr': self.hparams.lr * 0.3, 'weight_decay':weight_decay}, |
|
|
|
{'params': [p for n, p in self.llama_model.named_parameters() if "gating" not in n], 'lr': self.hparams.lr}, |
|
|
|
|
|
|
|
]) |
|
|
|
|
|
for i, param_group in enumerate(optimizer.param_groups): |
|
print(f"Initial LR for group {i}: {param_group['lr']}") |
|
total_params = sum(p.numel() for p in param_group['params']) |
|
print(f"Parameter Group {i}: {total_params} parameters") |
|
|
|
if self.hparams.lr_scheduler is None: |
|
return optimizer |
|
else: |
|
max_step = self.trainer.max_steps |
|
warmup_steps = max_step // 20 |
|
print(f'max_step: {max_step}') |
|
print(f'warmup_steps: {warmup_steps}') |
|
if self.hparams.lr_scheduler == 'cosine': |
|
|
|
init_lr_list = [ |
|
self.hparams.lr, |
|
self.hparams.lr * 0.3, |
|
self.hparams.lr * 1 |
|
] |
|
min_lr_list = [ |
|
self.hparams.lr_decay_min_lr, |
|
self.hparams.lr_decay_min_lr * 0.3, |
|
self.hparams.lr_decay_min_lr * 1 |
|
] |
|
warmup_start_lr_list = [ |
|
self.hparams.lr_warmup_start_lr, |
|
self.hparams.lr_warmup_start_lr * 0.3, |
|
self.hparams.lr_warmup_start_lr * 1 |
|
] |
|
self.scheduler = LinearWarmupCosineLRScheduler( |
|
optimizer=optimizer, |
|
max_step=max_step, |
|
min_lr_list=min_lr_list, |
|
init_lr_list=init_lr_list, |
|
warmup_steps=warmup_steps, |
|
warmup_start_lr_list=warmup_start_lr_list |
|
) |
|
|
|
|
|
for i, param_group in enumerate(optimizer.param_groups): |
|
print(f"Initial LR for group {i}: {param_group['lr']}") |
|
total_params = sum(p.numel() for p in param_group['params']) |
|
print(f"Parameter Group {i}: {total_params} parameters") |
|
|
|
|
|
else: |
|
self.scheduler = None |
|
raise ValueError('Invalid lr_scheduler type!') |
|
return optimizer |
|
|
|
def configure_loss(self, out, labels=None): |
|
loss = self.hparams.loss.lower() |
|
if loss == 'lm': |
|
return out.loss |
|
else: |
|
raise ValueError("Invalid Loss Type!") |
|
|
|
def on_save_checkpoint(self, checkpoint): |
|
if self.hparams.save == 'part': |
|
checkpoint.pop('optimizer_states') |
|
to_be_removed = [] |
|
for key, value in checkpoint['state_dict'].items(): |
|
try: |
|
if not self.get_parameter(key).requires_grad: |
|
to_be_removed.append(key) |
|
except AttributeError: |
|
to_be_removed.append(key) |
|
for key in to_be_removed: |
|
checkpoint['state_dict'].pop(key) |
|
elif self.hparams.save == 'all': |
|
pass |
|
|
|
def load_llm(self, llm_path): |
|
print('Loading LLAMA') |
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llm_path, use_fast=False) |
|
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token |
|
self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
self.llama_tokenizer.padding_side = "right" |
|
self.llama_tokenizer.add_special_tokens({'additional_special_tokens': ['[PH]','[HistoryEmb]','[CansEmb]','[ItemEmb]']}) |
|
self.llama_model = LlamaForCausalLM.from_pretrained(llm_path, device_map="auto",load_in_8bit=True) |
|
self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) |
|
if self.hparams.llm_tuning == 'lora': |
|
if self.hparams.peft_dir: |
|
self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True) |
|
else: |
|
if self.hparams.peft_config: |
|
peft_config = LoraConfig(**LoraConfig.from_json_file(self.hparams.peft_config)) |
|
else: |
|
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
r=self.hparams.lora_r, |
|
lora_alpha=self.hparams.lora_alpha, |
|
lora_dropout=self.hparams.lora_dropout, |
|
target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']) |
|
self.peft_config = peft_config |
|
self.llama_model = get_peft_model(self.llama_model, peft_config) |
|
self.llama_model.print_trainable_parameters() |
|
elif self.hparams.llm_tuning == 'freeze': |
|
for name, param in self.llama_model.named_parameters(): |
|
param.requires_grad = False |
|
elif self.hparams.llm_tuning == 'freeze_lora': |
|
if self.hparams.peft_dir: |
|
self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True) |
|
else: |
|
if self.hparams.peft_config: |
|
peft_config = LoraConfig(**LoraConfig.from_json_file(self.hparams.peft_config)) |
|
else: |
|
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
r=self.hparams.lora_r, |
|
lora_alpha=self.hparams.lora_alpha, |
|
lora_dropout=self.hparams.lora_dropout, |
|
target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']) |
|
self.peft_config = peft_config |
|
self.llama_model = get_peft_model(self.llama_model, peft_config) |
|
for name, param in self.llama_model.named_parameters(): |
|
param.requires_grad = False |
|
self.llama_model.print_trainable_parameters() |
|
elif self.hparams.llm_tuning == 'moelora': |
|
if self.hparams.peft_dir: |
|
self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True) |
|
else: |
|
if self.hparams.peft_config: |
|
peft_config = MoeLoraConfig(**MoeLoraConfig.from_json_file(self.hparams.peft_config)) |
|
else: |
|
peft_config = MoeLoraConfig(task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
r=self.hparams.lora_r, |
|
lora_alpha=self.hparams.lora_alpha, |
|
lora_dropout=self.hparams.lora_dropout, |
|
target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], |
|
num_moe=self.hparams.num_moe, |
|
gating=self.hparams.gating) |
|
self.peft_config = peft_config |
|
self.llama_model = get_peft_model(self.llama_model, peft_config) |
|
|
|
"""for name, param in self.llama_model.named_parameters(): |
|
if "gating" not in name: |
|
param.requires_grad = False""" |
|
self.llama_model.print_trainable_parameters() |
|
else: |
|
raise NotImplementedError() |
|
|
|
print('Loading LLAMA Done') |
|
|
|
def load_projector(self): |
|
name = self.hparams.model_name |
|
camel_name = ''.join([i.capitalize() for i in name.split('_')]) |
|
try: |
|
Model = getattr(importlib.import_module( |
|
'.'+name, package=__package__), camel_name) |
|
except: |
|
raise ValueError( |
|
f'Invalid Module File Name or Invalid Class Name {name}.{camel_name}!') |
|
self.projector = self.instancialize(Model, rec_size=self.hparams.rec_size, llm_size=self.llama_model.config.hidden_size) |
|
|
|
def instancialize(self, Model, **other_args): |
|
class_args = inspect.getargspec(Model.__init__).args[1:] |
|
inkeys = self.hparams.keys() |
|
args1 = {} |
|
for arg in class_args: |
|
if arg in inkeys: |
|
args1[arg] = getattr(self.hparams, arg) |
|
args1.update(other_args) |
|
|
|
return Model(**args1) |
|
|
|
def load_rec_model(self, rec_model_path): |
|
print('Loading Rec Model') |
|
self.rec_model = torch.load(rec_model_path, map_location="cpu") |
|
self.rec_model.eval() |
|
for name, param in self.rec_model.named_parameters(): |
|
param.requires_grad = False |
|
print('Loding Rec model Done') |
|
|
|
def encode_items(self, seq): |
|
if self.hparams.rec_embed=="SASRec": |
|
item_rec_embs=self.rec_model.cacu_x(seq) |
|
elif self.hparams.rec_embed in ['Caser','GRU']: |
|
item_rec_embs=self.rec_model.item_embeddings(seq) |
|
item_txt_embs=self.projector(item_rec_embs) |
|
return item_txt_embs |
|
|
|
def encode_users(self, seq, len_seq): |
|
if self.hparams.rec_embed=="SASRec": |
|
user_rec_embs=self.rec_model.cacul_h(seq, len_seq) |
|
elif self.hparams.rec_embed in ['Caser','GRU']: |
|
user_rec_embs=self.rec_model.item_embeddings(seq) |
|
|
|
user_txt_embs=self.projector(user_rec_embs) |
|
return user_rec_embs |
|
|
|
def embed_tokens(self, token_ids): |
|
embeds = self.llama_model.base_model.embed_tokens(token_ids) |
|
return embeds |
|
|
|
|
|
def wrap_emb(self, batch): |
|
input_embeds = self.llama_model.get_input_embeddings()(batch["tokens"].input_ids) |
|
|
|
|
|
|
|
his_token_id=self.llama_tokenizer("[HistoryEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item() |
|
cans_token_id=self.llama_tokenizer("[CansEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item() |
|
item_token_id=self.llama_tokenizer("[ItemEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item() |
|
|
|
|
|
his_item_embeds = self.encode_items(batch["seq"]) |
|
cans_item_embeds = self.encode_items(batch["cans"]) |
|
item_embeds=self.encode_items(batch["item_id"]) |
|
|
|
|
|
user_embeds=self.encode_users(batch["seq"], batch["len_seq"]) |
|
|
|
for i in range(len(batch["len_seq"])): |
|
if (batch["tokens"].input_ids[i]==his_token_id).nonzero().shape[0]>0: |
|
idx_tensor=(batch["tokens"].input_ids[i]==his_token_id).nonzero().view(-1) |
|
for idx, item_emb in zip(idx_tensor,his_item_embeds[i,:batch["len_seq"][i].item()]): |
|
input_embeds[i,idx]=item_emb |
|
if (batch["tokens"].input_ids[i]==cans_token_id).nonzero().shape[0]>0: |
|
idx_tensor=(batch["tokens"].input_ids[i]==cans_token_id).nonzero().view(-1) |
|
for idx, item_emb in zip(idx_tensor,cans_item_embeds[i,:batch["len_cans"][i].item()]): |
|
input_embeds[i,idx]=item_emb |
|
if (batch["tokens"].input_ids[i]==item_token_id).nonzero().shape[0]>0: |
|
idx=(batch["tokens"].input_ids[i]==item_token_id).nonzero().item() |
|
input_embeds[i,idx]=item_embeds[i] |
|
|
|
return input_embeds, user_embeds |
|
|
|
def calculate_hr1(self,eval_content): |
|
correct_num=0 |
|
valid_num=0 |
|
total_num=0 |
|
for i,generate in enumerate(eval_content["generate"]): |
|
real=eval_content["real"][i] |
|
cans=eval_content["cans"][i] |
|
total_num+=1 |
|
generate=generate.strip().lower().strip() |
|
real=real.strip().lower().strip() |
|
cans=[item.strip().lower().strip() for item in cans] |
|
gen_cans_list=[] |
|
for cans_item in cans: |
|
if cans_item in generate: |
|
gen_cans_list.append(cans_item) |
|
if len(gen_cans_list)==1: |
|
valid_num+=1 |
|
if real == gen_cans_list[0]: |
|
correct_num+=1 |
|
valid_ratio=valid_num/total_num |
|
if valid_num>0: |
|
hr1=correct_num/valid_num |
|
else: |
|
hr1=0 |
|
return valid_ratio,hr1 |
|
|