chendl's picture
Add application file
0b7b08a
""" Main training script """
import argparse
import copy
import glob
import os
import random
import functools
import numpy as np
import torch
# torch.multiprocessing.set_sharing_strategy('file_system')
import wandb
from data2 import get_data
from distributed import init_distributed_device, world_info_from_env
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
CPUOffload,
StateDictType,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from train_utils import train_one_epoch
from transformers import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from open_flamingo import create_model_and_transforms
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from torch.distributed.optim import ZeroRedundancyOptimizer
import warnings
warnings.filterwarnings("ignore")
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(message)s',
datefmt='%m/%d %I:%M:%S',
)
class FakeDataloader:
def __iter__(self):
return self
def __next__(self):
return None
def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
def get_grouped_params(model, args):
params_with_wd, params_without_wd = [], []
def apply_decay(x):
x = x.lower()
return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x
for n, p in model.named_parameters():
# if p.requires_grad:
if apply_decay(n):
if torch.distributed.get_rank() == 0:
logging.info(f"with wd: {n}")
params_with_wd.append(p)
else:
if torch.distributed.get_rank() == 0:
logging.info(f"without wd: {n}")
params_without_wd.append(p)
return [
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
]
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
def lambda_auto_wrap_policy(
module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn,
) -> bool:
"""
A convenient auto wrap policy to wrap submodules based on an arbitrary user
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
a `wrapper_cls` unit.
Return if a module should be wrapped during auto wrapping.
The first three parameters are required by :func:`_recursive_wrap`.
Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
this module will be wrapped.
"""
if recurse:
return True # always recurse
return lambda_fn(module)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str)
parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str)
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
parser.add_argument(
"--tokenizer_path",
default="facebook/opt-1.3b",
type=str,
help="path to tokenizer",
)
parser.add_argument(
"--run_name",
type=str,
default="openflamingo3B",
help="used to name saving directory and wandb run",
)
parser.add_argument("--use_media_placement_augmentation", action="store_true")
parser.add_argument("--offline", action="store_true")
parser.add_argument("--num_steps", type=int, default=300000)
parser.add_argument(
"--logging_steps", type=int, default=10, help="log loss every n steps"
)
# Sum of gradient optimization batch size
parser.add_argument("--batch_size_mmc4", type=int, default=128)
parser.add_argument("--batch_size_laion", type=int, default=128)
parser.add_argument("--batch_size_pile", type=int, default=128)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
default=None,
)
parser.add_argument(
"--delete_previous_checkpoint",
action="store_true",
help="delete previous checkpoint when saving new checkpoint",
)
parser.add_argument(
"--laion_shards",
type=str,
help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
)
parser.add_argument(
"--mmc4_shards",
type=str,
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
)
parser.add_argument(
"--pile_shards",
type=str,
default=None,
help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", default=1e-4, type=float)
parser.add_argument(
"--lr_scheduler",
default="constant",
type=str,
help="constant, linear, or cosine",
)
parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
parser.add_argument("--loss_multiplier_pile", type=float, default=1.0)
parser.add_argument("--loss_multiplier_det", type=float, default=1.0)
parser.add_argument("--loss_multiplier_rel", type=float, default=1.0)
parser.add_argument("--loss_multiplier_attn", type=float, default=1.0)
parser.add_argument("--warmup_steps", default=5000, type=int)
# weight decay is only apply to YOLOX head if using FSDP
# https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159
parser.add_argument("--weight_decay", default=0.05, type=float)
parser.add_argument(
"--precision",
choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
default="fp32",
help="Floating point precision.",
)
# data args
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--dataset_resampled", action="store_true")
# distributed training args
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training.",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)
# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
"--wandb_project",
type=str,
)
parser.add_argument(
"--wandb_entity",
type=str,
)
parser.add_argument(
"--save_checkpoints_to_wandb",
default=False,
action="store_true",
help="save checkpoints to wandb",
)
parser.add_argument(
"--checkpoint_activations",
default=False,
action="store_true",
)
parser.add_argument(
"--freeze_vision_encoder",
default=False,
action="store_true",
)
parser.add_argument(
"--mmc4_textsim_threshold",
default=30,
type=float,
help="threshold for filtering images in mmc4 based on image-text similarity",
)
parser.add_argument(
"--location_token_num",
default=1000,
type=int,
)
parser.add_argument(
"--vis_embed_size",
type=int,
required=False,
)
parser.add_argument(
"--save_interval",
default=1000,
type=int,
required=False,
)
parser.add_argument(
"--skip_delete_pattern",
default=1500,
type=int,
required=False,
)
parser.add_argument(
"--ddp",
default=False,
action="store_true",
)
parser.add_argument(
"--pile_freq",
default=1,
type=int,
required=False,
)
parser.add_argument(
"--restart",
default=False,
action="store_true",
)
parser.add_argument(
"--lora",
default=False,
action="store_true",
)
parser.add_argument(
"--lora_r",
default=16,
type=int,
required=False,
)
parser.add_argument(
"--single",
default=False,
action="store_true",
)
# Finetune
parser.add_argument(
"--instruct",
default=False,
action="store_true",
)
parser.add_argument(
"--fix-ffn",
default=False,
action="store_true",
)
parser.add_argument(
"--prob_ground",
default=1.0,
type=float,
required=False,
)
parser.add_argument(
"--optimizer",
default="adamw",
type=str,
required=False,
)
parser.add_argument(
"--add_visual_token",
default=False,
action="store_true",
)
parser.add_argument(
"--use_format_v2",
default=False,
action="store_true",
)
parser.add_argument(
"--use_sam",
default=None,
type=str,
required=False,
)
parser.add_argument(
"--max-length",
default=608,
type=int,
required=False,
)
parser.add_argument(
"--image-size",
default=256,
type=int,
required=False,
)
parser.add_argument(
"--reset_llm",
default=False,
action="store_true",
)
parser.add_argument(
"--add_box",
default=False,
action="store_true",
)
parser.add_argument(
"--add_pe",
default=False,
action="store_true",
)
parser.add_argument(
"--only_grounded_sample",
default=False,
action="store_true",
)
parser.add_argument(
"--expand",
default=False,
action="store_true",
)
parser.add_argument(
"--delete_contained",
default=False,
action="store_true",
)
parser.add_argument(
"--relation",
default=False,
action="store_true",
)
parser.add_argument(
"--attn_reg",
default="l1",
type=str,
required=False,
)
parser.add_argument(
"--enhance_data",
default=False,
action="store_true",
)
parser.add_argument(
"--no_visual",
default=False,
action="store_true",
)
parser.add_argument(
"--no_previsual",
default=False,
action="store_true",
)
parser.add_argument(
"--roi_align",
default=False,
action="store_true",
)
parser.add_argument(
"--roi_output_size",
default=4,
type=int,
required=False,
)
parser.add_argument(
"--apply_mask",
default=False,
action="store_true",
)
parser.add_argument(
"--longer_previsual",
default=False,
action="store_true",
)
args = parser.parse_args()
assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation"
if args.no_previsual:
assert args.no_visual, "no_previsual MUST come with no_visual"
assert not args.enhance_data, "dont enable enhance_data"
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
args.local_rank, args.rank, args.world_size = world_info_from_env()
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
device_id = init_distributed_device(args)
random_seed(args.seed)
model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms(
args.vision_encoder_path,
args.vision_encoder_pretrained,
args.lm_path,
args.tokenizer_path if args.tokenizer_path else args.lm_path,
use_local_files=args.offline,
use_media_placement_augmentation=args.use_media_placement_augmentation,
checkpoint_activations=args.checkpoint_activations,
freeze_vision_encoder=args.freeze_vision_encoder,
location_token_num=args.location_token_num,
lora=args.lora,
lora_r=args.lora_r,
fix_ffn=args.fix_ffn,
add_visual_token=args.add_visual_token,
add_box=args.add_box,
add_pe=args.add_pe,
add_relation=args.relation,
use_format_v2=args.use_format_v2,
use_sam=args.use_sam,
enhance_data=args.enhance_data,
roi_align=args.roi_align,
roi_output_size=args.roi_output_size,
apply_mask=args.apply_mask,
)
if args.reset_llm:
llm_state_dict = model.lang_encoder.state_dict()
if args.rank == 0:
print(args)
print(image_processor)
random_seed(args.seed, args.rank)
if args.rank == 0 and args.report_to_wandb:
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.run_name,
config=vars(args),
)
device_id = args.rank % torch.cuda.device_count()
if args.ddp:
print("use ddp mode")
model = model.to(device_id)
model = DDP(model)
else:
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
# buffer_dtype=torch.float16,
)
# from transformers.models.opt.modeling_opt import OPTDecoderLayer
from open_clip.transformer import ResidualAttentionBlock
from open_flamingo.src.flamingo_lm import FlamingoLayer
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
from segment_anything.modeling.image_encoder import Block
transformer_layer_cls=[
FlamingoLayer,
ResidualAttentionBlock,
Block,
]
if args.fix_ffn:
transformer_layer_cls.append(OPTAttention)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_layer_cls,
)
if args.lora:
from torch.distributed.fsdp.wrap import _or_policy
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])
ignored_modules = [model.vision_encoder]
# ignored_modules = None
else:
ignored_modules = [model.detection_head]
# ignored_modules = None
if args.add_pe:
ignored_modules += [model.pos_enc]
# if args.use_format_v2:
# ignored_modules += [model.lang_encoder.visual_guided_lm_head]
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=fpSixteen,
device_id=torch.cuda.current_device(),
ignored_modules=ignored_modules,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
)
model = model.to(device_id)
pile_dataset = None
if args.instruct:
laion_dataset = get_data(args, image_processor, tokenizer, "instruct")
else:
laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text")
if args.pile_shards is not None:
pile_dataset = get_data(args, image_processor, tokenizer, "pile")
optim_groups = get_grouped_params(model, args)
# optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
if args.ddp:
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
# optimizer = ZeroRedundancyOptimizer(
# optim_groups,
# optimizer_class=torch.optim.AdamW,
# lr=args.learning_rate,
# parameters_as_bucket_view=True,
# )
else:
if args.optimizer == "adamw":
print("use adamw")
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
elif args.optimizer == "sgd":
print("use sgd...")
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
else:
raise NotImplementedError
total_training_steps = args.num_steps
if args.rank == 0:
logging.info(f"Total training steps: {total_training_steps}")
if args.lr_scheduler == "linear":
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
elif args.lr_scheduler == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
else:
lr_scheduler = get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps
)
if args.ddp:
scaler = GradScaler()
else:
scaler = ShardedGradScaler()
total_laion_token = 0
total_pile_token = 0
total_laion_sample = 0
total_step = 0
# check if a checkpoint exists for this run
if os.path.exists(f"{args.run_name}"):
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) == 0:
if args.rank == 0:
logging.info(f"Found no checkpoints for run {args.run_name}.")
else:
args.resume_from_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
if args.rank == 0:
logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
args.restart = False
if args.rank == 0:
logging.info("do not restart because an existed checkpoint is found")
if args.resume_from_checkpoint is not None:
if args.rank == 0:
logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}")
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
torch.distributed.barrier()
if args.ddp:
model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
# sharded_osd = checkpoint['optimizer_state_dict']
else:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
if args.reset_llm:
for key in checkpoint["model_state_dict"]:
if key.startswith("lang_encoder"):
if args.rank == 0:
logging.info(f"reset {key}")
llm_key = key.replace("lang_encoder.", "")
checkpoint["model_state_dict"][key] = llm_state_dict[llm_key]
model_state_dict = model.state_dict()
for key in checkpoint["model_state_dict"].keys():
if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape:
if args.rank == 0:
logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}')
checkpoint["model_state_dict"][key] = model_state_dict[key].clone()
del model_state_dict
model.load_state_dict(checkpoint["model_state_dict"], False)
# sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups)
if not args.restart:
# optimizer.load_state_dict(sharded_osd)
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
# scaler.load_state_dict(checkpoint["scaler_state_dict"])
total_laion_token = checkpoint.get("total_laion_token", 0)
total_pile_token = checkpoint.get("total_pile_token", 0)
total_laion_sample = checkpoint.get("total_laion_sample", 0)
total_step = checkpoint.get("total_step", 0)
if args.rank == 0:
logging.info("load training statistics...")
else:
if args.rank == 0:
logging.info("restart training / finetuning. only load model weight...")
del checkpoint
if args.reset_llm:
del llm_state_dict
torch.cuda.empty_cache()
torch.distributed.barrier()
model.train()
if args.rank == 0:
if not os.path.exists(args.run_name):
os.makedirs(args.run_name)
writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog"))
else:
writer = None
laion_dataset.set_epoch(total_step)
laion_loader = laion_dataset.dataloader
if pile_dataset is not None:
pile_dataset.set_epoch(total_step)
pile_loader = pile_dataset.dataloader
else:
pile_loader = FakeDataloader()
train_one_epoch(
args=args,
model=model,
tokenizer=tokenizer,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
laion_loader=laion_loader,
pile_loader=pile_loader,
device_id=device_id,
writer=writer,
scaler=scaler,
optim_groups=optim_groups,
total_laion_token=total_laion_token,
total_pile_token=total_pile_token,
total_laion_sample=total_laion_sample,
total_step=total_step,
)
if __name__ == "__main__":
main()