File size: 7,446 Bytes
a71961f
 
 
 
b8b8f72
a71961f
 
 
 
 
 
 
 
 
 
 
 
440a9f4
 
a71961f
440a9f4
 
 
 
6d8e609
 
a71961f
 
440a9f4
 
 
 
 
 
6d8e609
 
 
440a9f4
6d8e609
a71961f
d220929
 
440a9f4
a71961f
440a9f4
6d8e609
440a9f4
a71961f
440a9f4
 
6d8e609
440a9f4
a71961f
 
 
 
 
d220929
 
 
a71961f
 
440a9f4
a71961f
 
 
 
440a9f4
a71961f
b8b8f72
440a9f4
d220929
a71961f
 
 
 
440a9f4
a71961f
440a9f4
 
 
 
 
 
 
 
a71961f
d220929
a71961f
d220929
440a9f4
d220929
 
440a9f4
a71961f
 
d220929
440a9f4
b8b8f72
 
440a9f4
 
a71961f
d220929
440a9f4
a71961f
440a9f4
 
d220929
 
a71961f
d220929
 
a71961f
440a9f4
 
 
d220929
 
440a9f4
 
d220929
440a9f4
 
 
 
 
 
d220929
 
 
 
440a9f4
 
 
 
d220929
440a9f4
 
 
 
 
d220929
440a9f4
a71961f
 
440a9f4
a71961f
440a9f4
a71961f
 
d220929
a71961f
 
d220929
 
 
 
 
 
 
a71961f
 
 
 
 
 
 
b8b8f72
a71961f
d220929
 
 
 
440a9f4
d220929
 
440a9f4
d220929
a71961f
d220929
 
a71961f
d220929
b8b8f72
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import base64
import json
import os
from io import BytesIO
from typing import Any, Dict, List, Optional, Union

import requests
import torch
from PIL import Image
from torch import nn
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer


