|
from torchvision import transforms |
|
import torch |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
from matplotlib import pyplot as plt |
|
|
|
|
|
normalize = transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], |
|
) |
|
preproccess = transforms.Compose([ |
|
transforms.Resize(288), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
model = torch.jit.load("./msscddiscmixup.torchscript.pt") |
|
|
|
|
|
def visualize(path): |
|
image = Image.open(path) |
|
plt.figure(figsize=(10, 10)) |
|
plt.axis('off') |
|
plt.imshow(image) |
|
|
|
|
|
def extract_feature(img_path): |
|
img = Image.open(img_path).convert('RGB') |
|
batch = preproccess(img).unsqueeze(0) |
|
return model(batch)[0, :] |
|
|
|
|
|
def simi(img1, img2): |
|
visualize(img1) |
|
visualize(img2) |
|
vec1 = extract_feature(img1) |
|
vec2 = extract_feature(img2) |
|
|
|
cos_sim = F.cosine_similarity(vec1, vec2, dim=0) |
|
|
|
print('similarity:', cos_sim) |
|
|
|
|