SwinTExCo / train_swin_224.py
duongttr's picture
Upload folder using huggingface_hub
62ef5f4
raw
history blame
27.1 kB
import os
import sys
import wandb
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
from datetime import datetime
from zoneinfo import ZoneInfo
from time import gmtime, strftime
from collections import OrderedDict
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision.transforms import CenterCrop
from torch.utils.data import ConcatDataset, DataLoader
import torchvision.transforms as torch_transforms
from torchvision.utils import make_grid
from src.losses import (
ContextualLoss,
ContextualLoss_forward,
Perceptual_loss,
consistent_loss_fn,
discriminator_loss_fn,
generator_loss_fn,
l1_loss_fn,
smoothness_loss_fn,
)
from src.models.CNN.GAN_models import Discriminator_x64_224
from src.models.CNN.ColorVidNet import GeneralColorVidNet
from src.models.CNN.FrameColor import frame_colorization
from src.models.CNN.NonlocalNet import WeightedAverage_color, NonlocalWeightedAverage, GeneralWarpNet
from src.models.vit.embed import GeneralEmbedModel
from src.data import transforms
from src.data.dataloader import VideosDataset, VideosDataset_ImageNet
from src.utils import CenterPad_threshold
from src.utils import (
TimeHandler,
RGB2Lab,
ToTensor,
Normalize,
LossHandler,
WarpingLayer,
uncenter_l,
tensor_lab2rgb,
print_num_params,
)
from src.scheduler import PolynomialLR
parser = argparse.ArgumentParser()
parser.add_argument("--video_data_root_list", type=str, default="dataset")
parser.add_argument("--flow_data_root_list", type=str, default="flow")
parser.add_argument("--mask_data_root_list", type=str, default="mask")
parser.add_argument("--data_root_imagenet", default="imagenet", type=str)
parser.add_argument("--annotation_file_path", default="dataset/annotation.csv", type=str)
parser.add_argument("--imagenet_pairs_file", default="imagenet_pairs.txt", type=str)
parser.add_argument("--gpu_ids", type=str, default="0,1,2,3", help="separate by comma")
parser.add_argument("--workers", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--image_size", type=int, default=[384, 384])
parser.add_argument("--ic", type=int, default=7)
parser.add_argument("--epoch", type=int, default=40)
parser.add_argument("--resume_epoch", type=int, default=0)
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--load_pretrained_model", type=bool, default=False)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--beta1", type=float, default=0.5)
parser.add_argument("--lr_step", type=int, default=1)
parser.add_argument("--lr_gamma", type=float, default=0.9)
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
parser.add_argument("--checkpoint_step", type=int, default=500)
parser.add_argument("--real_reference_probability", type=float, default=0.7)
parser.add_argument("--nonzero_placeholder_probability", type=float, default=0.0)
parser.add_argument("--domain_invariant", type=bool, default=False)
parser.add_argument("--weigth_l1", type=float, default=2.0)
parser.add_argument("--weight_contextual", type=float, default="0.5")
parser.add_argument("--weight_perceptual", type=float, default="0.02")
parser.add_argument("--weight_smoothness", type=float, default="5.0")
parser.add_argument("--weight_gan", type=float, default="0.5")
parser.add_argument("--weight_nonlocal_smoothness", type=float, default="0.0")
parser.add_argument("--weight_nonlocal_consistent", type=float, default="0.0")
parser.add_argument("--weight_consistent", type=float, default="0.05")
parser.add_argument("--luminance_noise", type=float, default="2.0")
parser.add_argument("--permute_data", type=bool, default=True)
parser.add_argument("--contextual_loss_direction", type=str, default="forward", help="forward or backward matching")
parser.add_argument("--batch_accum_size", type=int, default=10)
parser.add_argument("--epoch_train_discriminator", type=int, default=3)
parser.add_argument("--vit_version", type=str, default="vit_tiny_patch16_384")
parser.add_argument("--use_dummy", type=bool, default=False)
parser.add_argument("--use_wandb", type=bool, default=False)
parser.add_argument("--use_feature_transform", type=bool, default=False)
parser.add_argument("--head_out_idx", type=str, default="8,9,10,11")
parser.add_argument("--wandb_token", type=str, default="")
parser.add_argument("--wandb_name", type=str, default="")
def load_data():
transforms_video = [
CenterCrop(opt.image_size),
RGB2Lab(),
ToTensor(),
Normalize(),
]
train_dataset_videos = [
VideosDataset(
video_data_root=video_data_root,
flow_data_root=flow_data_root,
mask_data_root=mask_data_root,
imagenet_folder=opt.data_root_imagenet,
annotation_file_path=opt.annotation_file_path,
image_size=opt.image_size,
image_transform=transforms.Compose(transforms_video),
real_reference_probability=opt.real_reference_probability,
nonzero_placeholder_probability=opt.nonzero_placeholder_probability,
)
for video_data_root, flow_data_root, mask_data_root in zip(
opt.video_data_root_list, opt.flow_data_root_list, opt.mask_data_root_list
)
]
transforms_imagenet = [CenterPad_threshold(opt.image_size), RGB2Lab(), ToTensor(), Normalize()]
extra_reference_transform = [
torch_transforms.RandomHorizontalFlip(0.5),
torch_transforms.RandomResizedCrop(480, (0.98, 1.0), ratio=(0.8, 1.2)),
]
train_dataset_imagenet = VideosDataset_ImageNet(
imagenet_data_root=opt.data_root_imagenet,
pairs_file=opt.imagenet_pairs_file,
image_size=opt.image_size,
transforms_imagenet=transforms_imagenet,
distortion_level=4,
brightnessjitter=5,
nonzero_placeholder_probability=opt.nonzero_placeholder_probability,
extra_reference_transform=extra_reference_transform,
real_reference_probability=opt.real_reference_probability,
)
# video_training_length = sum([len(dataset) for dataset in train_dataset_videos])
# imagenet_training_length = len(train_dataset_imagenet)
# dataset_training_length = sum([dataset.real_len for dataset in train_dataset_videos]) + +train_dataset_imagenet.real_len
dataset_combined = ConcatDataset(train_dataset_videos + [train_dataset_imagenet])
# sampler=[]
# seed_sampler=int.from_bytes(os.urandom(4),"big")
# random.seed(seed_sampler)
# for idx in range(opt.epoch):
# sampler = sampler + random.sample(range(dataset_training_length),dataset_training_length)
# wandb.log({"Sampler_Seed":seed_sampler})
# sampler = sampler+WeightedRandomSampler([1] * video_training_length + [1] * imagenet_training_length, dataset_training_length*opt.epoch)
# video_training_length = sum([len(dataset) for dataset in train_dataset_videos])
# dataset_training_length = sum([dataset.real_len for dataset in train_dataset_videos])
# dataset_combined = ConcatDataset(train_dataset_videos)
# sampler = WeightedRandomSampler([1] * video_training_length, dataset_training_length * opt.epoch)
data_loader = DataLoader(dataset_combined, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers)
return data_loader
def training_logger():
if (total_iter % opt.checkpoint_step == 0) or (total_iter == len(data_loader)):
train_loss_dict = {"train/" + str(k): v / loss_handler.count_sample for k, v in loss_handler.loss_dict.items()}
train_loss_dict["train/opt_g_lr_1"] = step_optim_scheduler_g.get_last_lr()[0]
train_loss_dict["train/opt_g_lr_2"] = step_optim_scheduler_g.get_last_lr()[1]
train_loss_dict["train/opt_d_lr"] = step_optim_scheduler_d.get_last_lr()[0]
alert_text = f"l1_loss: {l1_loss.item()}\npercep_loss: {perceptual_loss.item()}\nctx_loss: {contextual_loss_total.item()}\ncst_loss: {consistent_loss.item()}\nsm_loss: {smoothness_loss.item()}\ntotal: {total_loss.item()}"
if opt.use_wandb:
wandb.log(train_loss_dict)
wandb.alert(title=f"Progress training #{total_iter}", text=alert_text)
for idx in range(I_predict_rgb.shape[0]):
concated_I = make_grid(
[(I_predict_rgb[idx] * 255), (I_reference_rgb[idx] * 255), (I_current_rgb[idx] * 255)], nrow=3
)
wandb_concated_I = wandb.Image(
concated_I,
caption="[LEFT] Predict, [CENTER] Reference, [RIGHT] Ground truth\n[REF] {}, [FRAME] {}".format(
ref_path[idx], curr_frame_path[idx]
),
)
wandb.log({f"example_{idx}": wandb_concated_I})
torch.save(
nonlocal_net.state_dict(),
os.path.join(opt.checkpoint_dir, "nonlocal_net_iter.pth"),
)
torch.save(
colornet.state_dict(),
os.path.join(opt.checkpoint_dir, "colornet_iter.pth"),
)
torch.save(
discriminator.state_dict(),
os.path.join(opt.checkpoint_dir, "discriminator_iter.pth"),
)
torch.save(embed_net.state_dict(), os.path.join(opt.checkpoint_dir, "embed_net_iter.pth"))
loss_handler.reset()
def load_params(ckpt_file):
params = torch.load(ckpt_file)
new_params = []
for key, value in params.items():
new_params.append((key, value))
return OrderedDict(new_params)
def parse(parser, save=True):
opt = parser.parse_args()
args = vars(opt)
print("------------------------------ Options -------------------------------")
for k, v in sorted(args.items()):
print("%s: %s" % (str(k), str(v)))
print("-------------------------------- End ---------------------------------")
if save:
file_name = os.path.join("opt.txt")
with open(file_name, "wt") as opt_file:
opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n")
opt_file.write("------------------------------ Options -------------------------------\n")
for k, v in sorted(args.items()):
opt_file.write("%s: %s\n" % (str(k), str(v)))
opt_file.write("-------------------------------- End ---------------------------------\n")
return opt
def gpu_setup():
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
cudnn.benchmark = True
torch.cuda.set_device(opt.gpu_ids[0])
device = torch.device("cuda")
print("running on GPU", opt.gpu_ids)
return device
if __name__ == "__main__":
############################################## SETUP ###############################################
torch.multiprocessing.set_start_method("spawn", force=True)
# =============== GET PARSER OPTION ================
opt = parse(parser)
opt.video_data_root_list = opt.video_data_root_list.split(",")
opt.flow_data_root_list = opt.flow_data_root_list.split(",")
opt.mask_data_root_list = opt.mask_data_root_list.split(",")
opt.gpu_ids = list(map(int, opt.gpu_ids.split(",")))
opt.head_out_idx = list(map(int, opt.head_out_idx.split(",")))
n_dim_output = 3 if opt.use_feature_transform else 4
assert len(opt.head_out_idx) == 4, "Size of head_out_idx must be 4"
os.makedirs(opt.checkpoint_dir, exist_ok=True)
# =================== INIT WANDB ===================
if opt.use_wandb:
print("Save images to Wandb")
if opt.wandb_token != "":
try:
wandb.login(key=opt.wandb_token)
except:
pass
wandb.init(
project="video-colorization",
name=f"{opt.wandb_name} {datetime.now(tz=ZoneInfo('Asia/Ho_Chi_Minh')).strftime('%Y/%m/%d_%H-%M-%S')}",
)
# ================== SETUP DEVICE ==================
# torch.multiprocessing.set_start_method("spawn", force=True)
# device = gpu_setup()
device = "cuda" if torch.cuda.is_available() else "cpu"
############################################ LOAD DATA #############################################
if opt.use_dummy:
H, W = 224, 224
I_last_lab = torch.rand(opt.batch_size, 3, H, W)
I_current_lab = torch.rand(opt.batch_size, 3, H, W)
I_reference_lab = torch.rand(opt.batch_size, 3, H, W)
flow_forward = torch.rand(opt.batch_size, 2, H, W)
mask = torch.rand(opt.batch_size, 1, H, W)
placeholder_lab = torch.rand(opt.batch_size, 3, H, W)
self_ref_flag = torch.rand(opt.batch_size, 3, H, W)
data_loader = [
[I_last_lab, I_current_lab, I_reference_lab, flow_forward, mask, placeholder_lab, self_ref_flag, None, None, None]
for _ in range(1)
]
else:
data_loader = load_data()
########################################## DEFINE NETWORK ##########################################
colornet = GeneralColorVidNet(opt.ic).to(device)
nonlocal_net = GeneralWarpNet(feature_channel=256).to(device) # change to 128 in swin tiny
discriminator = Discriminator_x64_224(ndf=64).to(device)
weighted_layer_color = WeightedAverage_color().to(device)
nonlocal_weighted_layer = NonlocalWeightedAverage().to(device)
warping_layer = WarpingLayer(device=device).to(device)
embed_net = GeneralEmbedModel(pretrained_model="swin-small", device=device).to(device)
print("-" * 59)
print("| TYPE | Model name | Num params |")
print("-" * 59)
colornet_params = print_num_params(colornet)
nonlocal_net_params = print_num_params(nonlocal_net)
discriminator_params = print_num_params(discriminator)
weighted_layer_color_params = print_num_params(weighted_layer_color)
nonlocal_weighted_layer_params = print_num_params(nonlocal_weighted_layer)
warping_layer_params = print_num_params(warping_layer)
embed_net_params = print_num_params(embed_net)
print("-" * 59)
print(
f"| TOTAL | | {('{:,}'.format(colornet_params+nonlocal_net_params+discriminator_params+weighted_layer_color_params+nonlocal_weighted_layer_params+warping_layer_params+embed_net_params)).rjust(10)} |"
)
print("-" * 59)
if opt.use_wandb:
wandb.watch(discriminator, log="all", log_freq=opt.checkpoint_step, idx=0)
wandb.watch(embed_net, log="all", log_freq=opt.checkpoint_step, idx=1)
wandb.watch(colornet, log="all", log_freq=opt.checkpoint_step, idx=2)
wandb.watch(nonlocal_net, log="all", log_freq=opt.checkpoint_step, idx=3)
# ============= USE PRETRAINED OR NOT ==============
if opt.load_pretrained_model:
# pretrained_path = "/workspace/video_colorization/ckpt_folder_ver_1_vit_small_patch16_384"
nonlocal_net.load_state_dict(load_params(os.path.join(opt.checkpoint_dir, "nonlocal_net_iter.pth")))
colornet.load_state_dict(load_params(os.path.join(opt.checkpoint_dir, "colornet_iter.pth")))
discriminator.load_state_dict(load_params(os.path.join(opt.checkpoint_dir, "discriminator_iter.pth")))
embed_net_params = load_params(os.path.join(opt.checkpoint_dir, "embed_net_iter.pth"))
embed_net.load_state_dict(embed_net_params)
###################################### DEFINE LOSS FUNCTIONS #######################################
perceptual_loss_fn = Perceptual_loss(opt.domain_invariant, opt.weight_perceptual)
contextual_loss = ContextualLoss().to(device)
contextual_forward_loss = ContextualLoss_forward().to(device)
######################################## DEFINE OPTIMIZERS #########################################
optimizer_g = optim.AdamW(
[
{"params": nonlocal_net.parameters(), "lr": opt.lr},
{"params": colornet.parameters(), "lr": 2 * opt.lr},
{"params": embed_net.parameters(), "lr": opt.lr},
],
betas=(0.5, 0.999),
eps=1e-5,
amsgrad=True,
)
optimizer_d = optim.AdamW(
filter(lambda p: p.requires_grad, discriminator.parameters()),
lr=opt.lr,
betas=(0.5, 0.999),
amsgrad=True,
)
step_optim_scheduler_g = PolynomialLR(
optimizer_g,
step_size=opt.lr_step,
iter_warmup=0,
iter_max=len(data_loader) * opt.epoch,
power=0.9,
min_lr=1e-8,
)
step_optim_scheduler_d = PolynomialLR(
optimizer_d,
step_size=opt.lr_step,
iter_warmup=0,
iter_max=len(data_loader) * opt.epoch,
power=0.9,
min_lr=1e-8,
)
########################################## DEFINE OTHERS ###########################################
downsampling_by2 = nn.AvgPool2d(kernel_size=2).to(device)
timer_handler = TimeHandler()
loss_handler = LossHandler() # Handle loss value
############################################## TRAIN ###############################################
total_iter = 0
for epoch_num in range(1, opt.epoch + 1):
# if opt.use_wandb:
# wandb.log({"Current_trainning_epoch": epoch_num})
with tqdm(total=len(data_loader), position=0, leave=True) as pbar:
for iter, sample in enumerate(data_loader):
timer_handler.compute_time("load_sample")
total_iter += 1
# =============== LOAD DATA SAMPLE ================
(
I_last_lab, ######## (3, H, W)
I_current_lab, ##### (3, H, W)
I_reference_lab, ### (3, H, W)
flow_forward, ###### (2, H, W)
mask, ############## (1, H, W)
placeholder_lab, ### (3, H, W)
self_ref_flag, ##### (3, H, W)
prev_frame_path,
curr_frame_path,
ref_path,
) = sample
I_last_lab = I_last_lab.to(device)
I_current_lab = I_current_lab.to(device)
I_reference_lab = I_reference_lab.to(device)
flow_forward = flow_forward.to(device)
mask = mask.to(device)
placeholder_lab = placeholder_lab.to(device)
self_ref_flag = self_ref_flag.to(device)
I_last_l = I_last_lab[:, 0:1, :, :]
I_last_ab = I_last_lab[:, 1:3, :, :]
I_current_l = I_current_lab[:, 0:1, :, :]
I_current_ab = I_current_lab[:, 1:3, :, :]
I_reference_l = I_reference_lab[:, 0:1, :, :]
I_reference_ab = I_reference_lab[:, 1:3, :, :]
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1))
_load_sample_time = timer_handler.compute_time("load_sample")
timer_handler.compute_time("forward_model")
features_B = embed_net(I_reference_rgb)
B_feat_0, B_feat_1, B_feat_2, B_feat_3 = features_B
# ================== COLORIZATION ==================
# The last frame
I_last_ab_predict, I_last_nonlocal_lab_predict = frame_colorization(
IA_l=I_last_l,
IB_lab=I_reference_lab,
IA_last_lab=placeholder_lab,
features_B=features_B,
embed_net=embed_net,
colornet=colornet,
nonlocal_net=nonlocal_net,
luminance_noise=opt.luminance_noise,
)
I_last_lab_predict = torch.cat((I_last_l, I_last_ab_predict), dim=1)
# The current frame
I_current_ab_predict, I_current_nonlocal_lab_predict = frame_colorization(
IA_l=I_current_l,
IB_lab=I_reference_lab,
IA_last_lab=I_last_lab_predict,
features_B=features_B,
embed_net=embed_net,
colornet=colornet,
nonlocal_net=nonlocal_net,
luminance_noise=opt.luminance_noise,
)
I_current_lab_predict = torch.cat((I_last_l, I_current_ab_predict), dim=1)
# ================ UPDATE GENERATOR ================
if opt.weight_gan > 0:
optimizer_g.zero_grad()
optimizer_d.zero_grad()
fake_data_lab = torch.cat(
(
uncenter_l(I_current_l),
I_current_ab_predict,
uncenter_l(I_last_l),
I_last_ab_predict,
),
dim=1,
)
real_data_lab = torch.cat(
(
uncenter_l(I_current_l),
I_current_ab,
uncenter_l(I_last_l),
I_last_ab,
),
dim=1,
)
if opt.permute_data:
batch_index = torch.arange(-1, opt.batch_size - 1, dtype=torch.long)
real_data_lab = real_data_lab[batch_index, ...]
discriminator_loss = discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator)
discriminator_loss.backward()
optimizer_d.step()
optimizer_g.zero_grad()
optimizer_d.zero_grad()
# ================== COMPUTE LOSS ==================
# L1 loss
l1_loss = l1_loss_fn(I_current_ab, I_current_ab_predict) * opt.weigth_l1
# Generator_loss. TODO: freeze this to train some first epoch
if epoch_num > opt.epoch_train_discriminator:
generator_loss = generator_loss_fn(real_data_lab, fake_data_lab, discriminator, opt.weight_gan, device)
# Perceptual Loss
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1))
pred_feat_0, pred_feat_1, pred_feat_2, pred_feat_3 = embed_net(I_predict_rgb)
I_current_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab), dim=1))
A_feat_0, _, _, A_feat_3 = embed_net(I_current_rgb)
perceptual_loss = perceptual_loss_fn(A_feat_3, pred_feat_3)
# Contextual Loss
contextual_style5_1 = torch.mean(contextual_forward_loss(pred_feat_3, B_feat_3.detach())) * 8
contextual_style4_1 = torch.mean(contextual_forward_loss(pred_feat_2, B_feat_2.detach())) * 4
contextual_style3_1 = torch.mean(contextual_forward_loss(pred_feat_1, B_feat_1.detach())) * 2
contextual_style2_1 = torch.mean(contextual_forward_loss(pred_feat_0, B_feat_0.detach()))
# if opt.use_feature_transform:
# contextual_style3_1 = (
# torch.mean(
# contextual_forward_loss(
# downsampling_by2(pred_feat_1),
# downsampling_by2(),
# )
# )
# * 2
# )
# else:
# contextual_style3_1 = (
# torch.mean(
# contextual_forward_loss(
# pred_feat_1,
# B_feat_1.detach(),
# )
# )
# * 2
# )
contextual_loss_total = (
contextual_style5_1 + contextual_style4_1 + contextual_style3_1 + contextual_style2_1
) * opt.weight_contextual
# Consistent Loss
consistent_loss = consistent_loss_fn(
I_current_lab_predict,
I_last_ab_predict,
I_current_nonlocal_lab_predict,
I_last_nonlocal_lab_predict,
flow_forward,
mask,
warping_layer,
weight_consistent=opt.weight_consistent,
weight_nonlocal_consistent=opt.weight_nonlocal_consistent,
device=device,
)
# Smoothness loss
smoothness_loss = smoothness_loss_fn(
I_current_l,
I_current_lab,
I_current_ab_predict,
A_feat_0,
weighted_layer_color,
nonlocal_weighted_layer,
weight_smoothness=opt.weight_smoothness,
weight_nonlocal_smoothness=opt.weight_nonlocal_smoothness,
device=device,
)
# Total loss
total_loss = l1_loss + perceptual_loss + contextual_loss_total + consistent_loss + smoothness_loss
if epoch_num > opt.epoch_train_discriminator:
total_loss += generator_loss
# Add loss to loss handler
loss_handler.add_loss(key="total_loss", loss=total_loss.item())
loss_handler.add_loss(key="l1_loss", loss=l1_loss.item())
loss_handler.add_loss(key="perceptual_loss", loss=perceptual_loss.item())
loss_handler.add_loss(key="contextual_loss", loss=contextual_loss_total.item())
loss_handler.add_loss(key="consistent_loss", loss=consistent_loss.item())
loss_handler.add_loss(key="smoothness_loss", loss=smoothness_loss.item())
loss_handler.add_loss(key="discriminator_loss", loss=discriminator_loss.item())
if epoch_num > opt.epoch_train_discriminator:
loss_handler.add_loss(key="generator_loss", loss=generator_loss.item())
loss_handler.count_one_sample()
total_loss.backward()
optimizer_g.step()
step_optim_scheduler_g.step()
step_optim_scheduler_d.step()
_forward_model_time = timer_handler.compute_time("forward_model")
timer_handler.compute_time("training_logger")
training_logger()
_training_logger_time = timer_handler.compute_time("training_logger")
pbar.set_description(
f"Epochs: {epoch_num}, Load_sample: {_load_sample_time:.3f}s, Forward: {_forward_model_time:.3f}s, log: {_training_logger_time:.3f}s"
)
pbar.update(1)