import os from starvector.util import ( set_env_vars, flatten_dict, get_exp_id, instantiate_from_config, generate_id_name_eval, get_last_checkpoint, model_summary_table, copy_code, ) # set_env_vars() from starvector.train.util import ( save_checkpoint, get_optimizer, init_distributed_mode, setup_train_env_variables, load_fsdp_plugin, apply_gradient_checkpointing, ) import logging import math from torch.utils.data import DataLoader from transformers import get_scheduler from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration from tqdm.auto import tqdm from omegaconf import OmegaConf import os import time from starvector.metrics.util import AverageMeter from util import save_checkpoint, get_optimizer from starvector.util import get_output_dir from starvector.model.builder import model_builder from safetensors.torch import load_file as load_safetensors from starvector.util import get_config import torch from starvector.train.util import load_checkpoint, is_deepspeed, consolidate_deepspeed_checkpoint logger = get_logger(__name__, log_level="INFO") def validate(model, dataloader, accelerator): loss_meter = AverageMeter() model.eval() pbar = tqdm(total=len(dataloader), ncols=100, desc="Processing", disable=not accelerator.is_local_main_process) with torch.no_grad(): for i, batch in enumerate(dataloader): batch_size = len(batch["image"]) loss = model(batch) loss_meter.update(loss.detach().item(), batch_size) pbar.update(1) val_loss = ( accelerator.gather(torch.tensor(loss_meter.avg).to(accelerator.device)) .float() .mean() .item() ) accelerator.wait_for_everyone() pbar.close() return val_loss def main(config=None): print(f"Experiment config: {config}") set_env_vars() exp_id = get_exp_id(config) output_dir = get_output_dir() logging_dir = os.path.join(output_dir, config.data.train.params.dataset_name, exp_id) if os.path.exists(logging_dir) and not config.training.resume_from_checkpoint: config.training.resume_from_checkpoint = get_last_checkpoint(logging_dir) config.training.continue_training = True # Flatten config dict for logging it log_config = flatten_dict(OmegaConf.to_container(config, resolve=True)) log_config['logging_dir'] = logging_dir # Add logging dir to config if config.fsdp.enable: init_distributed_mode(config) setup_train_env_variables(config) # --------------- Datasets --------------- train_dataset = instantiate_from_config(config.data.train) test_dataset = instantiate_from_config(config.data.test) train_dataloader = DataLoader(train_dataset, batch_size=config.data.train.batch_size, shuffle=True, num_workers=config.data.num_workers, pin_memory=True) test_dataloader = DataLoader(test_dataset, batch_size=config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers, pin_memory=True) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.training.gradient_accumulation_steps) max_train_steps = config.training.n_epochs * num_update_steps_per_epoch global_step = 0 first_epoch = 0 model = model_builder(config) # Instantiate the model, fsdp and accelerator if config.training.resume_from_checkpoint: if not config.fsdp.enable: if is_deepspeed(config.training.resume_from_checkpoint): if accelerator.is_main_process: consolidate_deepspeed_checkpoint(config.training.resume_from_checkpoint) accelerator.wait_for_everyone() model = load_checkpoint(model, config.training.resume_from_checkpoint) else: model.load_state_dict(torch.load(os.path.join(config.training.resume_from_checkpoint, "pytorch_model_fsdp.bin")), strict=False) if config.training.continue_training: global_step = int(os.path.basename(config.training.resume_from_checkpoint).split("-")[1]) resume_global_step = global_step * config.training.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch resume_step = resume_global_step % (num_update_steps_per_epoch * config.training.gradient_accumulation_steps) else: global_step = 0 first_epoch = 0 resume_step = 0 print("Loaded checkpoint but not updating global step") if config.fsdp.enable: fsdp_plugin = load_fsdp_plugin(config, model) else: fsdp_plugin = None # Define accelerator kwargs_handler = None accelerator = Accelerator( gradient_accumulation_steps=config.training.gradient_accumulation_steps, mixed_precision=config.training.model_precision, log_with="wandb" if config.project.use_wandb else None, project_dir=logging_dir, project_config=ProjectConfiguration(logging_dir=logging_dir), step_scheduler_with_optimizer=False, fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handler ) # --------------- Logging --------------- if accelerator.is_main_process: if config.project.use_wandb: import wandb wandb.init(name=exp_id, project=config.project.project, entity=config.project.entity, config=log_config) accelerator.init_trackers( project_name=config.project.project, ) config.project.wandb_run_id = wandb.run.id else: run = os.path.split(__file__)[-1].split(".")[0] accelerator.init_trackers(run) if logging_dir is not None: os.makedirs(logging_dir, exist_ok=True) # Copy code and dependency versions if config.project.copy_code: out_dir = os.path.join(logging_dir, "code") copy_code(os.path.join(os.path.dirname(__file__), "..", ".."), out_dir) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=True) total_batch_size = config.data.train.batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps if accelerator.is_main_process and config.project.use_wandb: wandb.log({"total_batch_size": total_batch_size}) wandb.log({"num_update_steps_per_epoch": num_update_steps_per_epoch}) wandb.log({"max_train_steps": max_train_steps}) # accelerate prepare model model = accelerator.prepare(model) # activation/gradient checkpointing if config.training.use_gradient_checkpointing: print("apply gradient checkpointing") model = apply_gradient_checkpointing(model) optimizer = get_optimizer(config, model) if accelerator.is_main_process: print("Train dataset length: ", len(train_dataset)) print("Test dataset length: ", len(test_dataset)) # --------------- Training config --------------- lr_scheduler = get_scheduler( config.training.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.training.lr_warmup_steps * config.training.gradient_accumulation_steps, num_training_steps= (len(train_dataloader) * config.training.n_epochs), ) optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare( optimizer, train_dataloader, test_dataloader, lr_scheduler ) loss_meter = AverageMeter() if accelerator.is_main_process: model_summary_table(model) if not os.path.exists(os.path.join(logging_dir, 'config.yaml')): with open(os.path.join(logging_dir, 'config.yaml'), 'w') as f: OmegaConf.save(config, f) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {config.training.n_epochs}") logger.info(f" Instantaneous batch size per device = {config.data.train.batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_train_steps}") # --------------- Generation/Validation arguments --------------- generation_args = config.generation # Need to set some experiment specific arguments generation_args.project_name = config.project.project generation_args.use_wandb = config.project.use_wandb generation_args.id = generate_id_name_eval(generation_args) generation_args.out_path = os.path.join(logging_dir, generation_args.id) generation_args.start_generation_at_step = config.generation.start_generation_at_step generation_args.metrics = config.metrics os.makedirs(generation_args.out_path, exist_ok=True) # --------------- Training loop --------------- total_steps = num_update_steps_per_epoch * config.training.n_epochs progress_bar = tqdm(total=total_steps, disable=not accelerator.is_local_main_process) progress_bar.set_description(f"Training Progress") for epoch in range(config.training.n_epochs): model.train() for step, batch in enumerate(train_dataloader): s_time = time.time() if config.training.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if step % config.training.gradient_accumulation_steps == 0: progress_bar.update(1) continue with accelerator.accumulate(model): loss = model(batch) accelerator.backward(loss) loss_meter.update(loss.detach().item(), batch['image'].shape[0]) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if global_step % config.training.checkpointing_steps == 0: accelerator.wait_for_everyone() val_loss = validate(model, test_dataloader, accelerator) accelerator.log({"val_loss": val_loss}, step=global_step) save_checkpoint(accelerator, model, global_step, logging_dir, config.training.checkpoints_total_limit) model.train() logs = { "loss": loss_meter.val, "last_lr": lr_scheduler.get_last_lr()[0], "step": global_step, "step_time": time.time() - s_time, "epoch": epoch} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) accelerator.end_training() if __name__ == "__main__": main(config=get_config())