File size: 2,403 Bytes
f15a1cd
4c05bb3
 
 
f15a1cd
 
4c05bb3
 
f15a1cd
4c05bb3
f15a1cd
4c05bb3
f15a1cd
 
 
 
 
 
d89efd0
f15a1cd
d89efd0
 
 
f15a1cd
d89efd0
f15a1cd
 
 
 
 
 
 
 
4c05bb3
f15a1cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
This is a modified version which only extract text embedding in HF Space.
See https://github.com/baaivision/Uni3D for source code.
Or refer to https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py for extracting all embeddings.
"""
import os
import sys

import open_clip
import torch
from huggingface_hub import hf_hub_download

sys.path.append('')
from feature_extractors import FeatureExtractor
from utils.tokenizer import SimpleTokenizer


class Uni3dEmbeddingEncoder(FeatureExtractor):
    def __init__(self, cache_dir, **kwargs) -> None:
        bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
        clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")

        if not os.path.exists(clip_path):
            hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin", 
                            cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = SimpleTokenizer(bpe_path)
        self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
        self.clip_model.to(self.device)
        
    @torch.no_grad()
    def encode_3D(self, data):
        raise NotImplementedError("For extracting 3D feature, see https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py")

    @torch.no_grad()
    def encode_text(self, input_text):
        texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True)
        if len(texts.shape) < 2:
            texts = texts[None, ...]
        class_embeddings = self.clip_model.encode_text(texts)
        class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
        return class_embeddings.float()

    @torch.no_grad()
    def encode_image(self, img_tensor_list):
        image = img_tensor_list.to(device=self.device, non_blocking=True)
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features.float()

    def encode_query(self, query_list):
        return self.encode_text(query_list)
    
    def get_img_transform(self):
        return self.preprocess