File size: 5,018 Bytes
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# -*- coding: utf-8 -*-

import  sys
import os
import torch

# import important files
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt
from architecture.grl import GRL
from architecture.discriminator import UNetDiscriminatorSN, MultiScaleDiscriminator
from train_code.train_master import train_master



class train_grlgan(train_master):
    def __init__(self, options, args) -> None:
        super().__init__(options, args, "grlgan", True)     # Pass a model name unique code


    def loss_init(self):

        # prepare pixel loss (Generator)
        self.pixel_loss_load()

        # prepare perceptual loss
        self.GAN_loss_load()


    def call_model(self):
        # Generator: GRL Small
        patch_size = 144
        if opt['model_size'] == "small":
            # GRL small model 
            self.generator = GRL(
                upscale = opt['scale'],              
                img_size = patch_size,
                window_size = 8,
                depths = [4, 4, 4, 4],
                embed_dim = 128,
                num_heads_window = [2, 2, 2, 2],
                num_heads_stripe = [2, 2, 2, 2],
                mlp_ratio = 2,
                qkv_proj_type = "linear",
                anchor_proj_type = "avgpool",
                anchor_window_down_factor = 2,
                out_proj_type = "linear",
                conv_type = "1conv",
                upsampler = "pixelshuffle",
            ).cuda()

        elif opt['model_size'] == "tiny":
            # GRL tiny model
            self.generator = GRL(
                upscale = opt['scale'],
                img_size = 64,
                window_size = 8,
                depths = [4, 4, 4, 4],
                embed_dim = 64,
                num_heads_window = [2, 2, 2, 2],
                num_heads_stripe = [2, 2, 2, 2],
                mlp_ratio = 2,
                qkv_proj_type = "linear",
                anchor_proj_type = "avgpool",
                anchor_window_down_factor = 2,
                out_proj_type = "linear",
                conv_type = "1conv",
                upsampler = "pixelshuffledirect",
            ).cuda()

        elif opt['model_size'] == "tiny2":
            # GRL tiny model
            self.generator = GRL(
                upscale = opt['scale'],
                img_size = 64,
                window_size = 8,
                depths = [4, 4, 4, 4],
                embed_dim = 64,
                num_heads_window = [2, 2, 2, 2],
                num_heads_stripe = [2, 2, 2, 2],
                mlp_ratio = 2,
                qkv_proj_type = "linear",
                anchor_proj_type = "avgpool",
                anchor_window_down_factor = 2,
                out_proj_type = "linear",
                conv_type = "1conv",
                upsampler = "nearest+conv",     # Change
            ).cuda()
            
        else:
            raise NotImplementedError("We don't support such model size in GRL model")
        # self.generator = torch.compile(self.generator).cuda()

        # Discriminator
        if opt['discriminator_type'] == "PatchDiscriminator":
            self.discriminator = MultiScaleDiscriminator(3).cuda()
        elif opt['discriminator_type'] == "UNetDiscriminator":
            self.discriminator = UNetDiscriminatorSN(3).cuda()
        
        self.generator.train(); self.discriminator.train()


    def run(self):
        self.master_run()
                        


    def calculate_loss(self, gen_hr, imgs_hr):

        ######################  We have 3 losses on Generator  ######################
        # Generator Pixel loss (l1 loss):  generated vs. GT
        l_g_pix = self.cri_pix(gen_hr, imgs_hr)
        self.generator_loss += l_g_pix
        self.weight_store["pixel_loss"] = l_g_pix


        # Generator perceptual loss:        generated vs. perceptual
        l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr)
        l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr)
        l_g_percep = l_g_percep_danbooru + l_g_percep_vgg 
        self.generator_loss += l_g_percep
        self.weight_store["perceptual_loss"] = l_g_percep


        # Generator GAN loss               label correction
        fake_g_preds = self.discriminator(gen_hr)
        l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) # loss_weight (self.gan_loss_weight) is included
        self.generator_loss += l_g_gan
        self.weight_store["gan_loss"] = l_g_gan # Already with gan_loss_weight (0.2/1)


    def tensorboard_report(self, iteration):
        self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration)
        self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration)
        self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration)
        self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration)