Spaces:
Runtime error
Runtime error
# coding=utf8 | |
from torch.utils.data import Dataset, DataLoader | |
from tqdm import tqdm | |
from transformers import AutoTokenizer | |
import json | |
import torch | |
import pytorch_lightning as pl | |
import os | |
class AbstractCollator: | |
""" | |
collector for summary task | |
""" | |
def __init__(self, tokenizer, max_enc_length, max_dec_length, prompt): | |
self.tokenizer = tokenizer | |
self.max_enc_length = max_enc_length | |
self.max_dec_length = max_dec_length | |
self.prompt = prompt | |
def __call__(self, samples): | |
labels = [] | |
attn_mask = [] | |
# decoder_attn_mask = [] | |
source_inputs = [] | |
for sample in samples: | |
encode_dict = self.tokenizer.encode_plus( | |
self.prompt + sample['text'], | |
max_length=self.max_enc_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt') | |
decode_dict = self.tokenizer.encode_plus( | |
sample['summary'], | |
max_length=self.max_dec_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt') | |
source_inputs.append(encode_dict['input_ids'].squeeze()) | |
labels.append(decode_dict['input_ids'].squeeze()) | |
attn_mask.append(encode_dict['attention_mask'].squeeze()) | |
# decoder_attn_mask.append(decode_dict['attention_mask'].squeeze()) | |
# labels = torch.tensor(decode_dict['input']) | |
source_inputs = torch.stack(source_inputs) | |
labels = torch.stack(labels) | |
attn_mask = torch.stack(attn_mask) | |
# decoder_attn_mask = torch.stack(decoder_attn_mask) | |
# decode_input_idxs = shift_tokens_right(labels, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id) | |
end_token_index = torch.where(labels == self.tokenizer.eos_token_id)[1] | |
for idx, end_idx in enumerate(end_token_index): | |
labels[idx][end_idx + 1:] = -100 | |
return { | |
"input_ids": source_inputs, | |
"attention_mask": attn_mask, | |
"labels": labels, | |
"text": [sample['text'] for sample in samples], | |
"summary": [sample['summary'] for sample in samples] | |
} | |
class LCSTSDataset(Dataset): | |
''' | |
Dataset Used for LCSTS summary task. | |
''' | |
def __init__(self, data_path, args): | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
args.pretrained_model_path, use_fast=False) | |
self.data = self.load_data(data_path) | |
self.prompt = args.prompt | |
self.max_enc_length = args.max_enc_length | |
self.max_dec_length = args.max_dec_length | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
return self.encode(self.data[index]) | |
def load_data(self, data_path): | |
with open(data_path, "r", encoding='utf8') as f: | |
lines = f.readlines() | |
samples = [] | |
for line in tqdm(lines): | |
obj = json.loads(line) | |
source = obj['text'] | |
target = obj['summary'] | |
samples.append({ | |
"text": source, | |
"summary": target | |
}) | |
return samples | |
def cal_data(self, data_path): | |
with open(data_path, "r", encoding='utf8') as f: | |
lines = f.readlines() | |
samples = [] | |
enc_sizes = [] | |
dec_sizes = [] | |
for line in tqdm(lines): | |
obj = json.loads(line.strip()) | |
source = obj['text'] | |
target = obj['summary'] | |
enc_input_ids = self.tokenizer.encode(source) | |
target = self.tokenizer.encode(target) | |
enc_sizes.append(len(enc_input_ids)) | |
dec_sizes.append(len(target)-1) | |
samples.append({ | |
"enc_input_ids": enc_input_ids, | |
"dec_input_ids": target[:-1], | |
"label_ids": target[1:] | |
}) | |
max_enc_len = max(enc_sizes) | |
max_dec_len = max(dec_sizes) | |
import numpy as np | |
# mean of len(enc_input_ids): 74.68041911345998 | |
# mean of len(dec_input_ids): 14.02265483791283 | |
# max of len(enc_input_ids): 132 | |
# max of len(dec_input_ids): 31 | |
print('mean of len(enc_input_ids):', np.mean(enc_sizes), | |
'mean of len(dec_input_ids):', np.mean(dec_sizes), | |
'max of len(enc_input_ids):', max_enc_len, | |
'max of len(dec_input_ids):', max_dec_len) | |
return samples | |
def encode(self, item): | |
encode_dict = self.tokenizer.encode_plus( | |
self.prompt + item['text'], | |
max_length=self.max_enc_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt') | |
decode_dict = self.tokenizer.encode_plus( | |
item['summary'], | |
max_length=self.max_dec_length, | |
padding='max_length', | |
truncation=True) | |
target = decode_dict['input_ids'] | |
# print('encode_dict shape:', encode_dict['input_ids'].shape) | |
labels = torch.tensor(target) | |
labels[target == self.tokenizer.pad_token_id] = -100 | |
return { | |
"input_ids": encode_dict['input_ids'].squeeze(), | |
"attention_mask": encode_dict['attention_mask'].squeeze(), | |
"labels": labels.squeeze(), | |
"text": item['text'], | |
"summary": item['summary'] | |
} | |
class LCSTSDataModel(pl.LightningDataModule): | |
def add_data_specific_args(parent_args): | |
parser = parent_args.add_argument_group('LCSTSDataModel') | |
parser.add_argument( | |
'--data_dir', default='/cognitive_comp/ganruyi/data_datasets_LCSTS_LCSTS/', type=str) | |
parser.add_argument('--num_workers', default=8, type=int) | |
parser.add_argument('--train_data', default='train.jsonl', type=str) | |
parser.add_argument('--valid_data', default='valid.jsonl', type=str) | |
parser.add_argument('--test_data', default='test_public.jsonl', type=str) | |
parser.add_argument('--train_batchsize', default=128, type=int) | |
parser.add_argument('--valid_batchsize', default=128, type=int) | |
parser.add_argument('--max_enc_length', default=128, type=int) | |
parser.add_argument('--max_dec_length', default=30, type=int) | |
parser.add_argument('--prompt', default='summarize:', type=str) | |
return parent_args | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.train_batchsize = args.train_batchsize | |
self.valid_batchsize = args.valid_batchsize | |
if not args.do_eval_only: | |
self.train_data = LCSTSDataset(os.path.join( | |
args.data_dir, args.train_data), args) | |
self.valid_data = LCSTSDataset(os.path.join( | |
args.data_dir, args.valid_data), args) | |
self.test_data = LCSTSDataset(os.path.join( | |
args.data_dir, args.test_data), args) | |
def train_dataloader(self): | |
return DataLoader(self.train_data, | |
shuffle=True, | |
batch_size=self.train_batchsize, | |
pin_memory=False, | |
num_workers=self.args.num_workers) | |
def val_dataloader(self): | |
return DataLoader(self.valid_data, | |
shuffle=False, | |
batch_size=self.valid_batchsize, | |
pin_memory=False, | |
num_workers=self.args.num_workers) | |
def predict_dataloader(self): | |
return DataLoader(self.test_data, | |
shuffle=False, | |
batch_size=self.valid_batchsize, | |
pin_memory=False, | |
num_workers=self.args.num_workers) | |