|
import inspect |
|
import importlib |
|
import pickle as pkl |
|
import pytorch_lightning as pl |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.sampler import WeightedRandomSampler |
|
|
|
import random |
|
import torch |
|
import argparse |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
import os |
|
|
|
|
|
|
|
class TrainCollater: |
|
def __init__(self, |
|
prompt_list=None, |
|
llm_tokenizer=None, |
|
train=False, |
|
terminator="\n", |
|
max_step=1): |
|
self.prompt_list = prompt_list |
|
self.llm_tokenizer = llm_tokenizer |
|
self.train=train |
|
self.terminator = terminator |
|
self.max_step = max_step |
|
self.cur_step = 1 |
|
|
|
def __call__(self, batch): |
|
if isinstance(self.prompt_list,list): |
|
instruction = random.choice(self.prompt_list) |
|
inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch) |
|
else: |
|
instruction = sample["instruction_input"] if "instruction_input" in sample else None |
|
inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch) |
|
|
|
thresh_hold = self.cur_step/self.max_step |
|
p = random.random() |
|
if p < thresh_hold or not self.train: |
|
for i, sample in enumerate(batch): |
|
input_text=inputs_text[i] |
|
if '[HistoryHere]' in input_text: |
|
insert_prompt=", ".join([seq_title+' [HistoryEmb]' for seq_title in sample['seq_name']]) |
|
input_text=input_text.replace('[HistoryHere]',insert_prompt) |
|
if '[CansHere]' in input_text: |
|
insert_prompt=", ".join([can_title+' [CansEmb]' for can_title in sample['cans_name']]) |
|
input_text=input_text.replace('[CansHere]',insert_prompt) |
|
if '[TargetHere]' in input_text: |
|
insert_prompt=insert_prompt=", ".join([sample['correct_answer']+' [ItemEmb]']) |
|
input_text=input_text.replace('[TargetHere]',insert_prompt) |
|
|
|
inputs_text[i]=input_text |
|
flag = False |
|
else: |
|
for i, sample in enumerate(batch): |
|
input_text=inputs_text[i] |
|
if '[HistoryHere]' in input_text: |
|
insert_prompt=", ".join([seq_title+' [PH]' for seq_title in sample['seq_name']]) |
|
input_text=input_text.replace('[HistoryHere]',insert_prompt) |
|
if '[CansHere]' in input_text: |
|
insert_prompt=", ".join([can_title+' [PH]' for can_title in sample['cans_name']]) |
|
input_text=input_text.replace('[CansHere]',insert_prompt) |
|
|
|
inputs_text[i]=input_text |
|
flag = True |
|
self.cur_step += 1 |
|
|
|
targets_text = [sample['correct_answer'] for sample in batch] |
|
|
|
if self.train: |
|
targets_text=[target_text+self.terminator for target_text in targets_text] |
|
inputs_pair = [[p, t] for p, t in zip(inputs_text, targets_text)] |
|
|
|
batch_tokens = self.llm_tokenizer( |
|
inputs_pair, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=False, |
|
add_special_tokens=True, |
|
return_attention_mask=True, |
|
return_token_type_ids=True) |
|
new_batch={"tokens":batch_tokens, |
|
"seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0), |
|
"cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0), |
|
"len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0), |
|
"len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0), |
|
"item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0), |
|
"flag":flag, |
|
} |
|
|
|
|
|
else: |
|
batch_tokens = self.llm_tokenizer( |
|
inputs_text, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=False, |
|
add_special_tokens=True, |
|
return_token_type_ids=True) |
|
cans_name=[sample['cans_name'] for sample in batch] |
|
new_batch={"tokens":batch_tokens, |
|
"seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0), |
|
"cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0), |
|
"len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0), |
|
"len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0), |
|
"item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0), |
|
"correct_answer": targets_text, |
|
"cans_name": cans_name, |
|
|
|
} |
|
|
|
return new_batch |
|
|
|
class DInterface(pl.LightningDataModule): |
|
|
|
def __init__(self, |
|
llm_tokenizer=None, |
|
num_workers=8, |
|
dataset='', |
|
**kwargs): |
|
super().__init__() |
|
self.num_workers = num_workers |
|
self.llm_tokenizer=llm_tokenizer |
|
self.dataset = dataset |
|
self.kwargs = kwargs |
|
self.batch_size = kwargs['batch_size'] |
|
self.max_epochs = kwargs['max_epochs'] |
|
self.load_data_module() |
|
self.load_prompt(kwargs['prompt_path']) |
|
|
|
self.trainset = self.instancialize(stage='train') |
|
self.valset = self.instancialize(stage='val') |
|
self.testset = self.instancialize(stage='test') |
|
self.max_steps = self.max_epochs*(len(self.trainset)//self.batch_size)//self.num_workers |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.trainset, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
shuffle=True, |
|
drop_last=True, |
|
collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=True, max_step=self.max_steps)) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.valset, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
shuffle=False, |
|
collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False)) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.testset, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
shuffle=False, |
|
|
|
drop_last=True, |
|
|
|
collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False)) |
|
|
|
def load_data_module(self): |
|
name = self.dataset |
|
camel_name = ''.join([i.capitalize() for i in name.split('_')]) |
|
try: |
|
self.data_module = getattr(importlib.import_module( |
|
'.'+name, package=__package__), camel_name) |
|
except: |
|
raise ValueError( |
|
f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') |
|
|
|
def instancialize(self, **other_args): |
|
class_args = inspect.getargspec(self.data_module.__init__).args[1:] |
|
inkeys = self.kwargs.keys() |
|
args1 = {} |
|
for arg in class_args: |
|
if arg in inkeys: |
|
args1[arg] = self.kwargs[arg] |
|
args1.update(other_args) |
|
return self.data_module(**args1) |
|
|
|
def load_prompt(self,prompt_path): |
|
if os.path.isfile(prompt_path): |
|
with open(prompt_path, 'r') as f: |
|
raw_prompts = f.read().splitlines() |
|
self.prompt_list = [p.strip() for p in raw_prompts] |
|
print('Load {} training prompts'.format(len(self.prompt_list))) |
|
print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) |
|
else: |
|
self.prompt_list = [] |
|
|