|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from numba import jit |
|
|
|
import torch |
|
|
|
|
|
@jit |
|
def time_warp(costs): |
|
dtw = np.zeros_like(costs) |
|
dtw[0, 1:] = np.inf |
|
dtw[1:, 0] = np.inf |
|
eps = 1e-4 |
|
for i in range(1, costs.shape[0]): |
|
for j in range(1, costs.shape[1]): |
|
dtw[i, j] = costs[i, j] + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1]) |
|
return dtw |
|
|
|
|
|
def align_from_distances(distance_matrix, debug=False, return_mindist=False): |
|
|
|
|
|
dtw = time_warp(distance_matrix) |
|
|
|
i = distance_matrix.shape[0] - 1 |
|
j = distance_matrix.shape[1] - 1 |
|
results = [0] * distance_matrix.shape[0] |
|
while i > 0 and j > 0: |
|
results[i] = j |
|
i, j = min([(i - 1, j), (i, j - 1), (i - 1, j - 1)], key=lambda x: dtw[x[0], x[1]]) |
|
|
|
if debug: |
|
visual = np.zeros_like(dtw) |
|
visual[range(len(results)), results] = 1 |
|
plt.matshow(visual) |
|
plt.show() |
|
if return_mindist: |
|
return results, dtw[-1, -1] |
|
return results |
|
|
|
|
|
def get_local_context(input_f, max_window=32, scale_factor=1.): |
|
|
|
|
|
T = input_f.shape[0] |
|
|
|
derivative = [[0 for _ in range(max_window * 2)] for _ in range(T)] |
|
|
|
for t in range(T): |
|
for feat_idx in range(-max_window, max_window): |
|
if t + feat_idx < 0 or t + feat_idx >= T: |
|
value = 0 |
|
else: |
|
value = input_f[t + feat_idx] |
|
derivative[t][feat_idx + max_window] = value |
|
return derivative |
|
|
|
|
|
def cal_localnorm_dist(src, tgt, src_len, tgt_len): |
|
local_src = torch.tensor(get_local_context(src)) |
|
local_tgt = torch.tensor(get_local_context(tgt, scale_factor=tgt_len / src_len)) |
|
|
|
local_norm_src = (local_src - local_src.mean(-1).unsqueeze(-1)) |
|
local_norm_tgt = (local_tgt - local_tgt.mean(-1).unsqueeze(-1)) |
|
|
|
dists = torch.cdist(local_norm_src[None, :, :], local_norm_tgt[None, :, :]) |
|
return dists |
|
|
|
|
|
|
|
def LoNDTWDistance(src, tgt): |
|
|
|
|
|
dists = cal_localnorm_dist(src, tgt, src.shape[0], tgt.shape[0]) |
|
costs = dists.squeeze(0) |
|
alignment, min_distance = align_from_distances(costs.T.cpu().detach().numpy(), return_mindist=True) |
|
return alignment, min_distance |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|