Spaces:
Running
on
Zero
Running
on
Zero
# -------------------------------------------------------- | |
# InternVL | |
# Copyright (c) 2024 OpenGVLab | |
# Licensed under The MIT License [see LICENSE for details] | |
# -------------------------------------------------------- | |
import bisect | |
import copy | |
import logging | |
from collections import defaultdict | |
from typing import List, Union | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from torch.utils.data import IterableDataset, get_worker_info | |
from transformers.trainer_pt_utils import LabelSmoother | |
from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN | |
IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_world_size(): | |
if not is_dist_avail_and_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
class PackedDataset(IterableDataset): | |
def __init__( | |
self, | |
tokenizer, | |
data_rank, | |
data_world_size, | |
datasets: List, | |
dataset_weight: List[int] = None, | |
num_images_expected: int = 6, | |
max_packed_tokens: int = 32768, | |
max_buffer_size: int = 100, | |
log_freq: int = 1000000, | |
strict_mode: bool = False, | |
debug_mode: bool = False, | |
replacement: bool = True, | |
allow_overflow: bool = True, | |
allow_empty_data: bool = False, | |
allow_deduplicated_ds_name: bool = False, | |
): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.data_rank = data_rank | |
self.data_world_size = data_world_size | |
self.datasets = datasets | |
self.num_images_expected = num_images_expected | |
self.max_buffer_size = max_buffer_size | |
self.log_freq = log_freq | |
self.strict_mode = strict_mode | |
self.debug_mode = debug_mode | |
self.replacement = replacement | |
self.allow_overflow = allow_overflow | |
self.allow_empty_data = allow_empty_data | |
self.max_packed_tokens = max_packed_tokens | |
self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) | |
self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) | |
assert self.img_start_token_id != self.tokenizer.unk_token_id | |
assert self.img_token_id != self.tokenizer.unk_token_id | |
assert self.img_end_token_id != self.tokenizer.unk_token_id | |
if dataset_weight is None: | |
dataset_weight = [1] * len(datasets) | |
self.dataset_type = [d.dataset_type for d in self.datasets] | |
self.datasets_orig = datasets | |
self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight] | |
self.datasets = [ds for ds in self.datasets_orig] | |
self.dataset_weight = [w for w in self.dataset_weight_orig] | |
# lazy init | |
self.worker_id = None | |
self.worker_state_key = None | |
self.dataset_iter_list = None | |
self._state_dict = { | |
'sample_info': {d.ds_name:0 for d in self.datasets}, | |
} | |
self.worker_custom_infos = None | |
ds_name_list = [d.ds_name for d in self.datasets] | |
if not allow_deduplicated_ds_name: | |
assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}' | |
for ds in self.datasets: | |
if ds.max_num_images > self.num_images_expected: | |
logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}') | |
ds.max_num_images = num_images_expected | |
if ds.max_tokens > self.max_packed_tokens: | |
logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}') | |
ds.max_tokens = self.max_packed_tokens | |
self._state_dict[ds.ds_name] = {} | |
if get_rank() == 0: | |
logger.info( | |
f'Loaded dataset to pack: {ds_name_list}, ' | |
f'{self.num_images_expected=}, {self.max_packed_tokens=}, ' | |
f'{self.replacement=}, {self.allow_overflow=}', | |
) | |
temp = [] | |
for ds, ds_w in zip(self.datasets, self.dataset_weight): | |
temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%') | |
temp = '\n'.join(temp) | |
logger.info( | |
f'Sampling prob for each dataset:\n{temp}' | |
) | |
if self.allow_empty_data: | |
logger.warning('allow_empty_data is enabled, note that empty data may be generated!') | |
def load_state_dict(self, state_dict, custom_infos=None): | |
self.worker_custom_infos = custom_infos | |
self._state_dict.update(state_dict) | |
for ds in self.datasets: | |
if ds.ds_name in self._state_dict: | |
ds.load_state_dict(self._state_dict[ds.ds_name]) | |
logger.info(f'{ds.ds_name=} is resumed.') | |
else: | |
logger.warning(f'{ds.ds_name=} is not resumed.') | |
def _should_log(self): | |
worker_id = 0 if get_worker_info() is None else get_worker_info().id | |
num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers | |
worker_id = num_workers * get_rank() + worker_id | |
num_workers = num_workers * get_world_size() | |
return worker_id == 0 | |
def next_data(self, current_dataset_idx): | |
while True: | |
try: | |
current_sample = next(self.dataset_iter_list[current_dataset_idx]) | |
break # Exit loop if successful | |
except StopIteration: | |
if self.replacement: | |
# logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.') | |
try: | |
self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx]) | |
current_sample = next(self.dataset_iter_list[current_dataset_idx]) | |
break | |
except: | |
# logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') | |
self.datasets.pop(current_dataset_idx) | |
self.dataset_iter_list.pop(current_dataset_idx) | |
self.dataset_weight.pop(current_dataset_idx) | |
if len(self.datasets) == 0: | |
raise StopIteration | |
current_dataset_idx = np.random.choice(len(self.datasets)) | |
else: | |
# logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') | |
self.datasets.pop(current_dataset_idx) | |
self.dataset_iter_list.pop(current_dataset_idx) | |
self.dataset_weight.pop(current_dataset_idx) | |
if len(self.datasets) == 0: | |
raise StopIteration | |
current_dataset_idx = np.random.choice(len(self.datasets)) | |
except: | |
logger.error('Unexpected error!') | |
if len(self.datasets) == 0: | |
raise StopIteration | |
current_dataset_idx = np.random.choice(len(self.datasets)) | |
current_ds_name = self.datasets[current_dataset_idx].ds_name | |
current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx | |
if self.worker_state_key not in self._state_dict[current_ds_name]: | |
self._state_dict[current_ds_name][self.worker_state_key] = {} | |
meta_info = current_sample.pop('meta_info', {}) | |
self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info) | |
self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1 | |
return current_sample | |
def find_buffer(self, buffer_list, new_sample): | |
# NOTE: use `bisect` to search might be faster | |
find = False | |
find_idx = -1 | |
num_images_current = new_sample['pixel_values'].size(0) | |
for buffer_idx, buffer in enumerate(buffer_list): | |
num_images_buffer = buffer['pixel_values'].size(0) | |
if num_images_buffer + num_images_current <= self.num_images_expected: | |
num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0) | |
if num_merged_tokens <= self.max_packed_tokens: | |
find = True | |
find_idx = buffer_idx | |
break | |
if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2: | |
find = True | |
find_idx = buffer_idx | |
if find: | |
return buffer_list.pop(find_idx) | |
return None | |
def update_buffer(self, buffer, new_sample): | |
if buffer is None: | |
new_sample['data_index'] = torch.zeros_like(new_sample['input_ids']) | |
return new_sample | |
new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item() | |
assert buffer.keys() == new_sample.keys() | |
for k in buffer: | |
buffer[k] = torch.cat([buffer[k], new_sample[k]]) | |
return buffer | |
def check_valid(sample_to_check, min_active_tokens_ratio=1/256): | |
num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum() | |
num_tokens = sample_to_check['labels'].numel() | |
return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio | |
def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id): | |
if buffer['input_ids'].size(0) <= max_tokens: | |
return [buffer] | |
def _image_is_splitted(input_ids, cut_idx): | |
is_image_start = input_ids[cut_idx].item() == img_start_token_id | |
is_image_token = input_ids[cut_idx].item() == img_token_id | |
is_image_end = input_ids[cut_idx].item() == img_end_token_id | |
return is_image_start or is_image_token or is_image_end | |
def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx): | |
assert (right_idx is None) == (right_img_idx is None) | |
left_sample = {} | |
right_sample = {} if right_idx is not None else None | |
for k in sample_to_split: | |
if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']: | |
left_sample[k] = sample_to_split[k][:left_idx] | |
if right_sample is not None: | |
right_sample[k] = sample_to_split[k][right_idx:] | |
elif k in ['pixel_values', 'image_flags']: | |
left_sample[k] = sample_to_split[k][:left_img_idx] | |
if right_sample is not None: | |
right_sample[k] = sample_to_split[k][right_img_idx:] | |
else: | |
raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}') | |
return left_sample, right_sample | |
splitted_buffer = [] | |
while buffer['input_ids'].size(0) > max_tokens: | |
img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist() | |
img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist() | |
assert len(img_start_idx_list) == len(img_end_idx_list) | |
if _image_is_splitted(buffer['input_ids'], max_tokens): | |
cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens) | |
if buffer['input_ids'][max_tokens] == img_start_token_id: | |
assert max_tokens == img_start_idx_list[cut_idx] | |
cut_left_idx = img_start_idx_list[cut_idx] | |
cut_left_img_idx = cut_idx | |
else: | |
cut_left_idx = img_start_idx_list[cut_idx - 1] | |
cut_left_img_idx = cut_idx - 1 | |
cut_right_idx = cut_left_idx | |
cut_right_img_idx = cut_left_img_idx | |
else: | |
cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens) | |
if cut_img_idx < len(img_start_idx_list): | |
cut_right_idx = img_start_idx_list[cut_img_idx] | |
cut_right_img_idx = cut_img_idx | |
else: | |
cut_right_idx = None | |
cut_right_img_idx = None | |
cut_left_idx = max_tokens | |
cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0) | |
left, right = _split( | |
sample_to_split=buffer, | |
left_idx=cut_left_idx, | |
left_img_idx=cut_left_img_idx, | |
right_idx=cut_right_idx, | |
right_img_idx=cut_right_img_idx, | |
) | |
assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0) | |
if right is not None: | |
assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0) | |
if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left): | |
splitted_buffer.append(left) | |
if right is None or right['pixel_values'].size(0) == 0: | |
break | |
buffer = right | |
if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer): | |
splitted_buffer.append(buffer) | |
break | |
logger.debug( | |
f'split a sample into {len(splitted_buffer)} samples, ' | |
f'current max_tokens={max_tokens}' | |
) | |
return splitted_buffer | |
def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer): | |
# NOTE: in-place operation | |
splitted_buffer = PackedDataset.split_buffer( | |
buffer=buffer, | |
max_tokens=self.max_packed_tokens, | |
img_start_token_id=self.img_start_token_id, | |
img_token_id=self.img_token_id, | |
img_end_token_id=self.img_end_token_id, | |
) | |
for each_buffer in splitted_buffer: | |
if each_buffer['pixel_values'].size(0) > self.num_images_expected: | |
logger.error( | |
f"Find a sample with {each_buffer['pixel_values'].size(0)} images, " | |
f'which exceeds {self.num_images_expected}' | |
) | |
continue | |
if each_buffer['input_ids'].size(0) >= self.max_packed_tokens: | |
assert each_buffer['input_ids'].size(0) == self.max_packed_tokens | |
buffer_max_len_list.append(each_buffer) | |
continue | |
find_idx = len(buffer_list) | |
num_images_new_sample = each_buffer['pixel_values'].size(0) | |
for buffer_idx in range(len(buffer_list)): | |
if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample: | |
find_idx = buffer_idx | |
break | |
buffer_list.insert(find_idx, each_buffer) | |
for i in range(1, len(buffer_list)): | |
assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0) | |
return buffer_list, buffer_max_len_list | |
def pad_buffer(self, buffer): | |
if buffer['pixel_values'].size(0) == self.num_images_expected: | |
return buffer | |
num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0) | |
pad_images = torch.stack([ | |
torch.zeros_like(buffer['pixel_values'][0]) | |
for _ in range(num_pad_images) | |
]) | |
pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long) | |
buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images]) | |
buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags]) | |
return buffer | |
def postprocess_buffer(self, buffer, custom_infos=None): | |
buffer['worker_state_key'] = self.worker_state_key | |
buffer['worker_state_dict'] = self._state_dict | |
if custom_infos is not None: | |
buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)} | |
return buffer | |
def print_log(self, iter_idx, buffer_list): | |
if iter_idx % self.log_freq != 0: | |
return | |
if self._should_log(): | |
logger.info( | |
f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}" | |
) | |
def __iter__(self): | |
iter_idx = 0 | |
buffer_list = [] | |
buffer_max_len_list = [] | |
if self._should_log(): | |
logger.info(f'Begin to iter, {len(buffer_list)=}') | |
worker_id = 0 if get_worker_info() is None else get_worker_info().id | |
num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers | |
worker_id = num_workers * self.data_rank + worker_id | |
num_workers = num_workers * self.data_world_size | |
rng = np.random.default_rng(seed=worker_id) | |
# reset states of each dataset | |
self.worker_id = worker_id | |
self.worker_state_key = f'work_state_{self.worker_id}' | |
self.datasets = [d for d in self.datasets_orig] | |
self.dataset_weight = [w for w in self.dataset_weight_orig] | |
self.dataset_iter_list = [iter(d) for d in self.datasets] | |
for ds in self.datasets: | |
# if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)): | |
ds.worker_id = worker_id | |
ds.worker_state_key = f'work_state_{self.worker_id}' | |
ds.num_workers = num_workers | |
if self._should_log() and worker_id == 0: | |
logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}') | |
if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos: | |
custom_infos = self.worker_custom_infos[self.worker_state_key] | |
# buffer list | |
if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list): | |
buffer_list = custom_infos['buffer_list'] | |
if self._should_log() and worker_id == 0: | |
logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}') | |
# other infos | |
# reset | |
self.worker_custom_infos = None | |
logger.debug( | |
f'{self.__class__.__name__} Rank {self.data_rank} ' | |
f'Worker {worker_id} begin to load data' | |
) | |
while True: | |
self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight] | |
current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight) | |
try: | |
current_sample = self.next_data(current_dataset_idx) | |
except: | |
logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})') | |
while len(buffer_list) > 0: | |
if self.strict_mode: | |
yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0))) | |
else: | |
yield self.postprocess_buffer(buffer_list.pop(0)) | |
logger.info(f'buffer_list is empty! ({len(buffer_list)=})') | |
return | |
buffer = self.find_buffer(buffer_list, current_sample) | |
buffer = self.update_buffer(buffer, current_sample) | |
buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer) | |
while len(buffer_max_len_list) > 0: | |
if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens: | |
logger.debug( | |
f'num tokens of a buffer exceed {self.max_packed_tokens=}, ' | |
f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images" | |
) | |
if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected: | |
# buffer_max_len_list.pop(0) | |
yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list}) | |
else: | |
yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list}) | |
while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected: | |
logger.error( | |
f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) " | |
f'is larger than num_images_expected({self.num_images_expected})' | |
) | |
buffer_list.pop(0) | |
while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected: | |
if self.debug_mode: | |
debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) | |
while True: | |
yield debug_data.copy() | |
yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) | |
while len(buffer_list) > self.max_buffer_size: | |
logger.debug( | |
f'Failed to pack data to exactly {self.num_images_expected} images, ' | |
f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images." | |
) | |
if self.strict_mode: | |
yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list}) | |
else: | |
yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) | |
self.print_log(iter_idx=iter_idx, buffer_list=buffer_list) | |
iter_idx += 1 | |
def get_cu_seqlens_and_indexes( | |
data_index: torch.LongTensor, # (seq_len,) | |
input_ids: torch.LongTensor, # (seq_len,) | |
labels: torch.LongTensor, # (seq_len,) | |
len2weight: callable, | |
): | |
indexes = [] | |
cu_seqlens = [0] | |
loss_weight = [] | |
start = data_index.min() | |
end = data_index.max() + 1 | |
for i in range(start, end): | |
num_tokens = (data_index == i).sum().item() | |
indexes.extend(list(range(num_tokens))) | |
cu_seqlens.append(cu_seqlens[-1] + num_tokens) | |
assert num_tokens > 0 | |
curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] | |
assert (curr_data_index == i).all(), data_index | |
curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] | |
num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item() | |
loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens) | |
assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}' | |
loss_weight = torch.tensor(loss_weight, dtype=torch.float32) | |
return cu_seqlens, indexes, loss_weight | |
WARNING_CNT = defaultdict(int) | |
def packed_collate_fn( | |
features, | |
data_collator, | |
len2weight: callable, | |
max_item_length: int, | |
micro_num: int = 1, | |
loss_reduction_all_gather: bool = False, | |
pad_id: int = 0, | |
): | |
if not isinstance(features, list): | |
features = [features] | |
if len(features) > micro_num: | |
raise NotImplementedError(f'{len(features)=} > {micro_num=}') | |
if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5: | |
logger.warning( | |
f'{len(features)=} > {micro_num=}, ' | |
f'the features will be padded to satisfy micro_num requirement' | |
) | |
WARNING_CNT['micro_num_warning'] += 1 | |
# ensure that the len(features) is equal to the required micro_num | |
num_features = len(features) | |
while len(features) < micro_num: | |
features.append(copy.deepcopy(features[0])) | |
features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID) | |
indexes = [] | |
cu_seqlens = [] | |
cu_num_images_list = [0] | |
worker_state_key_list = [] | |
worker_state_dict_list = [] | |
worker_state_custom_infos_list = [] | |
batch_lens = [feat['input_ids'].shape for feat in features] | |
max_item_length = max_item_length or max(batch_lens)[0] | |
num_samples = 0 | |
num_padding_tokens = 0 | |
for feat_idx, feat in enumerate(features): | |
data_index = feat.pop('data_index') | |
curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes( | |
data_index=data_index, | |
input_ids=feat['input_ids'], | |
labels=feat['labels'], | |
len2weight=len2weight, | |
) | |
feat['loss_weight'] = curr_loss_weight | |
if feat_idx < num_features: | |
num_samples += len(curr_cu_seqlens) - 1 | |
if curr_cu_seqlens[-1] < max_item_length: | |
curr_cu_seqlens.append(max_item_length) | |
curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2]))) | |
indexes.append(torch.tensor(curr_indexes, dtype=torch.long)) | |
cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32)) | |
worker_state_key_list.append(feat.pop('worker_state_key')) | |
worker_state_dict_list.append(feat.pop('worker_state_dict')) | |
worker_state_custom_infos_list.append(feat.pop('custom_infos', None)) | |
num_padding_tokens += (max_item_length - feat['input_ids'].size(0)) | |
cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0)) | |
batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id) | |
# convert it to list in case it is converted into bf16 | |
batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist() | |
batch['attention_mask'] = torch.stack(cu_seqlens) | |
batch['loss_reduction_all_gather'] = loss_reduction_all_gather | |
batch['statistics'] = torch.tensor( | |
[ | |
num_samples, | |
num_padding_tokens, | |
batch['image_flags'].numel() - batch['image_flags'].sum().item(), | |
], | |
dtype=torch.long, | |
) | |
batch.pop('type_ids') | |
return batch | |