from torchvision import transforms import torch from PIL import Image import torch.nn.functional as F from matplotlib import pyplot as plt ##### Global variable normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) preproccess = transforms.Compose([ transforms.Resize(288), transforms.ToTensor(), normalize, ]) model = torch.jit.load("./msscddiscmixup.torchscript.pt") def visualize(path): image = Image.open(path) plt.figure(figsize=(10, 10)) plt.axis('off') plt.imshow(image) def extract_feature(img_path): img = Image.open(img_path).convert('RGB') batch = preproccess(img).unsqueeze(0) return model(batch)[0, :] def simi(img1, img2): visualize(img1) visualize(img2) vec1 = extract_feature(img1) vec2 = extract_feature(img2) # subtract the mean and then L2 norm cos_sim = F.cosine_similarity(vec1, vec2, dim=0) # 余弦相似度得分大于0.75,匹配准确度90% print('similarity:', cos_sim)