sudy-super commited on
Commit
1fc3b01
1 Parent(s): aee9709

Upload tokenization_co_encoder.py

Browse files
Files changed (1) hide show
  1. tokenization_co_encoder.py +213 -0
tokenization_co_encoder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Tokenization classes for CoEncoder"""
3
+
4
+ import os
5
+ import json
6
+ from typing import List, Union, Optional
7
+ from transformers import AutoTokenizer
8
+ from transformers.processing_utils import ProcessorMixin
9
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
10
+ from transformers.utils import logging
11
+ from transformers.feature_extraction_utils import BatchFeature
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ class CoEncoderDualTokenizer(ProcessorMixin):
16
+ r"""
17
+ CoEncoderDualTokenizer is tokenizer for the CoEncoder model. It processes context and main text.
18
+
19
+ Args:
20
+ context_tokenizer ([`PreTrainedTokenizer`]):
21
+ The tokenizer for context.
22
+ text_tokenizer ([`PreTrainedTokenizer`]):
23
+ The tokenizer for main text.
24
+ """
25
+
26
+ attributes = ["context_tokenizer", "text_tokenizer"]
27
+ context_tokenizer_class = "AutoTokenizer"
28
+ text_tokenizer_class = "AutoTokenizer"
29
+
30
+ def __init__(self, context_tokenizer=None, text_tokenizer=None):
31
+ super().__init__(context_tokenizer, text_tokenizer)
32
+
33
+ @classmethod
34
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
35
+ """
36
+ Load both context and text tokenizers from a given repository.
37
+
38
+ Args:
39
+ pretrained_model_name_or_path (str): The name or path of the Hugging Face repository.
40
+
41
+ Returns:
42
+ CoEncoderDualTokenizer: An instance of the tokenizer class.
43
+ """
44
+ # Load context_tokenizer from 'context_tokenizer' directory
45
+ context_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
46
+ subfolder="context_tokenizer",
47
+ **kwargs
48
+ )
49
+
50
+ # Load text_tokenizer from 'text_tokenizer' directory
51
+ text_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
52
+ subfolder="text_tokenizer",
53
+ **kwargs
54
+ )
55
+
56
+ # Return a new instance of CoEncoderDualTokenizer with both tokenizers loaded
57
+ return cls(context_tokenizer=context_tokenizer, text_tokenizer=text_tokenizer)
58
+
59
+ def save_pretrained(self, save_directory: str, **kwargs):
60
+ """
61
+ Save the tokenizer to a directory, so that it can be reloaded using the `from_pretrained` class method.
62
+
63
+ Args:
64
+ save_directory (str): Directory to which to save.
65
+ """
66
+ # Save context tokenizer
67
+ context_save_dir = os.path.join(save_directory, 'context_tokenizer')
68
+ self.context_tokenizer.save_pretrained(context_save_dir, **kwargs)
69
+
70
+ # Save text tokenizer
71
+ text_save_dir = os.path.join(save_directory, 'text_tokenizer')
72
+ self.text_tokenizer.save_pretrained(text_save_dir, **kwargs)
73
+
74
+ # Save tokenizer config
75
+ tokenizer_config = {
76
+ "tokenizer_class": self.__class__.__name__,
77
+ }
78
+
79
+ with open(os.path.join(save_directory, 'tokenizer_config.json'), 'w', encoding='utf-8') as f:
80
+ json.dump(tokenizer_config, f, ensure_ascii=False)
81
+
82
+ def __call__(
83
+ self,
84
+ context: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
85
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
86
+ return_tensors: Optional[str] = None,
87
+ **kwargs
88
+ ) -> BatchFeature:
89
+ """
90
+ Main method to prepare inputs for the CoEncoder model.
91
+
92
+ Args:
93
+ context: Context text input.
94
+ text: Main text input.
95
+ return_tensors: Type of tensors to return.
96
+
97
+ Returns:
98
+ BatchFeature: A BatchFeature object containing model inputs.
99
+ """
100
+ if context is None and text is None:
101
+ raise ValueError("You must provide either context or text.")
102
+
103
+ features = {}
104
+
105
+ if context is not None:
106
+ context_features = self.context_tokenizer(
107
+ context,
108
+ return_tensors=return_tensors,
109
+ **kwargs
110
+ )
111
+ features.update({f"context_{k}": v for k, v in context_features.items()})
112
+
113
+ if text is not None:
114
+ text_features = self.text_tokenizer(
115
+ text,
116
+ return_tensors=return_tensors,
117
+ **kwargs
118
+ )
119
+ features.update({k: v for k, v in text_features.items()})
120
+
121
+ return BatchFeature(data=features, tensor_type=return_tensors)
122
+
123
+ def pad(
124
+ self,
125
+ encoded_inputs,
126
+ padding=True,
127
+ max_length=None,
128
+ return_tensors=None,
129
+ **kwargs
130
+ ):
131
+ """
132
+ Pads the encoded inputs to the maximum length in the batch.
133
+
134
+ Args:
135
+ encoded_inputs: A list of dictionaries containing context and text features.
136
+ padding: Whether to pad sequences.
137
+ max_length: Maximum length for padding.
138
+ return_tensors: Type of tensors to return.
139
+
140
+ Returns:
141
+ A dictionary with padded sequences.
142
+ """
143
+ # Separate context and text features
144
+ context_features = []
145
+ text_features = []
146
+
147
+ for feature in encoded_inputs:
148
+ # Extract context features
149
+ context_feature = {
150
+ k[len("context_"):]: v
151
+ for k, v in feature.items()
152
+ if k.startswith("context_")
153
+ }
154
+ if context_feature:
155
+ context_features.append(context_feature)
156
+ # Extract text features
157
+ text_feature = {
158
+ k: v
159
+ for k, v in feature.items()
160
+ if not k.startswith("context_")
161
+ }
162
+ if text_feature:
163
+ text_features.append(text_feature)
164
+
165
+ # Pad context features
166
+ if context_features:
167
+ context_padded = self.context_tokenizer.pad(
168
+ context_features,
169
+ padding=padding,
170
+ max_length=max_length,
171
+ return_tensors=return_tensors,
172
+ **kwargs.get("context_kwargs", {})
173
+ )
174
+ context_padded = {f"context_{k}": v for k, v in context_padded.items()}
175
+ else:
176
+ context_padded = {}
177
+
178
+ # Pad text features
179
+ if text_features:
180
+ text_padded = self.text_tokenizer.pad(
181
+ text_features,
182
+ padding=padding,
183
+ max_length=max_length,
184
+ return_tensors=return_tensors,
185
+ **kwargs.get("text_kwargs", {})
186
+ )
187
+ text_padded = {k: v for k, v in text_padded.items()}
188
+ else:
189
+ text_padded = {}
190
+
191
+ # Combine padded features
192
+ padded_features = {**context_padded, **text_padded}
193
+
194
+ return BatchFeature(data=padded_features, tensor_type=return_tensors)
195
+
196
+ def batch_decode(self, *args, **kwargs):
197
+ """
198
+ Calls the batch_decode method of the text_tokenizer.
199
+ """
200
+ return self.text_tokenizer.batch_decode(*args, **kwargs)
201
+
202
+ def decode(self, *args, **kwargs):
203
+ """
204
+ Calls the decode method of the text_tokenizer.
205
+ """
206
+ return self.text_tokenizer.decode(*args, **kwargs)
207
+
208
+ @property
209
+ def model_input_names(self):
210
+ """
211
+ Returns the model input names.
212
+ """
213
+ return list(dict.fromkeys(self.context_tokenizer.model_input_names + self.text_tokenizer.model_input_names))