|
import torch |
|
import random |
|
|
|
|
|
|
|
|
|
|
|
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): |
|
''' |
|
将源域数据和目标域数据转化为核矩阵,即上文中的K |
|
Params: |
|
source: 源域数据(n * len(x)) |
|
target: 目标域数据(m * len(y)) |
|
kernel_mul: |
|
kernel_num: 取不同高斯核的数量 |
|
fix_sigma: 不同高斯核的sigma值 |
|
Return: |
|
sum(kernel_val): 多个核矩阵之和 |
|
''' |
|
n_samples = int(source.size()[0])+int(target.size()[0]) |
|
total = torch.cat([source, target], dim=0) |
|
|
|
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) |
|
|
|
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) |
|
|
|
L2_distance = ((total0-total1)**2).sum(2) |
|
|
|
if fix_sigma: |
|
bandwidth = fix_sigma |
|
else: |
|
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) |
|
|
|
bandwidth /= kernel_mul ** (kernel_num // 2) |
|
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] |
|
|
|
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] |
|
|
|
return sum(kernel_val) |
|
|
|
def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): |
|
''' |
|
计算源域数据和目标域数据的MMD距离 |
|
Params: |
|
source: 源域数据(n * len(x)) |
|
target: 目标域数据(m * len(y)) |
|
kernel_mul: |
|
kernel_num: 取不同高斯核的数量 |
|
fix_sigma: 不同高斯核的sigma值 |
|
Return: |
|
loss: MMD loss |
|
''' |
|
batch_size = int(source.size()[0]) |
|
kernels = guassian_kernel(source, target, |
|
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) |
|
|
|
XX = kernels[:batch_size, :batch_size] |
|
YY = kernels[batch_size:, batch_size:] |
|
XY = kernels[:batch_size, batch_size:] |
|
YX = kernels[batch_size:, :batch_size] |
|
loss = torch.mean(XX + YY - XY -YX) |
|
return loss |