File size: 2,930 Bytes
2e34814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os

import torch
from tqdm import tqdm
from color_transfer_loss import ColorTransferLoss

from configs import paths_config, hyperparameters, global_config
from training.coaches.base_coach import BaseCoach
from utils.log_utils import log_images_from_w


class MultiIDCoach(BaseCoach):
    def __init__(self, data_loader, use_wandb):
        super().__init__(data_loader, use_wandb)

    def train(self):
        self.G.synthesis.train()
        self.G.mapping.train()

        w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}"
        os.makedirs(w_path_dir, exist_ok=True)
        os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True)

        use_ball_holder = True
        w_pivots = []
        images = []

        for fname, image in self.data_loader:
            if self.image_counter >= hyperparameters.max_images_to_invert:
                break

            image_name = fname[0]
            if hyperparameters.first_inv_type == "w+":
                embedding_dir = (
                    f"{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}"
                )
            else:
                embedding_dir = (
                    f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}"
                )
            os.makedirs(embedding_dir, exist_ok=True)

            w_pivot = self.get_inversion(w_path_dir, image_name, image)
            w_pivots.append(w_pivot)
            images.append((image_name, image))
            self.image_counter += 1

        for i in tqdm(range(hyperparameters.max_pti_steps)):
            self.image_counter = 0

            for data, w_pivot in zip(images, w_pivots):
                image_name, image = data

                if self.image_counter >= hyperparameters.max_images_to_invert:
                    break

                real_images_batch = image.to(global_config.device)

                generated_images = self.forward(w_pivot)

                loss, l2_loss_val, loss_lpips = self.calc_loss(
                    generated_images,
                    real_images_batch,
                    image_name,
                    self.G,
                    use_ball_holder,
                    w_pivot,
                    {},
                )

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                use_ball_holder = (
                    global_config.training_step
                    % hyperparameters.locality_regularization_interval
                    == 0
                )

                global_config.training_step += 1
                self.image_counter += 1

        if self.use_wandb:
            log_images_from_w(w_pivots, self.G, [image[0] for image in images])

        torch.save(
            self.G,
            f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt",
        )