aroraaman's picture
Add all of `fourm`
3424266
raw
history blame
14.9 kB
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import torch
import yaml
import fourm.utils as utils
from fourm.data import (CenterCropImageAugmenter, EmptyAugmenter,
PreTokenizedImageAugmenter,RandomCropImageAugmenter, build_fm_pretraining_dataset,
build_huggingface_pretraining_dataloader,
build_wds_fm_pretraining_dataloader)
from fourm.data.modality_transforms import CaptionTransform
from fourm.data.modality_info import MODALITY_TRANSFORMS
def setup_sampling_mod_info(dataset_config, modality_info):
# Subset of modality info for each dataset
# Input and output modalities for one dataset
in_domains = sorted(dataset_config['in_domains'].split('-'))
out_domains = sorted(dataset_config['out_domains'].split('-'))
all_domains = sorted(list(set(in_domains) | set(out_domains)))
mod_info = copy.deepcopy(modality_info)
mod_info = {mod: mod_info[mod] for mod in all_domains}
# Dirichlet concentration parameter (Alpha)
if dataset_config.get('alphas_config', None) is None:
for mod in mod_info:
mod_info[mod]["input_alphas"] = [0.]
mod_info[mod]["target_alphas"] = [0.]
if 'input_alphas' in dataset_config:
input_alphas = dataset_config['input_alphas'].split('-')
if len(input_alphas) == 1:
input_alphas = [float(input_alphas[0])] * len(in_domains)
else:
input_alphas = [float(alpha) for alpha in input_alphas]
for mod, alpha in zip(in_domains, input_alphas):
mod_info[mod]['input_alphas'] = [alpha]
if 'target_alphas' in dataset_config:
target_alphas = dataset_config['target_alphas'].split('-')
if len(target_alphas) == 1:
target_alphas = [float(target_alphas[0])] * len(out_domains)
else:
target_alphas = [float(alpha) for alpha in target_alphas]
for mod, alpha in zip(out_domains, target_alphas):
mod_info[mod]["target_alphas"] = [alpha]
sampling_weights = None
else:
print(f"Loading alphas config from: {dataset_config['alphas_config']}")
with open(dataset_config['alphas_config'], "r") as f:
alphas_config = yaml.safe_load(f)
if 'sampling_weights' in alphas_config:
sampling_weights = alphas_config['sampling_weights']
alphas_config = alphas_config['alphas_mixture']
else:
sampling_weights = None
for mod in mod_info:
mod_info[mod]["input_alphas"] = alphas_config[mod]["input_alphas"]
mod_info[mod]["target_alphas"] = alphas_config[mod]["target_alphas"]
if modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']:
mod_info[mod]['keep'] = alphas_config[mod]['keep']
return mod_info, sampling_weights
def get_train_dataloader(dataset_config, modality_info, sampling_weights, text_tokenizer, input_size,
num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
num_tasks, num_workers, dataset_batch_size=None, epoch_size=None):
in_domains = sorted(list(dataset_config['in_domains'].split('-')))
out_domains = sorted(list(dataset_config['out_domains'].split('-')))
all_domains = sorted(list(set(in_domains) | set(out_domains)))
modality_transforms = MODALITY_TRANSFORMS
if 'caption' in modality_transforms:
modality_transforms['caption'] = CaptionTransform(
aligned_captions=dataset_config.get('aligned_captions', True)
)
if dataset_config['type'] == 'multimodal':
is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
if is_pretokenized:
# Multi-modal training data augmentation (uses pre-tokenized data augmentation)
image_augmenter = PreTokenizedImageAugmenter(
target_size=input_size,
no_aug=(not dataset_config.get('tok_train_aug', True)),
main_domain=dataset_config['main_augment_domain']
)
else:
image_augmenter = RandomCropImageAugmenter(
target_size=input_size,
hflip=dataset_config.get('hflip'),
crop_scale=tuple(dataset_config.get('crop_scale')),
crop_ratio=tuple(dataset_config.get('crop_ratio')),
)
# Input and target token ranges
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens
if dataset_config['use_wds']:
# Using webdataset
loader = build_wds_fm_pretraining_dataloader(
data_path=dataset_config['data_path'], all_domains=all_domains,
modality_info=modality_info, modality_transforms=modality_transforms,
image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
input_tokens_range=(min_input_tokens, num_input_tokens),
target_tokens_range=(min_target_tokens, num_target_tokens),
num_gpus=num_tasks, num_workers=num_workers,
batch_size=dataset_batch_size, epoch_size=epoch_size,
modality_name_map=dataset_config.get('modality_name_map', None),
shuffle_buffer_load=dataset_config.get('wds_shuffle_buffer_tar', 1_000),
shuffle_buffer_repeat=dataset_config.get('wds_shuffle_buffer_repeat', 1_000),
n_repeats=dataset_config.get('wds_n_repeats', 1),
sampling_weights=sampling_weights,
)
else:
dataset_train = build_fm_pretraining_dataset(
data_path=dataset_config['data_path'],
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
input_tokens_range=(min_input_tokens, num_input_tokens),
target_tokens_range=(min_target_tokens, num_target_tokens)
)
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=True, drop_last=True,
)
# DataLoader has batch size 1 as it then gets collated through the Mixture dataloader
loader = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=1, num_workers=0,
pin_memory=False, drop_last=True,
collate_fn=lambda x: x[0],
)
elif dataset_config['type'] == 'huggingface':
# Input and target token ranges
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
if dataset_config.get('use_wds', False):
raise NotImplementedError('Webdataset not yet implemented for huggingface datasets.')
else:
loader = build_huggingface_pretraining_dataloader(
data_path=dataset_config['data_path'], all_domains=all_domains,
modality_info=modality_info, modality_transforms=modality_transforms,
image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
input_tokens_range=(num_input_tokens, num_input_tokens),
target_tokens_range=(num_target_tokens, num_target_tokens),
num_gpus=num_tasks, num_workers=num_workers,
batch_size=dataset_batch_size, epoch_size=epoch_size,
split='train', streaming=True, rename_text_to_caption=True,
shuffle_buffer_load=dataset_config.get('shuffle_buffer_load', 1_000),
shuffle_seed=0,
)
else:
raise NotImplementedError(f'Dataset type {dataset_config["type"]} not implemented.')
return loader
def cfgs_get(key, val_config, dataset_name, train_configs, default=None):
""" Try to retrieve a key from the validation set config.
If it does not exist, default to retrieving it from the train set config
with the same dataset name.
"""
return val_config.get(key, train_configs[dataset_name].get(key, default))
def get_val_dataloader(dataset_config, dataset_name, train_configs, modality_info, sampling_weights, text_tokenizer,
input_size, num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
fixed_eval, fixed_eval_input_tokens, fixed_eval_target_tokens,
dist_eval, num_tasks, num_workers, batch_size, pin_mem):
in_domains = sorted(list(cfgs_get('in_domains', dataset_config, dataset_name, train_configs).split('-')))
out_domains = sorted(list(cfgs_get('out_domains', dataset_config, dataset_name, train_configs).split('-')))
all_domains = sorted(list(set(in_domains) | set(out_domains)))
modality_transforms = MODALITY_TRANSFORMS
if 'caption' in modality_transforms:
modality_transforms['caption'] = CaptionTransform(
aligned_captions=cfgs_get('aligned_captions', dataset_config, dataset_name, train_configs, True)
)
dataset_type = cfgs_get('type', dataset_config, dataset_name, train_configs)
if dataset_type == 'multimodal':
main_augment_domain = cfgs_get('main_augment_domain', dataset_config, dataset_name, train_configs)
is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
if is_pretokenized:
eval_image_augmenter = PreTokenizedImageAugmenter(
target_size=input_size, no_aug=True, main_domain=main_augment_domain
)
else:
eval_image_augmenter = CenterCropImageAugmenter(
target_size=input_size, main_domain=main_augment_domain
)
if fixed_eval:
input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
else:
# Input and target token ranges
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens
input_tokens_range = (min_input_tokens, num_input_tokens)
target_tokens_range = (min_target_tokens, num_target_tokens)
dataset_val = build_fm_pretraining_dataset(
data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs),
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
image_augmenter=eval_image_augmenter, text_tokenizer=text_tokenizer,
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range
)
print("Warning: Eval stats may vary slightly as the masking applied on images is random.")
if dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=False)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
loader = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=False,
)
elif dataset_type == 'huggingface':
if fixed_eval:
input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
else:
# Input and target token ranges
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
input_tokens_range = (num_input_tokens, num_input_tokens)
target_tokens_range = (num_target_tokens, num_target_tokens)
loader = build_huggingface_pretraining_dataloader(
data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs),
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
num_gpus=num_tasks, num_workers=num_workers,
batch_size=batch_size, epoch_size=None,
split='validation', streaming=True, rename_text_to_caption=True,
shuffle_buffer_load=cfgs_get('shuffle_buffer_load', dataset_config, dataset_name, train_configs, 1_000),
shuffle_seed=0,
)
else:
raise NotImplementedError(f'Dataset type {dataset_type} not implemented.')
return loader