SeyedAli commited on
Commit
b82e756
·
1 Parent(s): 644361c

Create ViTMS.py

Browse files
src/similarity/model_implements/ViTMS.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, ViTMSNModel
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+ class ViTMS():
7
+
8
+ def __init__(self):
9
+ self.feature_extractor = AutoImageProcessor.from_pretrained('facebook/vit-msn-small')
10
+ self.model = ViTMSNModel.from_pretrained('facebook/vit-msn-small')
11
+
12
+ def extract_feature(self, imgs):
13
+ features = []
14
+ for img in imgs:
15
+ inputs = self.feature_extractor(images=img, return_tensors="pt")
16
+ with torch.no_grad():
17
+ outputs = self.model(**inputs)
18
+ last_hidden_states = outputs.last_hidden_state
19
+ features.append(np.squeeze(last_hidden_states.numpy()).flatten())
20
+ return features