File size: 775 Bytes
319e2a1
0f98d7f
8f93744
319e2a1
0f98d7f
 
 
 
319e2a1
 
0f98d7f
 
 
 
319e2a1
 
 
 
 
0f98d7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
import numpy as np 
import torch

class VitBase():

    def __init__(self):
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
    
    def extract_feature(self, imgs):
        features = []
        for img in imgs:
            inputs = self.feature_extractor(images=img, return_tensors="pt")
            with torch.no_grad():
                outputs = self.model(**inputs)
            last_hidden_states =  outputs.last_hidden_state            
            features.append(np.squeeze(last_hidden_states.numpy()).flatten())            
        return features