|
import random
|
|
|
|
import torch
|
|
|
|
|
|
def augment_domains(self, groups_feature_maps):
|
|
|
|
def hard_example_interpolation(z_i, hard_example, lambda_1):
|
|
return z_i + lambda_1 * (hard_example - z_i)
|
|
|
|
def hard_example_extrapolation(z_i, mean_latent, lambda_2):
|
|
return z_i + lambda_2 * (z_i - mean_latent)
|
|
|
|
def add_gaussian_noise(z_i, sigma, lambda_3):
|
|
epsilon = torch.randn_like(z_i) * sigma
|
|
return z_i + lambda_3 * epsilon
|
|
|
|
def difference_transform(z_i, z_j, z_k, lambda_4):
|
|
return z_i + lambda_4 * (z_j - z_k)
|
|
|
|
def distance(z_i, z_j):
|
|
return torch.norm(z_i - z_j)
|
|
|
|
domain_number = len(groups_feature_maps[0])
|
|
|
|
|
|
domain_means = []
|
|
for domain_idx in range(domain_number):
|
|
all_samples_in_domain = torch.cat([group[domain_idx] for group in groups_feature_maps], dim=0)
|
|
domain_mean = torch.mean(all_samples_in_domain, dim=0)
|
|
domain_means.append(domain_mean)
|
|
|
|
|
|
hard_examples = []
|
|
for domain_idx in range(domain_number):
|
|
all_samples_in_domain = torch.cat([group[domain_idx] for group in groups_feature_maps], dim=0)
|
|
distances = torch.tensor([distance(z, domain_means[domain_idx]) for z in all_samples_in_domain])
|
|
hard_example = all_samples_in_domain[torch.argmax(distances)]
|
|
hard_examples.append(hard_example)
|
|
|
|
augmented_groups = []
|
|
|
|
for group_feature_maps in groups_feature_maps:
|
|
augmented_domains = []
|
|
|
|
for domain_idx, domain_feature_maps in enumerate(group_feature_maps):
|
|
|
|
augmentations = [
|
|
lambda z: hard_example_interpolation(z, hard_examples[domain_idx], random.random()),
|
|
lambda z: hard_example_extrapolation(z, domain_means[domain_idx], random.random()),
|
|
lambda z: add_gaussian_noise(z, random.random(), random.random()),
|
|
lambda z: difference_transform(z, domain_feature_maps[0], domain_feature_maps[1], random.random())
|
|
]
|
|
chosen_aug = random.choice(augmentations)
|
|
augmented = torch.stack([chosen_aug(z) for z in domain_feature_maps])
|
|
augmented_domains.append(augmented)
|
|
|
|
augmented_domains = torch.stack(augmented_domains)
|
|
augmented_groups.append(augmented_domains)
|
|
|
|
return torch.stack(augmented_groups)
|
|
|
|
|
|
def mixup_in_latent_space(self, data):
|
|
|
|
bs, num_domains, _, _, _ = data.shape
|
|
|
|
|
|
mixed_data = torch.zeros_like(data)
|
|
|
|
|
|
for i in range(bs):
|
|
|
|
shuffled_idxs = torch.randperm(num_domains)
|
|
|
|
|
|
alpha = torch.rand(1) * 1.5 + 0.5
|
|
lambda_ = torch.distributions.beta.Beta(alpha, alpha).sample().to(data.device)
|
|
|
|
|
|
mixed_data[i] = lambda_ * data[i] + (1 - lambda_) * data[i, shuffled_idxs]
|
|
|
|
return mixed_data |