Spaces:
Sleeping
Sleeping
File size: 8,193 Bytes
02ba63a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import pytorch_lightning as pl
import torch
from torch import nn
from torch.optim import SGD
from torchvision.utils import save_image
import os
from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger
from model import Autoencoder
class AutoencoderModule(pl.LightningModule):
def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"):
super(AutoencoderModule, self).__init__()
self.feature_dim = feature_dim
self.learning_rate = learning_rate
self.lambda_c = lambda_c
self.margin_img = initial_margin
self.margin_img_init = initial_margin
self.threshold = initial_threshold
self.model = Autoencoder(self.feature_dim)
self.criterion = nn.MSELoss()
self.triplet_loss = TripletLossBatch()
self.losses = []
self.save_interval = save_interval # バッチごとの出力間隔
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
img, mat, _ = batch
batch_size, _, _, size, size = img.shape
img = img.view(batch_size*2, 3, size, size)
mat = mat.view(batch_size*2, 3, 3)
dec5_output, output = self.model(img)
mse_loss = self.criterion(output, img)
# 画像内方向の処理
num_anchor_sets = 2**12
trip_loss = 0
std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch]
for c in std_list:
std = size / c
anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2))
coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long()
valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2
coords[valid_coords_idx] = 0
anchors[valid_coords_idx] = 0
# 最も近い座標の選択
d = pairwise_distance_squared(anchors.float(), coords.float())
idx = torch.argmin(d, dim=2)
anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx)
# dec5_outputから特徴ベクトルを抽出
anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives)
trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img)
trip_loss /= len(std_list)
self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach()
# 変形の学習
num_samples = 2**20
tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples)
# バッチ方向の処理
bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size)
# 合計損失
loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss
self.log("train_loss", loss)
# VRAM管理
del img, output
torch.cuda.empty_cache()
return loss
def _get_triplet_coordinates(self, anchors, coords, idx):
anchors = anchors.squeeze(2)
positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2)
negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2)
return anchors, positives, negatives
def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives):
y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors]
positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives]
negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives]
return anchor_vectors, positive_vectors, negative_vectors
def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12):
anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2)
coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
anchor_coords = torch.cat((coords_x, coords_y), 1)
anchor_mat = mat[anchor_indices]
tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords)
anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device)
inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size))
anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]]
inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1]
anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and]
return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean()
def _compute_batch_direction_loss(self, dec5_output, batch_size, size):
N = 2**12
anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,))
anchor_coords = torch.randint(0, size, (N, 2))
other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2))
other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2
other_coords = torch.randint(0, size, (N, 2, 2))
anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]]
other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]]
distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors)
return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10)
def configure_optimizers(self):
optimizer = SGD(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
return [optimizer], [scheduler]
# def save_intermediate_image(self, output, epoch):
# save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1)
# print(f"Saved intermediate image at epoch {epoch}")
# def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None):
# save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png")
# DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords)
def configure_optimizers(self):
optimizer = SGD(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
return [optimizer], [scheduler] |