Spaces:
Sleeping
Sleeping
import random | |
import pytest | |
import torch | |
from ding.torch_utils.metric import levenshtein_distance, hamming_distance | |
class TestMetric(): | |
def test_levenshtein_distance(self): | |
r''' | |
Overview: | |
Test the Levenshtein Distance | |
''' | |
pred = torch.LongTensor([1, 4, 6, 4, 1]) | |
target1 = torch.LongTensor([1, 6, 4, 4, 1]) | |
distance = levenshtein_distance(pred, target1) | |
assert (distance.item() == 2) | |
target2 = torch.LongTensor([]) | |
distance = levenshtein_distance(pred, target2) | |
assert (distance.item() == 5) | |
target3 = torch.LongTensor([6, 4, 1]) | |
distance = levenshtein_distance(pred, target3) | |
assert (distance.item() == 2) | |
target3 = torch.LongTensor([6, 4, 1]) | |
distance = levenshtein_distance(pred, target3, pred, target3, extra_fn=lambda x, y: x + y) | |
assert distance.item() == 13 | |
target4 = torch.LongTensor([1, 4, 1]) | |
distance = levenshtein_distance(pred, target4, pred, target4, extra_fn=lambda x, y: x + y) | |
assert distance.item() == 14 | |
def test_hamming_distance(self): | |
r''' | |
Overview: | |
Test the Hamming Distance | |
''' | |
base = torch.zeros(8).long() | |
index = [i for i in range(8)] | |
for i in range(2): | |
pred_idx = random.sample(index, 4) | |
target_idx = random.sample(index, 4) | |
pred = base.clone() | |
pred[pred_idx] = 1 | |
target = base.clone() | |
target[target_idx] = 1 | |
pred = pred.unsqueeze(0) | |
target = target.unsqueeze(0) | |
distance = hamming_distance(pred, target) | |
diff = len(set(pred_idx).union(set(target_idx)) - set(pred_idx).intersection(set(target_idx))) | |
assert (distance.item() == diff) | |