File size: 6,898 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from pytorch_lightning import LightningDataModule
from typing import Optional
from torch.utils.data import DataLoader, DistributedSampler
def get_consume_samples(data_model: LightningDataModule) -> int:
if hasattr(data_model.trainer.lightning_module, 'consumed_samples'):
consumed_samples = data_model.trainer.lightning_module.consumed_samples
print('get consumed samples from model: {}'.format(consumed_samples))
else:
world_size = data_model.trainer.world_size
consumed_samples = max(0, data_model.trainer.global_step - 1) * \
data_model.hparams.train_batchsize * world_size * data_model.trainer.accumulate_grad_batches
print('calculate consumed samples: {}'.format(consumed_samples))
return consumed_samples
class UniversalDataModule(LightningDataModule):
@ staticmethod
def add_data_specific_args(parent_args):
parser = parent_args.add_argument_group('Universal DataModule')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--dataloader_workers', default=2, type=int)
parser.add_argument('--train_batchsize', default=16, type=int)
parser.add_argument('--val_batchsize', default=16, type=int)
parser.add_argument('--test_batchsize', default=16, type=int)
parser.add_argument('--datasets_name', type=str, default=None)
parser.add_argument('--train_datasets_field', type=str, default='train')
parser.add_argument('--val_datasets_field', type=str, default='validation')
parser.add_argument('--test_datasets_field', type=str, default='test')
parser.add_argument('--train_file', type=str, default=None)
parser.add_argument('--val_file', type=str, default=None)
parser.add_argument('--test_file', type=str, default=None)
parser.add_argument('--raw_file_type', type=str, default='json')
parser.add_argument('--sampler_type', type=str,
choices=['single',
'random'],
default='random')
return parent_args
def __init__(
self,
tokenizer,
collate_fn,
args,
datasets=None,
**kwargs,
):
super().__init__()
# 如果不传入datasets的名字,则可以在对象外部替换内部的datasets为模型需要的
if datasets is not None:
self.datasets = datasets
elif args.datasets_name is not None:
from fengshen.data.fs_datasets import load_dataset
print('---------begin to load datasets {}'.format(args.datasets_name))
self.datasets = load_dataset(
args.datasets_name, num_proc=args.num_workers)
print('---------ending load datasets {}'.format(args.datasets_name))
else:
print('---------begin to load datasets from local file')
from datasets import load_dataset
self.datasets = load_dataset(args.raw_file_type,
data_files={
args.train_datasets_field: args.train_file,
args.val_datasets_field: args.val_file,
args.test_datasets_field: args.test_file})
print('---------end to load datasets from local file')
self.tokenizer = tokenizer
self.collate_fn = collate_fn
self.save_hyperparameters(args)
def get_custom_sampler(self, ds):
from .universal_sampler import PretrainingRandomSampler
from .universal_sampler import PretrainingSampler
world_size = self.trainer.world_size
consumed_samples = get_consume_samples(self)
# use the user default sampler
if self.hparams.sampler_type == 'random':
return PretrainingRandomSampler(
total_samples=len(ds),
# consumed_samples cal by global steps
consumed_samples=consumed_samples,
micro_batch_size=self.hparams.train_batchsize,
data_parallel_rank=self.trainer.global_rank,
data_parallel_size=world_size,
epoch=self.trainer.current_epoch,
)
elif self.hparams.sampler_type == 'single':
return PretrainingSampler(
total_samples=len(ds),
# consumed_samples cal by global steps
consumed_samples=consumed_samples,
micro_batch_size=self.hparams.train_batchsize,
data_parallel_rank=self.trainer.global_rank,
data_parallel_size=world_size,
)
else:
raise Exception('Unknown sampler type: {}'.format(self.hparams.sampler_type))
def setup(self, stage: Optional[str] = None) -> None:
return
def train_dataloader(self):
ds = self.datasets[self.hparams.train_datasets_field]
collate_fn = self.collate_fn
if hasattr(ds, 'collate_fn'):
collate_fn = ds.collate_fn
if self.hparams.replace_sampler_ddp is False:
return DataLoader(
ds,
batch_sampler=self.get_custom_sampler(ds),
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
pin_memory=True,
)
return DataLoader(
ds,
batch_size=self.hparams.train_batchsize,
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
pin_memory=True,
)
def val_dataloader(self):
ds = self.datasets[self.hparams.val_datasets_field]
collate_fn = self.collate_fn
if hasattr(ds, 'collate_fn'):
collate_fn = ds.collate_fn
return DataLoader(
ds,
batch_size=self.hparams.val_batchsize,
shuffle=False,
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
sampler=DistributedSampler(
ds, shuffle=False),
pin_memory=True,
)
# return DataLoader(
# ds, shuffle=False, batch_size=self.hparams.val_batchsize, pin_memory=False, collate_fn=collate_fn,
# )
def test_dataloader(self):
ds = self.datasets[self.hparams.test_datasets_field]
collate_fn = self.collate_fn
if collate_fn is None and hasattr(ds, 'collater'):
collate_fn = ds.collater
return DataLoader(
ds,
batch_size=self.hparams.test_batchsize,
shuffle=False,
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
sampler=DistributedSampler(
ds, shuffle=False),
pin_memory=True,
)
|