File size: 8,419 Bytes
9f13819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import inspect
import importlib
import pickle as pkl
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

import random
import torch
import argparse
from transformers import LlamaForCausalLM, LlamaTokenizer
import os



class TrainCollater:
    def __init__(self,
                 prompt_list=None,
                 llm_tokenizer=None,
                 train=False,
                 terminator="\n",
                 max_step=1):
        self.prompt_list = prompt_list
        self.llm_tokenizer = llm_tokenizer
        self.train=train
        self.terminator = terminator
        self.max_step = max_step
        self.cur_step = 1

    def __call__(self, batch):
        if isinstance(self.prompt_list,list):
            instruction = random.choice(self.prompt_list)
            inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch)
        else:
            instruction = sample["instruction_input"] if "instruction_input" in sample else None
            inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch)
 
        thresh_hold = self.cur_step/self.max_step
        p = random.random()
        if p < thresh_hold or not self.train:
            for i, sample in enumerate(batch):
                input_text=inputs_text[i]
                if '[HistoryHere]' in input_text:
                    insert_prompt=", ".join([seq_title+' [HistoryEmb]' for seq_title in sample['seq_name']])
                    input_text=input_text.replace('[HistoryHere]',insert_prompt)
                if '[CansHere]' in input_text:
                    insert_prompt=", ".join([can_title+' [CansEmb]' for can_title in sample['cans_name']])
                    input_text=input_text.replace('[CansHere]',insert_prompt)   
                if '[TargetHere]' in input_text:
                    insert_prompt=insert_prompt=", ".join([sample['correct_answer']+' [ItemEmb]'])
                    input_text=input_text.replace('[TargetHere]',insert_prompt)  
                
                inputs_text[i]=input_text
            flag = False
        else:
            for i, sample in enumerate(batch):
                input_text=inputs_text[i]
                if '[HistoryHere]' in input_text:
                    insert_prompt=", ".join([seq_title+' [PH]' for seq_title in sample['seq_name']])
                    input_text=input_text.replace('[HistoryHere]',insert_prompt)
                if '[CansHere]' in input_text:
                    insert_prompt=", ".join([can_title+' [PH]' for can_title in sample['cans_name']])
                    input_text=input_text.replace('[CansHere]',insert_prompt)   
                
                inputs_text[i]=input_text
            flag = True
        self.cur_step += 1
        
        targets_text = [sample['correct_answer'] for sample in batch]

        if self.train:
            targets_text=[target_text+self.terminator for target_text in targets_text]
            inputs_pair = [[p, t] for p, t in zip(inputs_text, targets_text)]

            batch_tokens = self.llm_tokenizer(
                inputs_pair,
                return_tensors="pt",
                padding="longest",
                truncation=False,
                add_special_tokens=True,
                return_attention_mask=True,
                return_token_type_ids=True)
            new_batch={"tokens":batch_tokens,
                       "seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0),
                       "cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0),
                       "len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0),
                       "len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0),
                       "item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0),
                       "flag":flag,
                       }
            
            
        else:
            batch_tokens = self.llm_tokenizer(
                inputs_text,
                return_tensors="pt",
                padding="longest",
                truncation=False,
                add_special_tokens=True,
                return_token_type_ids=True)
            cans_name=[sample['cans_name'] for sample in batch]
            new_batch={"tokens":batch_tokens,
                       "seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0),
                       "cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0),
                       "len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0),
                       "len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0),
                       "item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0),
                       "correct_answer": targets_text,
                       "cans_name": cans_name,
            
                       }
            
        return new_batch

class DInterface(pl.LightningDataModule):

    def __init__(self, 
                 llm_tokenizer=None,
                 num_workers=8,
                 dataset='',
                 **kwargs):
        super().__init__()
        self.num_workers = num_workers
        self.llm_tokenizer=llm_tokenizer
        self.dataset = dataset
        self.kwargs = kwargs
        self.batch_size = kwargs['batch_size']
        self.max_epochs = kwargs['max_epochs']
        self.load_data_module()
        self.load_prompt(kwargs['prompt_path'])

        self.trainset = self.instancialize(stage='train')
        self.valset = self.instancialize(stage='val')
        self.testset = self.instancialize(stage='test')
        self.max_steps = self.max_epochs*(len(self.trainset)//self.batch_size)//self.num_workers

    def train_dataloader(self):
        return DataLoader(self.trainset,
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=True,
                          drop_last=True,
                          collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=True, max_step=self.max_steps))

    def val_dataloader(self):
        return DataLoader(self.valset, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=False,
                          collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False))

    def test_dataloader(self):
        return DataLoader(self.testset, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=False,
                          
                          drop_last=True,
                          
                          collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False))

    def load_data_module(self):
        name = self.dataset
        camel_name = ''.join([i.capitalize() for i in name.split('_')])
        try:
            self.data_module = getattr(importlib.import_module(
                '.'+name, package=__package__), camel_name)
        except:
            raise ValueError(
                f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')

    def instancialize(self, **other_args):
        class_args = inspect.getargspec(self.data_module.__init__).args[1:]
        inkeys = self.kwargs.keys()
        args1 = {}
        for arg in class_args:
            if arg in inkeys:
                args1[arg] = self.kwargs[arg]
        args1.update(other_args)
        return self.data_module(**args1)
    
    def load_prompt(self,prompt_path):
        if os.path.isfile(prompt_path):
            with open(prompt_path, 'r') as f:
                raw_prompts = f.read().splitlines()
            self.prompt_list = [p.strip() for p in raw_prompts]
            print('Load {} training prompts'.format(len(self.prompt_list)))
            print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
        else:
            self.prompt_list = []