|
import os |
|
import pytorch_lightning as pl |
|
from argparse import ArgumentParser |
|
from pytorch_lightning import Trainer |
|
import pytorch_lightning.callbacks as plc |
|
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger |
|
|
|
from model.model_interface import MInterface |
|
from data.data_interface import DInterface |
|
from recommender.A_SASRec_final_bce_llm import SASRec, Caser, GRU |
|
from SASRecModules_ori import * |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
|
|
def load_callbacks(args): |
|
callbacks = [] |
|
callbacks.append(plc.EarlyStopping( |
|
monitor='metric', |
|
mode='max', |
|
patience=10, |
|
min_delta=0.001 |
|
)) |
|
|
|
callbacks.append(plc.ModelCheckpoint( |
|
monitor='metric', |
|
dirpath=args.ckpt_dir, |
|
filename='{epoch:02d}-{metric:.3f}', |
|
save_top_k=-1, |
|
mode='max', |
|
save_last=True, |
|
|
|
every_n_epochs=1 |
|
)) |
|
|
|
if args.lr_scheduler: |
|
callbacks.append(plc.LearningRateMonitor( |
|
logging_interval='step')) |
|
return callbacks |
|
|
|
def main(args): |
|
pl.seed_everything(args.seed) |
|
model = MInterface(**vars(args)) |
|
if args.ckpt_path: |
|
ckpt = torch.load(args.ckpt_path, map_location='cpu') |
|
|
|
model.load_state_dict(ckpt['state_dict'], strict=False) |
|
print("load checkpoints from {}".format(args.ckpt_path)) |
|
|
|
data_module = DInterface(llm_tokenizer=model.llama_tokenizer,**vars(args)) |
|
|
|
args.max_steps=len(data_module.trainset) * args.max_epochs // (args.accumulate_grad_batches * args.batch_size) |
|
logger = TensorBoardLogger(save_dir='./log/', name=args.log_dir) |
|
args.callbacks = load_callbacks(args) |
|
args.logger = logger |
|
if not os.path.exists(args.ckpt_dir): |
|
os.makedirs(args.ckpt_dir) |
|
|
|
trainer = Trainer.from_argparse_args(args) |
|
|
|
if args.auto_lr_find: |
|
lr_finder=trainer.tuner.lr_find(model=model, datamodule=data_module, min_lr=1e-10, max_lr=1e-3, num_training=100) |
|
fig=lr_finder.plot(suggest=True) |
|
fig_path="lr_finder.png" |
|
fig.savefig(fig_path) |
|
print("Saving to {}".format(fig_path)) |
|
model.hparams.lr=lr_finder.suggestion() |
|
|
|
if args.mode == 'train': |
|
trainer.fit(model=model, datamodule=data_module) |
|
else: |
|
trainer.test(model=model, datamodule=data_module) |
|
|
|
|
|
if __name__ == '__main__': |
|
torch.multiprocessing.set_start_method('spawn') |
|
parser = ArgumentParser() |
|
|
|
parser.add_argument('--accelerator', default='gpu', type=str) |
|
parser.add_argument('--devices', default=-1, type=list) |
|
parser.add_argument('--precision', default=16, type=int) |
|
parser.add_argument('--amp_backend', default="native", type=str) |
|
|
|
parser.add_argument('--batch_size', default=4, type=int) |
|
parser.add_argument('--num_workers', default=8, type=int) |
|
parser.add_argument('--seed', default=1234, type=int) |
|
parser.add_argument('--lr', default=1e-4, type=float) |
|
parser.add_argument('--accumulate_grad_batches', default=32, type=int) |
|
parser.add_argument('--check_val_every_n_epoch', default=1, type=int) |
|
|
|
parser.add_argument('--lr_scheduler', default='cosine', choices=['cosine'], type=str) |
|
parser.add_argument('--lr_decay_min_lr', default=1e-6, type=float) |
|
parser.add_argument('--lr_warmup_start_lr', default=1e-6, type=float) |
|
|
|
parser.add_argument('--load_best', action='store_true') |
|
parser.add_argument('--load_dir', default=None, type=str) |
|
parser.add_argument('--load_ver', default=None, type=str) |
|
parser.add_argument('--load_v_num', default=None, type=int) |
|
|
|
parser.add_argument('--dataset', default='steam_data', type=str) |
|
parser.add_argument('--data_dir', default='LLaRA_MOE/data/ref/steam', type=str) |
|
parser.add_argument('--model_name', default='mlp_projector', type=str) |
|
parser.add_argument('--loss', default='lm', type=str) |
|
parser.add_argument('--weight_decay', default=1e-5, type=float) |
|
parser.add_argument('--no_augment', action='store_true') |
|
parser.add_argument('--ckpt_dir', default='LLaRA_MOE/checkpoints/steam/', type=str) |
|
parser.add_argument('--log_dir', default='steam_logs', type=str) |
|
|
|
parser.add_argument('--rec_size', default=64, type=int) |
|
parser.add_argument('--padding_item_id', default=3581, type=int) |
|
parser.add_argument('--llm_path', default='meta-llama/Llama-2-7b-hf', type=str) |
|
parser.add_argument('--rec_model_path', default='LLaRA_MOE/rec_model/SASRec_steam.pt', type=str) |
|
parser.add_argument('--prompt_path', default='LLaRA_MOE/prompt/game.txt', type=str) |
|
parser.add_argument('--output_dir', default='LLaRA_MOE/output/steam_moe/', type=str) |
|
parser.add_argument('--ckpt_path', type=str) |
|
parser.add_argument('--rec_embed', default="SASRec", choices=['SASRec', 'Caser','GRU'], type=str) |
|
|
|
parser.add_argument('--aug_prob', default=0.5, type=float) |
|
parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str) |
|
parser.add_argument('--auto_lr_find', default=False, action='store_true') |
|
parser.add_argument('--metric', default='hr', choices=['hr'], type=str) |
|
parser.add_argument('--max_epochs', default=5, type=int) |
|
parser.add_argument('--save', default='part', choices=['part', 'all'], type=str) |
|
parser.add_argument('--cans_num', default=20, type=int) |
|
|
|
|
|
parser.add_argument('--llm_tuning', default='moelora', choices=['lora', 'freeze','freeze_lora', 'moelora'], type=str) |
|
parser.add_argument('--peft_dir', default=None, type=str) |
|
parser.add_argument('--peft_config', default=None, type=str) |
|
parser.add_argument('--lora_r', default=8, type=float) |
|
parser.add_argument('--lora_alpha', default=32, type=float) |
|
parser.add_argument('--lora_dropout', default=0.1, type=float) |
|
parser.add_argument('--num_moe', default=4, type=int) |
|
parser.add_argument('--gating', default='Dense', type=str) |
|
|
|
parser.add_argument('--local_rank', default=3, type=int) |
|
|
|
parser.add_argument('--if_rand', default=False, type=bool) |
|
|
|
parser.add_argument('--router', default='unshare', choices=['share', 'unshare'], type=str) |
|
|
|
args = parser.parse_args() |
|
|
|
if 'movielens' in args.data_dir: |
|
args.padding_item_id = 1682 |
|
elif 'steam' in args.data_dir: |
|
args.padding_item_id = 3581 |
|
elif 'lastfm' in args.data_dir: |
|
args.padding_item_id = 4606 |
|
|
|
main(args) |
|
|