File size: 3,629 Bytes
68cd8f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
from functools import lru_cache
from pathlib import Path
from typing import Union

from . import modeling_finetune, utils
from PIL import Image
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models import create_model
from torchvision import transforms
from transformers import XLMRobertaTokenizer

# Get current workdir of this file
CWD = Path(__file__).parent
print(CWD)


class Preprocess:
    def __init__(self, tokenizer):
        self.max_len = 64
        self.input_size = 384

        self.tokenizer = tokenizer
        self.transform = transforms.Compose(
            [
                transforms.Resize((self.input_size, self.input_size), interpolation=3),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
                ),
            ]
        )

        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.pad_token_id = tokenizer.pad_token_id

    def preprocess(self, input: Union[str, Image.Image]):
        if isinstance(input, str):
            tokens = self.tokenizer.tokenize(input)
            tokens = self.tokenizer.convert_tokens_to_ids(tokens)

            tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id]
            num_tokens = len(tokens)
            padding_mask = [0] * num_tokens + [1] * (self.max_len - num_tokens)

            return (
                torch.LongTensor(
                    tokens + [self.pad_token_id] * (self.max_len - num_tokens)
                ).unsqueeze(0),
                torch.Tensor(padding_mask).unsqueeze(0),
                num_tokens,
            )
        elif isinstance(input, Image.Image):
            return self.transform(input).unsqueeze(0)
        else:
            raise Exception("Invalid input type")


class Beit3Model:
    def __init__(
        self,
        model_name: str = "beit3_base_patch16_384_retrieval",
        model_path: str = os.path.join(
            CWD,
            "beit3_model/beit3_base_patch16_384_f30k_retrieval.pth",
        ),
        device: str = "cuda",
    ):
        self._load_model(model_name, model_path, device)
        self.device = device

    # @lru_cache(maxsize=1)
    def _load_model(self, model_name, model_path, device: str = "cpu"):
        self.model = create_model(
            model_name,
            pretrained=False,
            drop_path_rate=0.1,
            vocab_size=64010,
            checkpoint_activations=False,
        )

        if model_name:
            utils.load_model_and_may_interpolate(
                model_path, self.model, "model|module", ""
            )

        self.preprocessor = Preprocess(
            XLMRobertaTokenizer(os.path.join(CWD, "beit3_model/beit3.spm"))
        )
        self.model.to(device)

    def get_embedding(self, input: Union[str, Image.Image]):
        if isinstance(input, str):
            token_ids, padding_mask, _ = self.preprocessor.preprocess(input)

            _, vector = self.model(
                text_description=token_ids, padding_mask=padding_mask, only_infer=True
            )
            vector = vector.cpu().detach().numpy().astype("float32")
            return vector
        elif isinstance(input, Image.Image):
            image_input = self.preprocessor.preprocess(input)
            image_input = image_input.to(self.device)
            vector, _ = self.model(image=image_input, only_infer=True)
            return vector
        else:
            raise Exception("Invalid input type")