iLoRA / data /data_interface.py
MingLi
fork and bug fix from https://github.com/AkaliKong/iLoRA
9f13819
import inspect
import importlib
import pickle as pkl
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import random
import torch
import argparse
from transformers import LlamaForCausalLM, LlamaTokenizer
import os
class TrainCollater:
def __init__(self,
prompt_list=None,
llm_tokenizer=None,
train=False,
terminator="\n",
max_step=1):
self.prompt_list = prompt_list
self.llm_tokenizer = llm_tokenizer
self.train=train
self.terminator = terminator
self.max_step = max_step
self.cur_step = 1
def __call__(self, batch):
if isinstance(self.prompt_list,list):
instruction = random.choice(self.prompt_list)
inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch)
else:
instruction = sample["instruction_input"] if "instruction_input" in sample else None
inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch)
thresh_hold = self.cur_step/self.max_step
p = random.random()
if p < thresh_hold or not self.train:
for i, sample in enumerate(batch):
input_text=inputs_text[i]
if '[HistoryHere]' in input_text:
insert_prompt=", ".join([seq_title+' [HistoryEmb]' for seq_title in sample['seq_name']])
input_text=input_text.replace('[HistoryHere]',insert_prompt)
if '[CansHere]' in input_text:
insert_prompt=", ".join([can_title+' [CansEmb]' for can_title in sample['cans_name']])
input_text=input_text.replace('[CansHere]',insert_prompt)
if '[TargetHere]' in input_text:
insert_prompt=insert_prompt=", ".join([sample['correct_answer']+' [ItemEmb]'])
input_text=input_text.replace('[TargetHere]',insert_prompt)
inputs_text[i]=input_text
flag = False
else:
for i, sample in enumerate(batch):
input_text=inputs_text[i]
if '[HistoryHere]' in input_text:
insert_prompt=", ".join([seq_title+' [PH]' for seq_title in sample['seq_name']])
input_text=input_text.replace('[HistoryHere]',insert_prompt)
if '[CansHere]' in input_text:
insert_prompt=", ".join([can_title+' [PH]' for can_title in sample['cans_name']])
input_text=input_text.replace('[CansHere]',insert_prompt)
inputs_text[i]=input_text
flag = True
self.cur_step += 1
targets_text = [sample['correct_answer'] for sample in batch]
if self.train:
targets_text=[target_text+self.terminator for target_text in targets_text]
inputs_pair = [[p, t] for p, t in zip(inputs_text, targets_text)]
batch_tokens = self.llm_tokenizer(
inputs_pair,
return_tensors="pt",
padding="longest",
truncation=False,
add_special_tokens=True,
return_attention_mask=True,
return_token_type_ids=True)
new_batch={"tokens":batch_tokens,
"seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0),
"cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0),
"len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0),
"len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0),
"item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0),
"flag":flag,
}
else:
batch_tokens = self.llm_tokenizer(
inputs_text,
return_tensors="pt",
padding="longest",
truncation=False,
add_special_tokens=True,
return_token_type_ids=True)
cans_name=[sample['cans_name'] for sample in batch]
new_batch={"tokens":batch_tokens,
"seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0),
"cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0),
"len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0),
"len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0),
"item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0),
"correct_answer": targets_text,
"cans_name": cans_name,
}
return new_batch
class DInterface(pl.LightningDataModule):
def __init__(self,
llm_tokenizer=None,
num_workers=8,
dataset='',
**kwargs):
super().__init__()
self.num_workers = num_workers
self.llm_tokenizer=llm_tokenizer
self.dataset = dataset
self.kwargs = kwargs
self.batch_size = kwargs['batch_size']
self.max_epochs = kwargs['max_epochs']
self.load_data_module()
self.load_prompt(kwargs['prompt_path'])
self.trainset = self.instancialize(stage='train')
self.valset = self.instancialize(stage='val')
self.testset = self.instancialize(stage='test')
self.max_steps = self.max_epochs*(len(self.trainset)//self.batch_size)//self.num_workers
def train_dataloader(self):
return DataLoader(self.trainset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
drop_last=True,
collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=True, max_step=self.max_steps))
def val_dataloader(self):
return DataLoader(self.valset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False))
def test_dataloader(self):
return DataLoader(self.testset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
drop_last=True,
collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False))
def load_data_module(self):
name = self.dataset
camel_name = ''.join([i.capitalize() for i in name.split('_')])
try:
self.data_module = getattr(importlib.import_module(
'.'+name, package=__package__), camel_name)
except:
raise ValueError(
f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')
def instancialize(self, **other_args):
class_args = inspect.getargspec(self.data_module.__init__).args[1:]
inkeys = self.kwargs.keys()
args1 = {}
for arg in class_args:
if arg in inkeys:
args1[arg] = self.kwargs[arg]
args1.update(other_args)
return self.data_module(**args1)
def load_prompt(self,prompt_path):
if os.path.isfile(prompt_path):
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
self.prompt_list = [p.strip() for p in raw_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []