# coding=utf8 from torch.utils.data import Dataset, DataLoader from tqdm import tqdm from transformers import AutoTokenizer import json import torch import pytorch_lightning as pl import os class AbstractCollator: """ collector for summary task """ def __init__(self, tokenizer, max_enc_length, max_dec_length, prompt): self.tokenizer = tokenizer self.max_enc_length = max_enc_length self.max_dec_length = max_dec_length self.prompt = prompt def __call__(self, samples): labels = [] attn_mask = [] # decoder_attn_mask = [] source_inputs = [] for sample in samples: encode_dict = self.tokenizer.encode_plus( self.prompt + sample['text'], max_length=self.max_enc_length, padding='max_length', truncation=True, return_tensors='pt') decode_dict = self.tokenizer.encode_plus( sample['summary'], max_length=self.max_dec_length, padding='max_length', truncation=True, return_tensors='pt') source_inputs.append(encode_dict['input_ids'].squeeze()) labels.append(decode_dict['input_ids'].squeeze()) attn_mask.append(encode_dict['attention_mask'].squeeze()) # decoder_attn_mask.append(decode_dict['attention_mask'].squeeze()) # labels = torch.tensor(decode_dict['input']) source_inputs = torch.stack(source_inputs) labels = torch.stack(labels) attn_mask = torch.stack(attn_mask) # decoder_attn_mask = torch.stack(decoder_attn_mask) # decode_input_idxs = shift_tokens_right(labels, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id) end_token_index = torch.where(labels == self.tokenizer.eos_token_id)[1] for idx, end_idx in enumerate(end_token_index): labels[idx][end_idx + 1:] = -100 return { "input_ids": source_inputs, "attention_mask": attn_mask, "labels": labels, "text": [sample['text'] for sample in samples], "summary": [sample['summary'] for sample in samples] } class LCSTSDataset(Dataset): ''' Dataset Used for LCSTS summary task. ''' def __init__(self, data_path, args): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_path, use_fast=False) self.data = self.load_data(data_path) self.prompt = args.prompt self.max_enc_length = args.max_enc_length self.max_dec_length = args.max_dec_length def __len__(self): return len(self.data) def __getitem__(self, index): return self.encode(self.data[index]) def load_data(self, data_path): with open(data_path, "r", encoding='utf8') as f: lines = f.readlines() samples = [] for line in tqdm(lines): obj = json.loads(line) source = obj['text'] target = obj['summary'] samples.append({ "text": source, "summary": target }) return samples def cal_data(self, data_path): with open(data_path, "r", encoding='utf8') as f: lines = f.readlines() samples = [] enc_sizes = [] dec_sizes = [] for line in tqdm(lines): obj = json.loads(line.strip()) source = obj['text'] target = obj['summary'] enc_input_ids = self.tokenizer.encode(source) target = self.tokenizer.encode(target) enc_sizes.append(len(enc_input_ids)) dec_sizes.append(len(target)-1) samples.append({ "enc_input_ids": enc_input_ids, "dec_input_ids": target[:-1], "label_ids": target[1:] }) max_enc_len = max(enc_sizes) max_dec_len = max(dec_sizes) import numpy as np # mean of len(enc_input_ids): 74.68041911345998 # mean of len(dec_input_ids): 14.02265483791283 # max of len(enc_input_ids): 132 # max of len(dec_input_ids): 31 print('mean of len(enc_input_ids):', np.mean(enc_sizes), 'mean of len(dec_input_ids):', np.mean(dec_sizes), 'max of len(enc_input_ids):', max_enc_len, 'max of len(dec_input_ids):', max_dec_len) return samples def encode(self, item): encode_dict = self.tokenizer.encode_plus( self.prompt + item['text'], max_length=self.max_enc_length, padding='max_length', truncation=True, return_tensors='pt') decode_dict = self.tokenizer.encode_plus( item['summary'], max_length=self.max_dec_length, padding='max_length', truncation=True) target = decode_dict['input_ids'] # print('encode_dict shape:', encode_dict['input_ids'].shape) labels = torch.tensor(target) labels[target == self.tokenizer.pad_token_id] = -100 return { "input_ids": encode_dict['input_ids'].squeeze(), "attention_mask": encode_dict['attention_mask'].squeeze(), "labels": labels.squeeze(), "text": item['text'], "summary": item['summary'] } class LCSTSDataModel(pl.LightningDataModule): @staticmethod def add_data_specific_args(parent_args): parser = parent_args.add_argument_group('LCSTSDataModel') parser.add_argument( '--data_dir', default='/cognitive_comp/ganruyi/data_datasets_LCSTS_LCSTS/', type=str) parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--train_data', default='train.jsonl', type=str) parser.add_argument('--valid_data', default='valid.jsonl', type=str) parser.add_argument('--test_data', default='test_public.jsonl', type=str) parser.add_argument('--train_batchsize', default=128, type=int) parser.add_argument('--valid_batchsize', default=128, type=int) parser.add_argument('--max_enc_length', default=128, type=int) parser.add_argument('--max_dec_length', default=30, type=int) parser.add_argument('--prompt', default='summarize:', type=str) return parent_args def __init__(self, args): super().__init__() self.args = args self.train_batchsize = args.train_batchsize self.valid_batchsize = args.valid_batchsize if not args.do_eval_only: self.train_data = LCSTSDataset(os.path.join( args.data_dir, args.train_data), args) self.valid_data = LCSTSDataset(os.path.join( args.data_dir, args.valid_data), args) self.test_data = LCSTSDataset(os.path.join( args.data_dir, args.test_data), args) def train_dataloader(self): return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batchsize, pin_memory=False, num_workers=self.args.num_workers) def val_dataloader(self): return DataLoader(self.valid_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, num_workers=self.args.num_workers) def predict_dataloader(self): return DataLoader(self.test_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, num_workers=self.args.num_workers)