""" This is a modified version which only extract text embedding in HF Space. See https://github.com/baaivision/Uni3D for source code. Or refer to https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py for extracting all embeddings. """ import os import sys import open_clip import torch from huggingface_hub import hf_hub_download sys.path.append('') from feature_extractors import FeatureExtractor from utils.tokenizer import SimpleTokenizer class Uni3dEmbeddingEncoder(FeatureExtractor): def __init__(self, cache_dir, **kwargs) -> None: bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz" clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin") if not os.path.exists(clip_path): hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin", cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = SimpleTokenizer(bpe_path) self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path) self.clip_model.to(self.device) @torch.no_grad() def encode_3D(self, data): raise NotImplementedError("For extracting 3D feature, see https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py") @torch.no_grad() def encode_text(self, input_text): texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True) if len(texts.shape) < 2: texts = texts[None, ...] class_embeddings = self.clip_model.encode_text(texts) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) return class_embeddings.float() @torch.no_grad() def encode_image(self, img_tensor_list): image = img_tensor_list.to(device=self.device, non_blocking=True) image_features = self.clip_model.encode_image(image) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.float() def encode_query(self, query_list): return self.encode_text(query_list) def get_img_transform(self): return self.preprocess