import os import sys import wandb import argparse import numpy as np from tqdm import tqdm from PIL import Image from datetime import datetime from zoneinfo import ZoneInfo from time import gmtime, strftime from collections import OrderedDict import random import torch import torch.nn as nn import torch.optim as optim import torch.backends.cudnn as cudnn from torchvision.transforms import CenterCrop from torch.utils.data import ConcatDataset, DataLoader, WeightedRandomSampler import torchvision.transforms as torch_transforms from torchvision.utils import make_grid from src.losses import ( ContextualLoss, ContextualLoss_forward, Perceptual_loss, consistent_loss_fn, discriminator_loss_fn, generator_loss_fn, l1_loss_fn, smoothness_loss_fn, ) from src.models.CNN.GAN_models import Discriminator_x64 from src.models.CNN.ColorVidNet import ColorVidNet from src.models.CNN.FrameColor import frame_colorization from src.models.CNN.NonlocalNet import WeightedAverage_color, NonlocalWeightedAverage, WarpNet, WarpNet_new from src.models.vit.embed import EmbedModel from src.models.vit.config import load_config from src.data import transforms from src.data.dataloader import VideosDataset, VideosDataset_ImageNet from src.utils import CenterPad_threshold from src.utils import ( TimeHandler, RGB2Lab, ToTensor, Normalize, LossHandler, WarpingLayer, uncenter_l, tensor_lab2rgb, print_num_params, ) from src.scheduler import PolynomialLR parser = argparse.ArgumentParser() parser.add_argument("--video_data_root_list", type=str, default="dataset") parser.add_argument("--flow_data_root_list", type=str, default="flow") parser.add_argument("--mask_data_root_list", type=str, default="mask") parser.add_argument("--data_root_imagenet", default="imagenet", type=str) parser.add_argument("--annotation_file_path", default="dataset/annotation.csv", type=str) parser.add_argument("--imagenet_pairs_file", default="imagenet_pairs.txt", type=str) parser.add_argument("--gpu_ids", type=str, default="0,1,2,3", help="separate by comma") parser.add_argument("--workers", type=int, default=0) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--image_size", type=int, default=[384, 384]) parser.add_argument("--ic", type=int, default=7) parser.add_argument("--epoch", type=int, default=40) parser.add_argument("--resume_epoch", type=int, default=0) parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--load_pretrained_model", type=bool, default=False) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--beta1", type=float, default=0.5) parser.add_argument("--lr_step", type=int, default=1) parser.add_argument("--lr_gamma", type=float, default=0.9) parser.add_argument("--checkpoint_dir", type=str, default="checkpoints") parser.add_argument("--checkpoint_step", type=int, default=500) parser.add_argument("--real_reference_probability", type=float, default=0.7) parser.add_argument("--nonzero_placeholder_probability", type=float, default=0.0) parser.add_argument("--domain_invariant", type=bool, default=False) parser.add_argument("--weigth_l1", type=float, default=2.0) parser.add_argument("--weight_contextual", type=float, default="0.5") parser.add_argument("--weight_perceptual", type=float, default="0.02") parser.add_argument("--weight_smoothness", type=float, default="5.0") parser.add_argument("--weight_gan", type=float, default="0.5") parser.add_argument("--weight_nonlocal_smoothness", type=float, default="0.0") parser.add_argument("--weight_nonlocal_consistent", type=float, default="0.0") parser.add_argument("--weight_consistent", type=float, default="0.05") parser.add_argument("--luminance_noise", type=float, default="2.0") parser.add_argument("--permute_data", type=bool, default=True) parser.add_argument("--contextual_loss_direction", type=str, default="forward", help="forward or backward matching") parser.add_argument("--batch_accum_size", type=int, default=10) parser.add_argument("--epoch_train_discriminator", type=int, default=3) parser.add_argument("--vit_version", type=str, default="vit_tiny_patch16_384") parser.add_argument("--use_dummy", type=bool, default=False) parser.add_argument("--use_wandb", type=bool, default=False) parser.add_argument("--use_feature_transform", type=bool, default=False) parser.add_argument("--head_out_idx", type=str, default="8,9,10,11") parser.add_argument("--wandb_token", type=str, default="") parser.add_argument("--wandb_name", type=str, default="") def load_data(): transforms_video = [ CenterCrop(opt.image_size), RGB2Lab(), ToTensor(), Normalize(), ] train_dataset_videos = [ VideosDataset( video_data_root=video_data_root, flow_data_root=flow_data_root, mask_data_root=mask_data_root, imagenet_folder=opt.data_root_imagenet, annotation_file_path=opt.annotation_file_path, image_size=opt.image_size, image_transform=transforms.Compose(transforms_video), real_reference_probability=opt.real_reference_probability, nonzero_placeholder_probability=opt.nonzero_placeholder_probability, ) for video_data_root, flow_data_root, mask_data_root in zip( opt.video_data_root_list, opt.flow_data_root_list, opt.mask_data_root_list ) ] transforms_imagenet = [CenterPad_threshold(opt.image_size), RGB2Lab(), ToTensor(), Normalize()] extra_reference_transform = [ torch_transforms.RandomHorizontalFlip(0.5), torch_transforms.RandomResizedCrop(480, (0.98, 1.0), ratio=(0.8, 1.2)), ] train_dataset_imagenet = VideosDataset_ImageNet( imagenet_data_root=opt.data_root_imagenet, pairs_file=opt.imagenet_pairs_file, image_size=opt.image_size, transforms_imagenet=transforms_imagenet, distortion_level=4, brightnessjitter=5, nonzero_placeholder_probability=opt.nonzero_placeholder_probability, extra_reference_transform=extra_reference_transform, real_reference_probability=opt.real_reference_probability, ) # video_training_length = sum([len(dataset) for dataset in train_dataset_videos]) # imagenet_training_length = len(train_dataset_imagenet) # dataset_training_length = sum([dataset.real_len for dataset in train_dataset_videos]) + +train_dataset_imagenet.real_len dataset_combined = ConcatDataset(train_dataset_videos + [train_dataset_imagenet]) # sampler=[] # seed_sampler=int.from_bytes(os.urandom(4),"big") # random.seed(seed_sampler) # for idx in range(opt.epoch): # sampler = sampler + random.sample(range(dataset_training_length),dataset_training_length) # wandb.log({"Sampler_Seed":seed_sampler}) # sampler = sampler+WeightedRandomSampler([1] * video_training_length + [1] * imagenet_training_length, dataset_training_length*opt.epoch) # video_training_length = sum([len(dataset) for dataset in train_dataset_videos]) # dataset_training_length = sum([dataset.real_len for dataset in train_dataset_videos]) # dataset_combined = ConcatDataset(train_dataset_videos) # sampler = WeightedRandomSampler([1] * video_training_length, dataset_training_length * opt.epoch) data_loader = DataLoader(dataset_combined, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers) return data_loader def training_logger(): if (total_iter % opt.checkpoint_step == 0) or (total_iter == len(data_loader)): train_loss_dict = {"train/" + str(k): v / loss_handler.count_sample for k, v in loss_handler.loss_dict.items()} train_loss_dict["train/opt_g_lr_1"] = step_optim_scheduler_g.get_last_lr()[0] train_loss_dict["train/opt_g_lr_2"] = step_optim_scheduler_g.get_last_lr()[1] train_loss_dict["train/opt_d_lr"] = step_optim_scheduler_d.get_last_lr()[0] alert_text = f"l1_loss: {l1_loss.item()}\npercep_loss: {perceptual_loss.item()}\nctx_loss: {contextual_loss_total.item()}\ncst_loss: {consistent_loss.item()}\nsm_loss: {smoothness_loss.item()}\ntotal: {total_loss.item()}" if opt.use_wandb: wandb.log(train_loss_dict) wandb.alert(title=f"Progress training #{total_iter}", text=alert_text) for idx in range(I_predict_rgb.shape[0]): concated_I = make_grid( [(I_predict_rgb[idx] * 255), (I_reference_rgb[idx] * 255), (I_current_rgb[idx] * 255)], nrow=3 ) wandb_concated_I = wandb.Image( concated_I, caption="[LEFT] Predict, [CENTER] Reference, [RIGHT] Ground truth\n[REF] {}, [FRAME] {}".format( ref_path[idx], curr_frame_path[idx] ), ) wandb.log({f"example_{idx}": wandb_concated_I}) torch.save( nonlocal_net.state_dict(), os.path.join(opt.checkpoint_dir, "nonlocal_net_iter.pth"), ) torch.save( colornet.state_dict(), os.path.join(opt.checkpoint_dir, "colornet_iter.pth"), ) torch.save( discriminator.state_dict(), os.path.join(opt.checkpoint_dir, "discriminator_iter.pth"), ) torch.save(embed_net.state_dict(), os.path.join(opt.checkpoint_dir, "embed_net_iter.pth")) loss_handler.reset() def load_params(ckpt_file): params = torch.load(ckpt_file) new_params = [] for key, value in params.items(): new_params.append((key, value)) return OrderedDict(new_params) def parse(parser, save=True): opt = parser.parse_args() args = vars(opt) print("------------------------------ Options -------------------------------") for k, v in sorted(args.items()): print("%s: %s" % (str(k), str(v))) print("-------------------------------- End ---------------------------------") if save: file_name = os.path.join("opt.txt") with open(file_name, "wt") as opt_file: opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n") opt_file.write("------------------------------ Options -------------------------------\n") for k, v in sorted(args.items()): opt_file.write("%s: %s\n" % (str(k), str(v))) opt_file.write("-------------------------------- End ---------------------------------\n") return opt def gpu_setup(): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" cudnn.benchmark = True torch.cuda.set_device(opt.gpu_ids[0]) device = torch.device("cuda") print("running on GPU", opt.gpu_ids) return device if __name__ == "__main__": ############################################## SETUP ############################################### torch.multiprocessing.set_start_method("spawn", force=True) # =============== GET PARSER OPTION ================ opt = parse(parser) opt.video_data_root_list = opt.video_data_root_list.split(",") opt.flow_data_root_list = opt.flow_data_root_list.split(",") opt.mask_data_root_list = opt.mask_data_root_list.split(",") opt.gpu_ids = list(map(int, opt.gpu_ids.split(","))) opt.head_out_idx = list(map(int, opt.head_out_idx.split(","))) n_dim_output = 3 if opt.use_feature_transform else 4 assert len(opt.head_out_idx) == 4, "Size of head_out_idx must be 4" os.makedirs(opt.checkpoint_dir, exist_ok=True) # =================== INIT WANDB =================== if opt.use_wandb: print("Save images to Wandb") if opt.wandb_token != "": try: wandb.login(key=opt.wandb_token) except: pass wandb.init( project="video-colorization", name=f"{opt.wandb_name} {datetime.now(tz=ZoneInfo('Asia/Ho_Chi_Minh')).strftime('%Y/%m/%d_%H-%M-%S')}", ) # ================== SETUP DEVICE ================== # torch.multiprocessing.set_start_method("spawn", force=True) # device = gpu_setup() device = "cuda" if torch.cuda.is_available() else "cpu" # =================== VIT CONFIG =================== cfg = load_config() model_cfg = cfg["model"][opt.vit_version] model_cfg["image_size"] = (384, 384) model_cfg["backbone"] = opt.vit_version model_cfg["dropout"] = 0.0 model_cfg["drop_path_rate"] = 0.1 model_cfg["n_cls"] = 10 ############################################ LOAD DATA ############################################# if opt.use_dummy: H, W = 384, 384 I_last_lab = torch.rand(opt.batch_size, 3, H, W) I_current_lab = torch.rand(opt.batch_size, 3, H, W) I_reference_lab = torch.rand(opt.batch_size, 3, H, W) flow_forward = torch.rand(opt.batch_size, 2, H, W) mask = torch.rand(opt.batch_size, 1, H, W) placeholder_lab = torch.rand(opt.batch_size, 3, H, W) self_ref_flag = torch.rand(opt.batch_size, 3, H, W) data_loader = [ [I_last_lab, I_current_lab, I_reference_lab, flow_forward, mask, placeholder_lab, self_ref_flag, None, None, None] for _ in range(10) ] else: data_loader = load_data() ########################################## DEFINE NETWORK ########################################## print("-" * 59) print("| TYPE | Model name | Num params |") print("-" * 59) colornet = ColorVidNet(opt.ic).to(device) colornet_params = print_num_params(colornet) if opt.use_feature_transform: nonlocal_net = WarpNet().to(device) else: nonlocal_net = WarpNet_new(model_cfg["d_model"]).to(device) nonlocal_net_params = print_num_params(nonlocal_net) discriminator = Discriminator_x64(ndf=64).to(device) discriminator_params = print_num_params(discriminator) weighted_layer_color = WeightedAverage_color().to(device) weighted_layer_color_params = print_num_params(weighted_layer_color) nonlocal_weighted_layer = NonlocalWeightedAverage().to(device) nonlocal_weighted_layer_params = print_num_params(nonlocal_weighted_layer) warping_layer = WarpingLayer(device=device).to(device) warping_layer_params = print_num_params(warping_layer) embed_net = EmbedModel(model_cfg, head_out_idx=opt.head_out_idx, n_dim_output=n_dim_output, device=device) embed_net_params = print_num_params(embed_net) print("-" * 59) print( f"| TOTAL | | {('{:,}'.format(colornet_params+nonlocal_net_params+discriminator_params+weighted_layer_color_params+nonlocal_weighted_layer_params+warping_layer_params+embed_net_params)).rjust(10)} |" ) print("-" * 59) if opt.use_wandb: wandb.watch(discriminator, log="all", log_freq=opt.checkpoint_step, idx=0) wandb.watch(embed_net, log="all", log_freq=opt.checkpoint_step, idx=1) wandb.watch(colornet, log="all", log_freq=opt.checkpoint_step, idx=2) wandb.watch(nonlocal_net, log="all", log_freq=opt.checkpoint_step, idx=3) # ============= USE PRETRAINED OR NOT ============== if opt.load_pretrained_model: # pretrained_path = "/workspace/video_colorization/ckpt_folder_ver_1_vit_small_patch16_384" nonlocal_net.load_state_dict(load_params(os.path.join(opt.checkpoint_dir, "nonlocal_net_iter.pth"))) colornet.load_state_dict(load_params(os.path.join(opt.checkpoint_dir, "colornet_iter.pth"))) discriminator.load_state_dict(load_params(os.path.join(opt.checkpoint_dir, "discriminator_iter.pth"))) embed_net_params = load_params(os.path.join(opt.checkpoint_dir, "embed_net_iter.pth")) embed_net_params.pop("vit.heads_out") embed_net.load_state_dict(embed_net_params) ###################################### DEFINE LOSS FUNCTIONS ####################################### perceptual_loss_fn = Perceptual_loss(opt.domain_invariant, opt.weight_perceptual) contextual_loss = ContextualLoss().to(device) contextual_forward_loss = ContextualLoss_forward().to(device) ######################################## DEFINE OPTIMIZERS ######################################### optimizer_g = optim.AdamW( [ {"params": nonlocal_net.parameters(), "lr": opt.lr}, {"params": colornet.parameters(), "lr": 2 * opt.lr}, {"params": embed_net.parameters(), "lr": opt.lr}, ], betas=(0.5, 0.999), eps=1e-5, amsgrad=True, ) optimizer_d = optim.AdamW( filter(lambda p: p.requires_grad, discriminator.parameters()), lr=opt.lr, betas=(0.5, 0.999), amsgrad=True, ) step_optim_scheduler_g = PolynomialLR( optimizer_g, step_size=opt.lr_step, iter_warmup=0, iter_max=len(data_loader) * opt.epoch, power=0.9, min_lr=1e-8, ) step_optim_scheduler_d = PolynomialLR( optimizer_d, step_size=opt.lr_step, iter_warmup=0, iter_max=len(data_loader) * opt.epoch, power=0.9, min_lr=1e-8, ) ########################################## DEFINE OTHERS ########################################### downsampling_by2 = nn.AvgPool2d(kernel_size=2).to(device) timer_handler = TimeHandler() loss_handler = LossHandler() # Handle loss value ############################################## TRAIN ############################################### total_iter = 0 for epoch_num in range(1, opt.epoch + 1): # if opt.use_wandb: # wandb.log({"Current_trainning_epoch": epoch_num}) with tqdm(total=len(data_loader), position=0, leave=True) as pbar: for iter, sample in enumerate(data_loader): timer_handler.compute_time("load_sample") total_iter += 1 # =============== LOAD DATA SAMPLE ================ ( I_last_lab, ######## (3, H, W) I_current_lab, ##### (3, H, W) I_reference_lab, ### (3, H, W) flow_forward, ###### (2, H, W) mask, ############## (1, H, W) placeholder_lab, ### (3, H, W) self_ref_flag, ##### (3, H, W) prev_frame_path, curr_frame_path, ref_path, ) = sample I_last_lab = I_last_lab.to(device) I_current_lab = I_current_lab.to(device) I_reference_lab = I_reference_lab.to(device) flow_forward = flow_forward.to(device) mask = mask.to(device) placeholder_lab = placeholder_lab.to(device) self_ref_flag = self_ref_flag.to(device) I_last_l = I_last_lab[:, 0:1, :, :] I_last_ab = I_last_lab[:, 1:3, :, :] I_current_l = I_current_lab[:, 0:1, :, :] I_current_ab = I_current_lab[:, 1:3, :, :] I_reference_l = I_reference_lab[:, 0:1, :, :] I_reference_ab = I_reference_lab[:, 1:3, :, :] I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)) _load_sample_time = timer_handler.compute_time("load_sample") timer_handler.compute_time("forward_model") features_B = embed_net(I_reference_rgb) _, B_feat_1, B_feat_2, B_feat_3 = features_B # ================== COLORIZATION ================== # The last frame I_last_ab_predict, I_last_nonlocal_lab_predict = frame_colorization( IA_l=I_last_l, IB_lab=I_reference_lab, IA_last_lab=placeholder_lab, features_B=features_B, embed_net=embed_net, colornet=colornet, nonlocal_net=nonlocal_net, luminance_noise=opt.luminance_noise, ) I_last_lab_predict = torch.cat((I_last_l, I_last_ab_predict), dim=1) # The current frame I_current_ab_predict, I_current_nonlocal_lab_predict = frame_colorization( IA_l=I_current_l, IB_lab=I_reference_lab, IA_last_lab=I_last_lab_predict, features_B=features_B, embed_net=embed_net, colornet=colornet, nonlocal_net=nonlocal_net, luminance_noise=opt.luminance_noise, ) I_current_lab_predict = torch.cat((I_last_l, I_current_ab_predict), dim=1) # ================ UPDATE GENERATOR ================ if opt.weight_gan > 0: optimizer_g.zero_grad() optimizer_d.zero_grad() fake_data_lab = torch.cat( ( uncenter_l(I_current_l), I_current_ab_predict, uncenter_l(I_last_l), I_last_ab_predict, ), dim=1, ) real_data_lab = torch.cat( ( uncenter_l(I_current_l), I_current_ab, uncenter_l(I_last_l), I_last_ab, ), dim=1, ) if opt.permute_data: batch_index = torch.arange(-1, opt.batch_size - 1, dtype=torch.long) real_data_lab = real_data_lab[batch_index, ...] discriminator_loss = discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator) discriminator_loss.backward() optimizer_d.step() optimizer_g.zero_grad() optimizer_d.zero_grad() # ================== COMPUTE LOSS ================== # L1 loss l1_loss = l1_loss_fn(I_current_ab, I_current_ab_predict) * opt.weigth_l1 # Generator_loss. TODO: freeze this to train some first epoch if epoch_num > opt.epoch_train_discriminator: generator_loss = generator_loss_fn(real_data_lab, fake_data_lab, discriminator, opt.weight_gan, device) # Perceptual Loss I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1)) _, pred_feat_1, pred_feat_2, pred_feat_3 = embed_net(I_predict_rgb) I_current_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab), dim=1)) A_feat_0, _, _, A_feat_3 = embed_net(I_current_rgb) perceptual_loss = perceptual_loss_fn(A_feat_3, pred_feat_3) # Contextual Loss contextual_style5_1 = torch.mean(contextual_forward_loss(pred_feat_3, B_feat_3.detach())) * 8 contextual_style4_1 = torch.mean(contextual_forward_loss(pred_feat_2, B_feat_2.detach())) * 4 contextual_style3_1 = torch.mean(contextual_forward_loss(pred_feat_1, B_feat_1.detach())) * 2 # if opt.use_feature_transform: # contextual_style3_1 = ( # torch.mean( # contextual_forward_loss( # downsampling_by2(pred_feat_1), # downsampling_by2(), # ) # ) # * 2 # ) # else: # contextual_style3_1 = ( # torch.mean( # contextual_forward_loss( # pred_feat_1, # B_feat_1.detach(), # ) # ) # * 2 # ) contextual_loss_total = ( contextual_style5_1 + contextual_style4_1 + contextual_style3_1 ) * opt.weight_contextual # Consistent Loss consistent_loss = consistent_loss_fn( I_current_lab_predict, I_last_ab_predict, I_current_nonlocal_lab_predict, I_last_nonlocal_lab_predict, flow_forward, mask, warping_layer, weight_consistent=opt.weight_consistent, weight_nonlocal_consistent=opt.weight_nonlocal_consistent, device=device, ) # Smoothness loss smoothness_loss = smoothness_loss_fn( I_current_l, I_current_lab, I_current_ab_predict, A_feat_0, weighted_layer_color, nonlocal_weighted_layer, weight_smoothness=opt.weight_smoothness, weight_nonlocal_smoothness=opt.weight_nonlocal_smoothness, device=device, ) # Total loss total_loss = l1_loss + perceptual_loss + contextual_loss_total + consistent_loss + smoothness_loss if epoch_num > opt.epoch_train_discriminator: total_loss += generator_loss # Add loss to loss handler loss_handler.add_loss(key="total_loss", loss=total_loss.item()) loss_handler.add_loss(key="l1_loss", loss=l1_loss.item()) loss_handler.add_loss(key="perceptual_loss", loss=perceptual_loss.item()) loss_handler.add_loss(key="contextual_loss", loss=contextual_loss_total.item()) loss_handler.add_loss(key="consistent_loss", loss=consistent_loss.item()) loss_handler.add_loss(key="smoothness_loss", loss=smoothness_loss.item()) loss_handler.add_loss(key="discriminator_loss", loss=discriminator_loss.item()) if epoch_num > opt.epoch_train_discriminator: loss_handler.add_loss(key="generator_loss", loss=generator_loss.item()) loss_handler.count_one_sample() total_loss.backward() optimizer_g.step() step_optim_scheduler_g.step() step_optim_scheduler_d.step() _forward_model_time = timer_handler.compute_time("forward_model") timer_handler.compute_time("training_logger") training_logger() _training_logger_time = timer_handler.compute_time("training_logger") pbar.set_description( f"Epochs: {epoch_num}, Load_sample: {_load_sample_time:.3f}s, Forward: {_forward_model_time:.3f}s, log: {_training_logger_time:.3f}s" ) pbar.update(1)