m3 commited on
Commit
bfa11e6
·
1 Parent(s): 22fe4c5

chore: add sscd

Browse files
Files changed (1) hide show
  1. src/sscd.py +42 -0
src/sscd.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import torch
3
+ from PIL import Image
4
+ import torch.nn.functional as F
5
+ from matplotlib import pyplot as plt
6
+
7
+ ##### Global variable
8
+ normalize = transforms.Normalize(
9
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
10
+ )
11
+ preproccess = transforms.Compose([
12
+ transforms.Resize(288),
13
+ transforms.ToTensor(),
14
+ normalize,
15
+ ])
16
+
17
+ model = torch.jit.load("./msscddiscmixup.torchscript.pt")
18
+
19
+
20
+ def visualize(path):
21
+ image = Image.open(path)
22
+ plt.figure(figsize=(10, 10))
23
+ plt.axis('off')
24
+ plt.imshow(image)
25
+
26
+
27
+ def extract_feature(img_path):
28
+ img = Image.open(img_path).convert('RGB')
29
+ batch = preproccess(img).unsqueeze(0)
30
+ return model(batch)[0, :]
31
+
32
+
33
+ def simi(img1, img2):
34
+ visualize(img1)
35
+ visualize(img2)
36
+ vec1 = extract_feature(img1)
37
+ vec2 = extract_feature(img2)
38
+ # subtract the mean and then L2 norm
39
+ cos_sim = F.cosine_similarity(vec1, vec2, dim=0)
40
+ # 余弦相似度得分大于0.75,匹配准确度90%
41
+ print('similarity:', cos_sim)
42
+