File size: 8,193 Bytes
02ba63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import pytorch_lightning as pl
import torch
from torch import nn
from torch.optim import SGD
from torchvision.utils import save_image
import os
from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger
from model import Autoencoder

class AutoencoderModule(pl.LightningModule):
    def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"):
        super(AutoencoderModule, self).__init__()
        self.feature_dim = feature_dim
        self.learning_rate = learning_rate
        self.lambda_c = lambda_c
        self.margin_img = initial_margin
        self.margin_img_init = initial_margin
        self.threshold = initial_threshold
        self.model = Autoencoder(self.feature_dim)
        self.criterion = nn.MSELoss()
        self.triplet_loss = TripletLossBatch()
        self.losses = []
        self.save_interval = save_interval  # バッチごとの出力間隔
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        img, mat, _ = batch
        batch_size, _, _, size, size = img.shape
        img = img.view(batch_size*2, 3, size, size)
        mat = mat.view(batch_size*2, 3, 3)
        
        dec5_output, output = self.model(img)
        mse_loss = self.criterion(output, img)
        
        # 画像内方向の処理
        num_anchor_sets = 2**12
        trip_loss = 0
        std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch]
        for c in std_list:
            std = size / c
            anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2))
            coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long()
            valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2
            coords[valid_coords_idx] = 0
            anchors[valid_coords_idx] = 0
            
            # 最も近い座標の選択
            d = pairwise_distance_squared(anchors.float(), coords.float())
            idx = torch.argmin(d, dim=2)
            anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx)
            
            # dec5_outputから特徴ベクトルを抽出
            anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives)
            trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img)
        
        trip_loss /= len(std_list)
        self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach()
        
        # 変形の学習
        num_samples = 2**20
        tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples)
        
        # バッチ方向の処理
        bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size)
        
        # 合計損失
        loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss
        self.log("train_loss", loss)

        # VRAM管理
        del img, output
        torch.cuda.empty_cache()
        
        return loss


    def _get_triplet_coordinates(self, anchors, coords, idx):
        anchors = anchors.squeeze(2)
        positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2)
        negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2)
        return anchors, positives, negatives

    def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives):
        y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
        x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
        y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
        x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
        y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
        x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)

        anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors]
        positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives]
        negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives]
        return anchor_vectors, positive_vectors, negative_vectors

    def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12):
        anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2)
        coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
        coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
        anchor_coords = torch.cat((coords_x, coords_y), 1)
        anchor_mat = mat[anchor_indices]
        tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords)

        anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device)
        inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size))
        anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]]

        inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1]
        anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and]
        return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean()

    def _compute_batch_direction_loss(self, dec5_output, batch_size, size):
        N = 2**12
        anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,))
        anchor_coords = torch.randint(0, size, (N, 2))
        other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2))
        other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2
        other_coords = torch.randint(0, size, (N, 2, 2))

        anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]]
        other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]]
        distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors)
        return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10)

    def configure_optimizers(self):
        optimizer = SGD(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
        return [optimizer], [scheduler]

    # def save_intermediate_image(self, output, epoch):
    #     save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1)
    #     print(f"Saved intermediate image at epoch {epoch}")
        
    # def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None):
    #     save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png")
    #     DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords)

    def configure_optimizers(self):
        optimizer = SGD(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
        return [optimizer], [scheduler]