chendl's picture
Add application file
0b7b08a
import time
from contextlib import suppress
import numpy as np
import torch
from tqdm import tqdm
import datetime
import os
import gc
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from torch.utils.tensorboard import SummaryWriter
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(message)s',
datefmt='%m/%d %I:%M:%S',
)
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == "bf16":
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
return cast_dtype
def get_autocast(precision):
if precision == "amp_fp16":
return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
elif precision == "amp_bfloat16" or precision == "amp_bf16":
# amp_bfloat16 is more stable than amp float16 for clip training
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return suppress
def get_sync(model, flag):
if flag:
return suppress
else:
return lambda: model.no_sync()
def train_one_epoch(
args,
model,
laion_loader,
pile_loader,
tokenizer,
optimizer,
lr_scheduler,
device_id,
writer: SummaryWriter,
optim_groups,
scaler,
total_laion_token: int,
total_pile_token: int,
total_laion_sample: int,
total_step: int,
):
world_size = torch.distributed.get_world_size()
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
if args.add_box:
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
if args.use_format_v2:
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
if args.rank == 0:
logging.info(f"train from: {total_step} step")
model.train()
# loop through dataloader
last_logging_step = total_step
last_save_step = total_step
for num_steps, (batch_laion, batch_pile) in tqdm(
enumerate(zip(laion_loader, pile_loader)),
disable=args.rank != 0 or "SLURM_PROCID" in os.environ,
total=args.num_steps * args.gradient_accumulation_steps,
initial=total_step * args.gradient_accumulation_steps,
):
#### LAION FORWARD PASS ####
images = (
batch_laion[0]
.to(device_id, dtype=cast_dtype, non_blocking=True)
.unsqueeze(1)
.unsqueeze(1)
)
image_nums = batch_laion[1]
image_start_index_list = batch_laion[2]
# TODO: OPT model: input_ids is not started with </s> while input_ids2 is?
input_ids = batch_laion[3].to(device_id, non_blocking=True).long()
attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True)
added_bbox_list = [x.to(device_id) for x in batch_laion[5]] # list object
total_laion_token += int(attention_mask.sum().long()) * world_size
total_laion_sample += sum(image_nums) * world_size
labels = input_ids.clone()
if args.add_box:
labels[input_ids == visual_token_id] = -100
labels[input_ids == box_token_id] = -100
labels[input_ids == endofattr_token_id] = -100
if args.use_format_v2:
labels[input_ids == previsual_token_id] = -100
labels[input_ids == prebox_token_id] = -100
labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
labels[torch.roll(input_ids == box_token_id, 1)] = -100
labels[:, 0] = -100
labels[input_ids == tokenizer.pad_token_id] = -100
labels[input_ids == media_token_id] = -100
labels[input_ids == endofmedia_token_id] = -100
labels.to(device_id)
current_laion_num = input_ids.shape[0]
#### PILE FORWARD PASS ####
if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None:
input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long()
attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True)
input_length = input_ids.shape[-1]
input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1)
attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1)
labels2 = input_ids2.clone()
labels2[labels2 == tokenizer.pad_token_id] = -100
labels2[:, 0] = -100
labels2.to(device_id)
if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1:
image_nums = image_nums + [0] * len(input_ids2)
image_start_index_list = image_start_index_list + [[]] * len(input_ids2)
input_ids = torch.cat([input_ids, input_ids2], dim=0)
attention_mask = torch.cat([attention_mask, attention_mask2], dim=0)
labels = torch.cat([labels, labels2], dim=0)
total_pile_token += int(attention_mask2.sum().long()) * world_size
else:
del input_ids2
del attention_mask2
del labels2
if args.instruct:
answer_token_id = tokenizer(" Answer").input_ids[0]
answer_token_loc = (input_ids == answer_token_id).nonzero()
for batch_idx, idx in answer_token_loc:
labels[batch_idx][:idx+2] = -100
if args.relation and not args.instruct:
relations = batch_laion[6]
else:
relations = None
if len(added_bbox_list) == 0:
added_bbox_list = None
update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1
# do_sync = get_sync(model, update_flag)
with autocast():
# modify:
# /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py
# /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py
# CrossEntropyLoss(reduction="none")
outputs = model(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=added_bbox_list,
add_box=args.add_box,
relations=relations,
)
loss_total = outputs.loss.reshape(labels.shape[0], -1)
loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1)
loss_sample_for_laion = loss_sample[:current_laion_num]
nan_mask = torch.isnan(loss_sample_for_laion)
if nan_mask.sum() > 0:
logging.warning(f"caption NaN: {nan_mask}")
if nan_mask.sum() == len(loss_sample_for_laion) or not model.valid:
logging.info("WARNING: skip this caption loss due to some error")
loss_laion = torch.tensor(0.0).cuda()
else:
loss_laion = loss_sample_for_laion[~nan_mask].mean()
loss_caption = loss_laion
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
if current_laion_num != loss_sample.shape[0]:
loss_pile = loss_sample[current_laion_num:].mean()
else:
loss_pile = torch.tensor(0.0).cuda()
divided_loss_pile = loss_pile / args.gradient_accumulation_steps
if "detection_losses" in outputs:
loss_det = outputs["detection_losses"]["loss"]
loss_iou = outputs["detection_losses"]["loss_iou"]
loss_obj = outputs["detection_losses"]["loss_obj"]
loss_cls = outputs["detection_losses"]["loss_cls"]
else:
loss_det = torch.tensor(0.0).cuda()
loss_iou = torch.tensor(0.0).cuda()
loss_obj = torch.tensor(0.0).cuda()
loss_cls = torch.tensor(0.0).cuda()
if "loss_dict" in outputs:
visual_loss_iou = outputs["loss_dict"][0]["loss_iou"]
previsual_loss_iou = outputs["loss_dict"][1]["loss_iou"]
visual_loss_obj = outputs["loss_dict"][0]["loss_obj"]
previsual_loss_obj = outputs["loss_dict"][1]["loss_obj"]
else:
visual_loss_iou = torch.tensor(0.0).cuda()
previsual_loss_iou = torch.tensor(0.0).cuda()
visual_loss_obj = torch.tensor(0.0).cuda()
previsual_loss_obj = torch.tensor(0.0).cuda()
divided_loss_det = loss_det / args.gradient_accumulation_steps
loss_rel = outputs.get("rel_loss", torch.tensor(0.0).cuda())
divided_loss_rel = loss_rel / args.gradient_accumulation_steps
loss = (
divided_loss_laion * args.loss_multiplier_laion +
divided_loss_pile * args.loss_multiplier_pile +
divided_loss_det * args.loss_multiplier_det +
divided_loss_rel * args.loss_multiplier_rel
)
scaler.scale(loss).backward()
# for logging only
loss = (
loss_laion * args.loss_multiplier_laion
+ loss_pile * args.loss_multiplier_pile
+ loss_det * args.loss_multiplier_det
+ loss_rel * args.loss_multiplier_rel
).detach()
# step optimizer and log
if update_flag:
#### MASK GRADIENTS FOR EMBEDDINGS ####
# Note (anas): Do not apply weight decay to embeddings as it will break this function.
# ! not an important point
# if args.ddp:
# def mask_embedding(m):
# if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
# zero_mask = torch.zeros_like(m.weight.grad)
# zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
# zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id])
# m.weight.grad = m.weight.grad * zero_mask
# model.apply(mask_embedding)
total_step += 1
scaler.unscale_(optimizer)
if args.ddp:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
else:
model.clip_grad_norm_(1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
# https://github.com/facebookresearch/fairscale/issues/627
model.zero_grad(set_to_none=True)
if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step:
last_logging_step = total_step
global_step = total_step
lr = optimizer.param_groups[0]["lr"]
writer.add_scalar("lr", lr, global_step)
writer.add_scalar("scale", scaler.get_scale(), global_step)
writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step)
writer.add_scalar("loss_laion", loss_caption.item(), global_step)
writer.add_scalar("loss_pile", loss_pile.item(), global_step)
writer.add_scalar("loss", loss.item(), global_step)
writer.add_scalar("loss_det", loss_det.item(), global_step)
writer.add_scalar("loss_iou", loss_iou.item(), global_step)
writer.add_scalar("loss_obj", loss_obj.item(), global_step)
writer.add_scalar("loss_cls", loss_cls.item(), global_step)
if loss_rel.item() != 0:
writer.add_scalar("loss_rel", loss_rel.item(), global_step)
if args.use_format_v2:
writer.add_scalar("loss_iou_visual", visual_loss_iou.item(), global_step)
writer.add_scalar("loss_obj_visual", visual_loss_obj.item(), global_step)
writer.add_scalar("loss_iou_previsual", previsual_loss_iou.item(), global_step)
writer.add_scalar("loss_obj_previsual", previsual_loss_obj.item(), global_step)
global_sample_num = total_laion_sample
writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num)
writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num)
writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num)
writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num)
writer.add_scalar("loss_det_vs_sample_num", loss_det.item(), global_sample_num)
writer.add_scalar("loss_iou_vs_sample_num", loss_iou.item(), global_sample_num)
writer.add_scalar("loss_obj_vs_sample_num", loss_obj.item(), global_sample_num)
if loss_rel.item() != 0:
writer.add_scalar("loss_rel_vs_sample_num", loss_rel.item(), global_sample_num)
writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num)
writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token)
writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token)
writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token)
writer.add_scalar("loss_det_vs_token", loss_det.item(), total_laion_token)
writer.add_scalar("loss_iou_vs_token", loss_iou.item(), total_laion_token)
writer.add_scalar("loss_obj_vs_token", loss_obj.item(), total_laion_token)
writer.add_scalar("loss_cls_vs_token", loss_cls.item(), total_laion_token)
if loss_rel.item() != 0:
writer.add_scalar("loss_rel_vs_token", loss_rel.item(), total_laion_token)
total_token = total_laion_token + total_pile_token
writer.add_scalar("sample_num", global_sample_num, global_step)
writer.add_scalar("total_laion_token", total_laion_token, global_step)
writer.add_scalar("total_pile_token", total_pile_token, global_step)
writer.add_scalar("total_token", total_token, global_step)
logging.info(
f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} // laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // iou: {loss_iou.item():.4f} // obj: {loss_obj.item():.4f} // previsual_obj: {previsual_loss_obj.item():.4f} // visual_obj: {visual_loss_obj.item():.4f} // previsual_iou: {previsual_loss_iou.item():.4f} // visual_iou: {visual_loss_iou.item():.4f} // lr: {lr:.2e} // scale: {scaler.get_scale()}"
)
if total_step % args.save_interval == 0 and total_step != last_save_step:
last_save_step = total_step
torch.distributed.barrier()
if args.ddp:
cpu_state = model.state_dict()
# if args.rank == 0:
# optimizer_state = optimizer.state_dict()
else:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
torch.distributed.barrier()
# https://pytorch.org/docs/1.12/fsdp.html
# need to pass optim_groups as optim_input
# optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups)
if args.rank == 0:
checkpoint_dict = {
"model_state_dict": cpu_state,
# "optimizer_state_dict": optimizer_state,
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"total_pile_token": total_pile_token,
"total_laion_token": total_laion_token,
"total_laion_sample": total_laion_sample,
"total_step": total_step,
}
logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt")
torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt")
del checkpoint_dict
if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0:
try:
os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt")
except:
pass
torch.distributed.barrier()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count