class Transformer(nn.Module):
    def __init__(
        self,
        model_name_or_path: str,
        tokenizer_name_or_path: Optional[str] = None,
        image_processor_name_or_path: Optional[str] = None,
        max_seq_length: Optional[int] = None,
        config_kwargs: Optional[Dict[str, Any]] = None,
        model_kwargs: Optional[Dict[str, Any]] = None,
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
        image_processor_kwargs: Optional[Dict[str, Any]] = None,
        cache_dir: str = None,
        **_,
    ) -> None:
        super(Transformer, self).__init__()

        config_kwargs = config_kwargs or {}
        model_kwargs = model_kwargs or {}
        tokenizer_kwargs = tokenizer_kwargs or {}
        image_processor_kwargs = image_processor_kwargs or {}

        config = AutoConfig.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **config_kwargs
        )
        self.model = AutoModel.from_pretrained(
            model_name_or_path, config=config, cache_dir=cache_dir, **model_kwargs
        )
        if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
            tokenizer_kwargs['model_max_length'] = max_seq_length

        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path or model_name_or_path,
            cache_dir=cache_dir,
            **tokenizer_kwargs,
        )
        self.image_processor = AutoImageProcessor.from_pretrained(
            image_processor_name_or_path or model_name_or_path,
            cache_dir=cache_dir,
            **image_processor_kwargs,
        )

        # No max_seq_length set. Try to infer from model
        if max_seq_length is None:
            if (
                hasattr(self.model, 'config')
                and hasattr(self.model.config, 'max_position_embeddings')
                and hasattr(self.tokenizer, 'model_max_length')
            ):
                max_seq_length = min(
                    self.model.config.max_position_embeddings,
                    self.tokenizer.model_max_length,
                )
        self.max_seq_length = max_seq_length
        if tokenizer_name_or_path is not None:
            self.model.config.tokenizer_class = self.tokenizer.__class__.__name__

    @staticmethod
    def _decode_data_image(data_image_str: str) -> Image.Image:
        header, data = data_image_str.split(',', 1)
        image_data = base64.b64decode(data)
        return Image.open(BytesIO(image_data))

    def tokenize(
        self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True
    ) -> Dict[str, torch.Tensor]:
        """
        Encodes input samples. Text samples are tokenized. Image URLs, image data
        buffers and PIL images are passed through the image processor.
        """
        _images = []
        _texts = []
        _image_or_text_descriptors = []
        for sample in texts:
            if isinstance(sample, str):
                if sample.startswith('http'):
                    response = requests.get(sample)
                    _images.append(Image.open(BytesIO(response.content)).convert('RGB'))
                    _image_or_text_descriptors.append(0)
                elif sample.startswith('data:image/'):
                    _images.append(self._decode_data_image(sample).convert('RGB'))
                    _image_or_text_descriptors.append(0)
                else:
                    try:
                        _images.append(Image.open(sample).convert('RGB'))
                        _image_or_text_descriptors.append(0)
                    except Exception as e:
                        _ = str(e)
                        _texts.append(sample)
                        _image_or_text_descriptors.append(1)
            elif isinstance(sample, Image.Image):
                _images.append(sample.convert('RGB'))
                _image_or_text_descriptors.append(0)

        encoding = {}
        if len(_texts):
            encoding['input_ids'] = self.tokenizer(
                _texts,
                padding=padding,
                truncation='longest_first',
                return_tensors='pt',
                max_length=self.max_seq_length,
            ).input_ids

        if len(_images):
            encoding['pixel_values'] = self.image_processor(
                _images, return_tensors='pt'
            ).pixel_values

        encoding['image_text_info'] = _image_or_text_descriptors
        return encoding

    def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        image_embeddings = []
        text_embeddings = []

        if 'pixel_values' in features:
            image_embeddings = self.model.get_image_features(features['pixel_values'])
        if 'input_ids' in features:
            text_embeddings = self.model.get_text_features(features['input_ids'])

        sentence_embedding = []
        image_features = iter(image_embeddings)
        text_features = iter(text_embeddings)
        for _, _input_type in enumerate(features['image_text_info']):
            if _input_type == 0:
                sentence_embedding.append(next(image_features))
            else:
                sentence_embedding.append(next(text_features))

        features['sentence_embedding'] = torch.stack(sentence_embedding).float()
        return features

    def save(self, output_path: str, safe_serialization: bool = True) -> None:
        self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
        self.tokenizer.save_pretrained(output_path)
        self.image_processor.save_pretrained(output_path)

    @staticmethod
    def load(input_path: str) -> 'Transformer':
        # Old classes used other config names than 'sentence_bert_config.json'
        for config_name in [
            'sentence_bert_config.json',
            'sentence_roberta_config.json',
            'sentence_distilbert_config.json',
            'sentence_camembert_config.json',
            'sentence_albert_config.json',
            'sentence_xlm-roberta_config.json',
            'sentence_xlnet_config.json',
        ]:
            sbert_config_path = os.path.join(input_path, config_name)
            if os.path.exists(sbert_config_path):
                break

        with open(sbert_config_path) as fIn:
            config = json.load(fIn)

        # Don't allow configs to set trust_remote_code
        if 'config_kwargs' in config and 'trust_remote_code' in config['config_kwargs']:
            config['config_kwargs'].pop('trust_remote_code')
        if 'model_kwargs' in config and 'trust_remote_code' in config['model_kwargs']:
            config['model_kwargs'].pop('trust_remote_code')
        if (
            'tokenizer_kwargs' in config
            and 'trust_remote_code' in config['tokenizer_kwargs']
        ):
            config['tokenizer_kwargs'].pop('trust_remote_code')
        if (
            'image_processor_kwargs' in config
            and 'trust_remote_code' in config['image_processor_kwargs']
        ):
            config['image_processor_kwargs'].pop('trust_remote_code')

        return Transformer(model_name_or_path=input_path, **config)