|
import torch |
|
import argparse |
|
import os |
|
from models.anime_gan import GeneratorV1 |
|
from models.anime_gan_v2 import GeneratorV2 |
|
from models.anime_gan_v3 import GeneratorV3 |
|
from models.anime_gan import Discriminator |
|
from datasets import AnimeDataSet |
|
from utils.common import load_checkpoint |
|
from trainer import Trainer |
|
from utils.logger import get_logger |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo') |
|
parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao') |
|
parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo') |
|
parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}") |
|
parser.add_argument('--epochs', type=int, default=70) |
|
parser.add_argument('--init_epochs', type=int, default=10) |
|
parser.add_argument('--batch_size', type=int, default=8) |
|
parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory") |
|
parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce') |
|
parser.add_argument('--resume', action='store_true', help="Continue from current dir") |
|
parser.add_argument('--resume_G_init', type=str, default='False') |
|
parser.add_argument('--resume_G', type=str, default='False') |
|
parser.add_argument('--resume_D', type=str, default='False') |
|
parser.add_argument('--device', type=str, default='cuda') |
|
parser.add_argument('--use_sn', action='store_true') |
|
parser.add_argument('--cache', action='store_true', help="Turn on disk cache") |
|
parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision") |
|
parser.add_argument('--save_interval', type=int, default=1) |
|
parser.add_argument('--debug_samples', type=int, default=0) |
|
parser.add_argument('--num_workers', type=int, default=2) |
|
parser.add_argument('--imgsz', type=int, nargs="+", default=[256], |
|
help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs") |
|
parser.add_argument('--resize_method', type=str, default="crop", |
|
help="Resize image method if origin photo larger than imgsz") |
|
|
|
parser.add_argument('--lr_g', type=float, default=2e-5) |
|
parser.add_argument('--lr_d', type=float, default=4e-5) |
|
parser.add_argument('--init_lr', type=float, default=1e-4) |
|
parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G') |
|
parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D') |
|
parser.add_argument( |
|
'--gray_adv', action='store_true', |
|
help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style") |
|
|
|
parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') |
|
parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') |
|
parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') |
|
parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') |
|
parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers') |
|
parser.add_argument('--d_noise', action='store_true') |
|
|
|
|
|
parser.add_argument('--ddp', action='store_true') |
|
parser.add_argument("--local-rank", default=0, type=int) |
|
parser.add_argument("--world-size", default=2, type=int) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def check_params(args): |
|
|
|
args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}" |
|
assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported' |
|
|
|
|
|
def main(args, logger): |
|
check_params(args) |
|
|
|
if not torch.cuda.is_available(): |
|
logger.info("CUDA not found, use CPU") |
|
|
|
|
|
args.device = 'cpu' |
|
args.debug_samples = 10 |
|
args.batch_size = 2 |
|
else: |
|
logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}") |
|
|
|
norm_type = "instance" |
|
if args.model == 'v1': |
|
G = GeneratorV1(args.dataset) |
|
elif args.model == 'v2': |
|
G = GeneratorV2(args.dataset) |
|
norm_type = "layer" |
|
elif args.model == 'v3': |
|
G = GeneratorV3(args.dataset) |
|
|
|
D = Discriminator( |
|
args.dataset, |
|
num_layers=args.d_layers, |
|
use_sn=args.use_sn, |
|
norm_type=norm_type, |
|
) |
|
|
|
start_e = 0 |
|
start_e_init = 0 |
|
|
|
trainer = Trainer( |
|
generator=G, |
|
discriminator=D, |
|
config=args, |
|
logger=logger, |
|
) |
|
|
|
if args.resume_G_init.lower() != 'false': |
|
start_e_init = load_checkpoint(G, args.resume_G_init) + 1 |
|
if args.local_rank == 0: |
|
logger.info(f"G content weight loaded from {args.resume_G_init}") |
|
elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false': |
|
|
|
try: |
|
start_e = load_checkpoint(G, args.resume_G) |
|
if args.local_rank == 0: |
|
logger.info(f"G weight loaded from {args.resume_G}") |
|
load_checkpoint(D, args.resume_D) |
|
if args.local_rank == 0: |
|
logger.info(f"D weight loaded from {args.resume_D}") |
|
|
|
args.init_epochs = 0 |
|
|
|
except Exception as e: |
|
print('Could not load checkpoint, train from scratch', e) |
|
elif args.resume: |
|
|
|
logger.info(f"Loading weight from {trainer.checkpoint_path_G}") |
|
start_e = load_checkpoint(G, trainer.checkpoint_path_G) |
|
logger.info(f"Loading weight from {trainer.checkpoint_path_D}") |
|
load_checkpoint(D, trainer.checkpoint_path_D) |
|
args.init_epochs = 0 |
|
|
|
dataset = AnimeDataSet( |
|
args.anime_image_dir, |
|
args.real_image_dir, |
|
args.debug_samples, |
|
args.cache, |
|
imgsz=args.imgsz, |
|
resize_method=args.resize_method, |
|
) |
|
if args.local_rank == 0: |
|
logger.info(f"Start from epoch {start_e}, {start_e_init}") |
|
trainer.train(dataset, start_e, start_e_init) |
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
real_name = os.path.basename(args.real_image_dir) |
|
anime_name = os.path.basename(args.anime_image_dir) |
|
args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}" |
|
|
|
os.makedirs(args.exp_dir, exist_ok=True) |
|
logger = get_logger(os.path.join(args.exp_dir, "train.log")) |
|
|
|
if args.local_rank == 0: |
|
logger.info("# ==== Train Config ==== #") |
|
for arg in vars(args): |
|
logger.info(f"{arg} {getattr(args, arg)}") |
|
logger.info("==========================") |
|
|
|
main(args, logger) |
|
|