# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for pre-processing classification data.""" from absl import logging from official.legacy.xlnet import data_utils SEG_ID_A = 0 SEG_ID_B = 1 class PaddingInputExample(object): """Fake example so the num input examples is a multiple of the batch size. When running eval/predict on the TPU, we need to pad the number of examples to be a multiple of the batch size, because the TPU requires a fixed batch size. The alternative is to drop the last batch, which is bad because it means the entire output data won't be generated. We use this class instead of `None` because treating `None` as padding battches could cause silent errors. """ class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, label_id, is_real_example=True): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.label_id = label_id self.is_real_example = is_real_example def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def convert_single_example(example_index, example, label_list, max_seq_length, tokenize_fn, use_bert_format): """Converts a single `InputExample` into a single `InputFeatures`.""" if isinstance(example, PaddingInputExample): return InputFeatures( input_ids=[0] * max_seq_length, input_mask=[1] * max_seq_length, segment_ids=[0] * max_seq_length, label_id=0, is_real_example=False) if label_list is not None: label_map = {} for (i, label) in enumerate(label_list): label_map[label] = i tokens_a = tokenize_fn(example.text_a) tokens_b = None if example.text_b: tokens_b = tokenize_fn(example.text_b) if tokens_b: # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for two [SEP] & one [CLS] with "- 3" _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) else: # Account for one [SEP] & one [CLS] with "- 2" if len(tokens_a) > max_seq_length - 2: tokens_a = tokens_a[:max_seq_length - 2] tokens = [] segment_ids = [] for token in tokens_a: tokens.append(token) segment_ids.append(SEG_ID_A) tokens.append(data_utils.SEP_ID) segment_ids.append(SEG_ID_A) if tokens_b: for token in tokens_b: tokens.append(token) segment_ids.append(SEG_ID_B) tokens.append(data_utils.SEP_ID) segment_ids.append(SEG_ID_B) if use_bert_format: tokens.insert(0, data_utils.CLS_ID) segment_ids.insert(0, data_utils.SEG_ID_CLS) else: tokens.append(data_utils.CLS_ID) segment_ids.append(data_utils.SEG_ID_CLS) input_ids = tokens # The mask has 0 for real tokens and 1 for padding tokens. Only real # tokens are attended to. input_mask = [0] * len(input_ids) # Zero-pad up to the sequence length. if len(input_ids) < max_seq_length: delta_len = max_seq_length - len(input_ids) if use_bert_format: input_ids = input_ids + [0] * delta_len input_mask = input_mask + [1] * delta_len segment_ids = segment_ids + [data_utils.SEG_ID_PAD] * delta_len else: input_ids = [0] * delta_len + input_ids input_mask = [1] * delta_len + input_mask segment_ids = [data_utils.SEG_ID_PAD] * delta_len + segment_ids assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length if label_list is not None: label_id = label_map[example.label] else: label_id = example.label if example_index < 5: logging.info("*** Example ***") logging.info("guid: %s", (example.guid)) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("label: %s (id = %d)", example.label, label_id) feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_id) return feature