onnx-base / src /sscd.py
m3's picture
chore: add sscd
bfa11e6
raw
history blame
1.02 kB
from torchvision import transforms
import torch
from PIL import Image
import torch.nn.functional as F
from matplotlib import pyplot as plt
##### Global variable
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
)
preproccess = transforms.Compose([
transforms.Resize(288),
transforms.ToTensor(),
normalize,
])
model = torch.jit.load("./msscddiscmixup.torchscript.pt")
def visualize(path):
image = Image.open(path)
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.imshow(image)
def extract_feature(img_path):
img = Image.open(img_path).convert('RGB')
batch = preproccess(img).unsqueeze(0)
return model(batch)[0, :]
def simi(img1, img2):
visualize(img1)
visualize(img2)
vec1 = extract_feature(img1)
vec2 = extract_feature(img2)
# subtract the mean and then L2 norm
cos_sim = F.cosine_similarity(vec1, vec2, dim=0)
# 余弦相似度得分大于0.75,匹配准确度90%
print('similarity:', cos_sim)