File size: 4,257 Bytes
2e34814
 
 
 
 
 
 
 
 
 
 
33dd132
 
2e34814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
from tqdm import tqdm
from configs import paths_config, hyperparameters, global_config
from training.coaches.base_coach import BaseCoach
from utils.log_utils import log_images_from_w
from color_transfer_loss import ColorTransferLoss
import copy


class SingleIDCoach(BaseCoach):
    def __init__(self, data_loader, in_year, use_wandb):
        super().__init__(data_loader, in_year, use_wandb)

    def train(self):

        w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}"
        os.makedirs(w_path_dir, exist_ok=True)
        os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True)

        use_ball_holder = True

        for fname, image in tqdm(self.data_loader):
    
            image_name = fname[0]

            self.restart_training()

            if self.image_counter >= hyperparameters.max_images_to_invert:
                break

            embedding_dir = (
                f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}"
            )

            os.makedirs(embedding_dir, exist_ok=True)

            w_pivot = None

            if hyperparameters.use_last_w_pivots:
                w_pivot = self.load_inversions(w_path_dir, image_name)

            elif not hyperparameters.use_last_w_pivots or w_pivot is None:
                w_pivot = self.calc_inversions(image, image_name)

            # w_pivot = w_pivot.detach().clone().to(global_config.device)
            w_pivot = w_pivot.to(global_config.device)

            torch.save(w_pivot, f"{embedding_dir}/0.pt")

            # w_pivot = torch.load(
            # f"{embedding_dir}/0.pt", map_location=global_config.device
            # )
            log_images_counter = 0
            real_images_batch = image.to(global_config.device)

            if hyperparameters.color_transfer_lambda > 0:
                self.color_losses = {}
                for y in self.years:
                    _, init_rgbs = self.forward_sibling(
                        self.siblings[y].synthesis, w_pivot
                    )
                    self.color_losses[y] = ColorTransferLoss(init_rgbs)

            for i in tqdm(range(hyperparameters.max_pti_steps)):
                rgbs = {}
                if hyperparameters.color_transfer_lambda > 0:
                    for y in self.years:
                        G_sibling_aug = copy.deepcopy(self.siblings[y])
                        for p_pti, p_orig, p in zip(
                            self.G.synthesis.parameters(),
                            self.original_G.synthesis.parameters(),
                            G_sibling_aug.synthesis.parameters(),
                        ):
                            delta = p_pti - p_orig
                            p += delta
                        rgbs[y] = self.forward_sibling(
                            G_sibling_aug.synthesis, w_pivot
                        )[1]

                generated_images = self.forward(w_pivot)
                loss, l2_loss_val, loss_lpips = self.calc_loss(
                    generated_images,
                    real_images_batch,
                    image_name,
                    self.G,
                    use_ball_holder,
                    w_pivot,
                    rgbs,
                )

                self.optimizer.zero_grad()

                if loss_lpips <= hyperparameters.LPIPS_value_threshold:
                    break

                loss.backward()
                self.optimizer.step()

                use_ball_holder = (
                    global_config.training_step
                    % hyperparameters.locality_regularization_interval
                    == 0
                )

                if (
                    self.use_wandb
                    and log_images_counter % global_config.image_rec_result_log_snapshot
                    == 0
                ):
                    log_images_from_w([w_pivot], self.G, [image_name])

                global_config.training_step += 1
                log_images_counter += 1

            self.image_counter += 1

            torch.save(
                self.G,
                f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_{image_name}.pt",
            )