Spaces:
Running
on
T4
Running
on
T4
File size: 5,018 Bytes
561c629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# -*- coding: utf-8 -*-
import sys
import os
import torch
# import important files
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt
from architecture.grl import GRL
from architecture.discriminator import UNetDiscriminatorSN, MultiScaleDiscriminator
from train_code.train_master import train_master
class train_grlgan(train_master):
def __init__(self, options, args) -> None:
super().__init__(options, args, "grlgan", True) # Pass a model name unique code
def loss_init(self):
# prepare pixel loss (Generator)
self.pixel_loss_load()
# prepare perceptual loss
self.GAN_loss_load()
def call_model(self):
# Generator: GRL Small
patch_size = 144
if opt['model_size'] == "small":
# GRL small model
self.generator = GRL(
upscale = opt['scale'],
img_size = patch_size,
window_size = 8,
depths = [4, 4, 4, 4],
embed_dim = 128,
num_heads_window = [2, 2, 2, 2],
num_heads_stripe = [2, 2, 2, 2],
mlp_ratio = 2,
qkv_proj_type = "linear",
anchor_proj_type = "avgpool",
anchor_window_down_factor = 2,
out_proj_type = "linear",
conv_type = "1conv",
upsampler = "pixelshuffle",
).cuda()
elif opt['model_size'] == "tiny":
# GRL tiny model
self.generator = GRL(
upscale = opt['scale'],
img_size = 64,
window_size = 8,
depths = [4, 4, 4, 4],
embed_dim = 64,
num_heads_window = [2, 2, 2, 2],
num_heads_stripe = [2, 2, 2, 2],
mlp_ratio = 2,
qkv_proj_type = "linear",
anchor_proj_type = "avgpool",
anchor_window_down_factor = 2,
out_proj_type = "linear",
conv_type = "1conv",
upsampler = "pixelshuffledirect",
).cuda()
elif opt['model_size'] == "tiny2":
# GRL tiny model
self.generator = GRL(
upscale = opt['scale'],
img_size = 64,
window_size = 8,
depths = [4, 4, 4, 4],
embed_dim = 64,
num_heads_window = [2, 2, 2, 2],
num_heads_stripe = [2, 2, 2, 2],
mlp_ratio = 2,
qkv_proj_type = "linear",
anchor_proj_type = "avgpool",
anchor_window_down_factor = 2,
out_proj_type = "linear",
conv_type = "1conv",
upsampler = "nearest+conv", # Change
).cuda()
else:
raise NotImplementedError("We don't support such model size in GRL model")
# self.generator = torch.compile(self.generator).cuda()
# Discriminator
if opt['discriminator_type'] == "PatchDiscriminator":
self.discriminator = MultiScaleDiscriminator(3).cuda()
elif opt['discriminator_type'] == "UNetDiscriminator":
self.discriminator = UNetDiscriminatorSN(3).cuda()
self.generator.train(); self.discriminator.train()
def run(self):
self.master_run()
def calculate_loss(self, gen_hr, imgs_hr):
###################### We have 3 losses on Generator ######################
# Generator Pixel loss (l1 loss): generated vs. GT
l_g_pix = self.cri_pix(gen_hr, imgs_hr)
self.generator_loss += l_g_pix
self.weight_store["pixel_loss"] = l_g_pix
# Generator perceptual loss: generated vs. perceptual
l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr)
l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr)
l_g_percep = l_g_percep_danbooru + l_g_percep_vgg
self.generator_loss += l_g_percep
self.weight_store["perceptual_loss"] = l_g_percep
# Generator GAN loss label correction
fake_g_preds = self.discriminator(gen_hr)
l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) # loss_weight (self.gan_loss_weight) is included
self.generator_loss += l_g_gan
self.weight_store["gan_loss"] = l_g_gan # Already with gan_loss_weight (0.2/1)
def tensorboard_report(self, iteration):
self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration)
self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration)
self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration)
self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration)
|