Spaces:
Running
on
Zero
Running
on
Zero
# -------------------------------------------------------- | |
# InternVL | |
# Copyright (c) 2024 OpenGVLab | |
# Licensed under The MIT License [see LICENSE for details] | |
# -------------------------------------------------------- | |
import datasets | |
import torch | |
import transformers | |
from torch.utils.data import DataLoader | |
from transformers.trainer import is_datasets_available, seed_worker | |
def get_train_dataloader(self) -> DataLoader: | |
""" | |
Returns the training [`~torch.utils.data.DataLoader`]. | |
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed | |
training if necessary) otherwise. | |
Subclass and override this method if you want to inject some custom behavior. | |
""" | |
if self.train_dataset is None: | |
raise ValueError('Trainer: training requires a train_dataset.') | |
train_dataset = self.train_dataset | |
data_collator = self.data_collator | |
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): | |
train_dataset = self._remove_unused_columns(train_dataset, description='training') | |
else: | |
data_collator = self._get_collator_with_removed_columns(data_collator, description='training') | |
dataloader_params = { | |
'batch_size': self._train_batch_size, | |
'collate_fn': data_collator, | |
'num_workers': self.args.dataloader_num_workers, | |
'pin_memory': self.args.dataloader_pin_memory, | |
'persistent_workers': self.args.dataloader_persistent_workers, | |
} | |
if not isinstance(train_dataset, torch.utils.data.IterableDataset): | |
dataloader_params['sampler'] = self._get_train_sampler() | |
dataloader_params['drop_last'] = self.args.dataloader_drop_last | |
dataloader_params['worker_init_fn'] = seed_worker | |
if self.args.use_packed_ds: | |
return DataLoader(train_dataset, **dataloader_params) | |
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) | |
def replace_train_dataloader(): | |
transformers.Trainer.get_train_dataloader = get_train_dataloader | |
# print('Replace train dataloader!!') | |