Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
import torch | |
import yaml | |
import fourm.utils as utils | |
from fourm.data import (CenterCropImageAugmenter, EmptyAugmenter, | |
PreTokenizedImageAugmenter,RandomCropImageAugmenter, build_fm_pretraining_dataset, | |
build_huggingface_pretraining_dataloader, | |
build_wds_fm_pretraining_dataloader) | |
from fourm.data.modality_transforms import CaptionTransform | |
from fourm.data.modality_info import MODALITY_TRANSFORMS | |
def setup_sampling_mod_info(dataset_config, modality_info): | |
# Subset of modality info for each dataset | |
# Input and output modalities for one dataset | |
in_domains = sorted(dataset_config['in_domains'].split('-')) | |
out_domains = sorted(dataset_config['out_domains'].split('-')) | |
all_domains = sorted(list(set(in_domains) | set(out_domains))) | |
mod_info = copy.deepcopy(modality_info) | |
mod_info = {mod: mod_info[mod] for mod in all_domains} | |
# Dirichlet concentration parameter (Alpha) | |
if dataset_config.get('alphas_config', None) is None: | |
for mod in mod_info: | |
mod_info[mod]["input_alphas"] = [0.] | |
mod_info[mod]["target_alphas"] = [0.] | |
if 'input_alphas' in dataset_config: | |
input_alphas = dataset_config['input_alphas'].split('-') | |
if len(input_alphas) == 1: | |
input_alphas = [float(input_alphas[0])] * len(in_domains) | |
else: | |
input_alphas = [float(alpha) for alpha in input_alphas] | |
for mod, alpha in zip(in_domains, input_alphas): | |
mod_info[mod]['input_alphas'] = [alpha] | |
if 'target_alphas' in dataset_config: | |
target_alphas = dataset_config['target_alphas'].split('-') | |
if len(target_alphas) == 1: | |
target_alphas = [float(target_alphas[0])] * len(out_domains) | |
else: | |
target_alphas = [float(alpha) for alpha in target_alphas] | |
for mod, alpha in zip(out_domains, target_alphas): | |
mod_info[mod]["target_alphas"] = [alpha] | |
sampling_weights = None | |
else: | |
print(f"Loading alphas config from: {dataset_config['alphas_config']}") | |
with open(dataset_config['alphas_config'], "r") as f: | |
alphas_config = yaml.safe_load(f) | |
if 'sampling_weights' in alphas_config: | |
sampling_weights = alphas_config['sampling_weights'] | |
alphas_config = alphas_config['alphas_mixture'] | |
else: | |
sampling_weights = None | |
for mod in mod_info: | |
mod_info[mod]["input_alphas"] = alphas_config[mod]["input_alphas"] | |
mod_info[mod]["target_alphas"] = alphas_config[mod]["target_alphas"] | |
if modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']: | |
mod_info[mod]['keep'] = alphas_config[mod]['keep'] | |
return mod_info, sampling_weights | |
def get_train_dataloader(dataset_config, modality_info, sampling_weights, text_tokenizer, input_size, | |
num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens, | |
num_tasks, num_workers, dataset_batch_size=None, epoch_size=None): | |
in_domains = sorted(list(dataset_config['in_domains'].split('-'))) | |
out_domains = sorted(list(dataset_config['out_domains'].split('-'))) | |
all_domains = sorted(list(set(in_domains) | set(out_domains))) | |
modality_transforms = MODALITY_TRANSFORMS | |
if 'caption' in modality_transforms: | |
modality_transforms['caption'] = CaptionTransform( | |
aligned_captions=dataset_config.get('aligned_captions', True) | |
) | |
if dataset_config['type'] == 'multimodal': | |
is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info]) | |
if is_pretokenized: | |
# Multi-modal training data augmentation (uses pre-tokenized data augmentation) | |
image_augmenter = PreTokenizedImageAugmenter( | |
target_size=input_size, | |
no_aug=(not dataset_config.get('tok_train_aug', True)), | |
main_domain=dataset_config['main_augment_domain'] | |
) | |
else: | |
image_augmenter = RandomCropImageAugmenter( | |
target_size=input_size, | |
hflip=dataset_config.get('hflip'), | |
crop_scale=tuple(dataset_config.get('crop_scale')), | |
crop_ratio=tuple(dataset_config.get('crop_ratio')), | |
) | |
# Input and target token ranges | |
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens) | |
min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens) | |
min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens | |
min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens | |
if dataset_config['use_wds']: | |
# Using webdataset | |
loader = build_wds_fm_pretraining_dataloader( | |
data_path=dataset_config['data_path'], all_domains=all_domains, | |
modality_info=modality_info, modality_transforms=modality_transforms, | |
image_augmenter=image_augmenter, text_tokenizer=text_tokenizer, | |
input_tokens_range=(min_input_tokens, num_input_tokens), | |
target_tokens_range=(min_target_tokens, num_target_tokens), | |
num_gpus=num_tasks, num_workers=num_workers, | |
batch_size=dataset_batch_size, epoch_size=epoch_size, | |
modality_name_map=dataset_config.get('modality_name_map', None), | |
shuffle_buffer_load=dataset_config.get('wds_shuffle_buffer_tar', 1_000), | |
shuffle_buffer_repeat=dataset_config.get('wds_shuffle_buffer_repeat', 1_000), | |
n_repeats=dataset_config.get('wds_n_repeats', 1), | |
sampling_weights=sampling_weights, | |
) | |
else: | |
dataset_train = build_fm_pretraining_dataset( | |
data_path=dataset_config['data_path'], | |
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms, | |
image_augmenter=image_augmenter, text_tokenizer=text_tokenizer, | |
input_tokens_range=(min_input_tokens, num_input_tokens), | |
target_tokens_range=(min_target_tokens, num_target_tokens) | |
) | |
sampler_train = torch.utils.data.DistributedSampler( | |
dataset_train, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=True, drop_last=True, | |
) | |
# DataLoader has batch size 1 as it then gets collated through the Mixture dataloader | |
loader = torch.utils.data.DataLoader( | |
dataset_train, sampler=sampler_train, | |
batch_size=1, num_workers=0, | |
pin_memory=False, drop_last=True, | |
collate_fn=lambda x: x[0], | |
) | |
elif dataset_config['type'] == 'huggingface': | |
# Input and target token ranges | |
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
if dataset_config.get('use_wds', False): | |
raise NotImplementedError('Webdataset not yet implemented for huggingface datasets.') | |
else: | |
loader = build_huggingface_pretraining_dataloader( | |
data_path=dataset_config['data_path'], all_domains=all_domains, | |
modality_info=modality_info, modality_transforms=modality_transforms, | |
image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer, | |
input_tokens_range=(num_input_tokens, num_input_tokens), | |
target_tokens_range=(num_target_tokens, num_target_tokens), | |
num_gpus=num_tasks, num_workers=num_workers, | |
batch_size=dataset_batch_size, epoch_size=epoch_size, | |
split='train', streaming=True, rename_text_to_caption=True, | |
shuffle_buffer_load=dataset_config.get('shuffle_buffer_load', 1_000), | |
shuffle_seed=0, | |
) | |
else: | |
raise NotImplementedError(f'Dataset type {dataset_config["type"]} not implemented.') | |
return loader | |
def cfgs_get(key, val_config, dataset_name, train_configs, default=None): | |
""" Try to retrieve a key from the validation set config. | |
If it does not exist, default to retrieving it from the train set config | |
with the same dataset name. | |
""" | |
return val_config.get(key, train_configs[dataset_name].get(key, default)) | |
def get_val_dataloader(dataset_config, dataset_name, train_configs, modality_info, sampling_weights, text_tokenizer, | |
input_size, num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens, | |
fixed_eval, fixed_eval_input_tokens, fixed_eval_target_tokens, | |
dist_eval, num_tasks, num_workers, batch_size, pin_mem): | |
in_domains = sorted(list(cfgs_get('in_domains', dataset_config, dataset_name, train_configs).split('-'))) | |
out_domains = sorted(list(cfgs_get('out_domains', dataset_config, dataset_name, train_configs).split('-'))) | |
all_domains = sorted(list(set(in_domains) | set(out_domains))) | |
modality_transforms = MODALITY_TRANSFORMS | |
if 'caption' in modality_transforms: | |
modality_transforms['caption'] = CaptionTransform( | |
aligned_captions=cfgs_get('aligned_captions', dataset_config, dataset_name, train_configs, True) | |
) | |
dataset_type = cfgs_get('type', dataset_config, dataset_name, train_configs) | |
if dataset_type == 'multimodal': | |
main_augment_domain = cfgs_get('main_augment_domain', dataset_config, dataset_name, train_configs) | |
is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info]) | |
if is_pretokenized: | |
eval_image_augmenter = PreTokenizedImageAugmenter( | |
target_size=input_size, no_aug=True, main_domain=main_augment_domain | |
) | |
else: | |
eval_image_augmenter = CenterCropImageAugmenter( | |
target_size=input_size, main_domain=main_augment_domain | |
) | |
if fixed_eval: | |
input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens) | |
target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens) | |
else: | |
# Input and target token ranges | |
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens) | |
min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens) | |
min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens | |
min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens | |
input_tokens_range = (min_input_tokens, num_input_tokens) | |
target_tokens_range = (min_target_tokens, num_target_tokens) | |
dataset_val = build_fm_pretraining_dataset( | |
data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs), | |
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms, | |
image_augmenter=eval_image_augmenter, text_tokenizer=text_tokenizer, | |
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range | |
) | |
print("Warning: Eval stats may vary slightly as the masking applied on images is random.") | |
if dist_eval: | |
if len(dataset_val) % num_tasks != 0: | |
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' | |
'This will slightly alter validation results as extra duplicate entries are added to achieve ' | |
'equal num of samples per-process.') | |
sampler_val = torch.utils.data.DistributedSampler( | |
dataset_val, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=False) | |
else: | |
sampler_val = torch.utils.data.SequentialSampler(dataset_val) | |
loader = torch.utils.data.DataLoader( | |
dataset_val, sampler=sampler_val, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
pin_memory=pin_mem, | |
drop_last=False, | |
) | |
elif dataset_type == 'huggingface': | |
if fixed_eval: | |
input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens) | |
target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens) | |
else: | |
# Input and target token ranges | |
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
input_tokens_range = (num_input_tokens, num_input_tokens) | |
target_tokens_range = (num_target_tokens, num_target_tokens) | |
loader = build_huggingface_pretraining_dataloader( | |
data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs), | |
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms, | |
image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer, | |
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range, | |
num_gpus=num_tasks, num_workers=num_workers, | |
batch_size=batch_size, epoch_size=None, | |
split='validation', streaming=True, rename_text_to_caption=True, | |
shuffle_buffer_load=cfgs_get('shuffle_buffer_load', dataset_config, dataset_name, train_configs, 1_000), | |
shuffle_seed=0, | |
) | |
else: | |
raise NotImplementedError(f'Dataset type {dataset_type} not implemented.') | |
return loader |