Spaces:
Running
Running
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Loads dataset for the sentence prediction (classification) task.""" | |
import dataclasses | |
import functools | |
from typing import List, Mapping, Optional, Tuple | |
import tensorflow as tf, tf_keras | |
import tensorflow_hub as hub | |
from official.common import dataset_fn | |
from official.core import config_definitions as cfg | |
from official.core import input_reader | |
from official.nlp import modeling | |
from official.nlp.data import data_loader | |
from official.nlp.data import data_loader_factory | |
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32} | |
class SentencePredictionDataConfig(cfg.DataConfig): | |
"""Data config for sentence prediction task (tasks/sentence_prediction).""" | |
input_path: str = '' | |
global_batch_size: int = 32 | |
is_training: bool = True | |
seq_length: int = 128 | |
label_type: str = 'int' | |
# Whether to include the example id number. | |
include_example_id: bool = False | |
label_field: str = 'label_ids' | |
# Maps the key in TfExample to feature name. | |
# E.g 'label_ids' to 'next_sentence_labels' | |
label_name: Optional[Tuple[str, str]] = None | |
# Either tfrecord, sstable, or recordio. | |
file_type: str = 'tfrecord' | |
class SentencePredictionDataLoader(data_loader.DataLoader): | |
"""A class to load dataset for sentence prediction (classification) task.""" | |
def __init__(self, params): | |
self._params = params | |
self._seq_length = params.seq_length | |
self._include_example_id = params.include_example_id | |
self._label_field = params.label_field | |
if params.label_name: | |
self._label_name_mapping = dict([params.label_name]) | |
else: | |
self._label_name_mapping = dict() | |
def name_to_features_spec(self): | |
"""Defines features to decode. Subclass may override to append features.""" | |
label_type = LABEL_TYPES_MAP[self._params.label_type] | |
name_to_features = { | |
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
self._label_field: tf.io.FixedLenFeature([], label_type), | |
} | |
if self._include_example_id: | |
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) | |
return name_to_features | |
def _decode(self, record: tf.Tensor): | |
"""Decodes a serialized tf.Example.""" | |
example = tf.io.parse_single_example(record, self.name_to_features_spec()) | |
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
# So cast all int64 to int32. | |
for name in example: | |
t = example[name] | |
if t.dtype == tf.int64: | |
t = tf.cast(t, tf.int32) | |
example[name] = t | |
return example | |
def _parse(self, record: Mapping[str, tf.Tensor]): | |
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" | |
key_mapping = { | |
'input_ids': 'input_word_ids', | |
'input_mask': 'input_mask', | |
'segment_ids': 'input_type_ids' | |
} | |
ret = {} | |
for record_key in record: | |
if record_key in key_mapping: | |
ret[key_mapping[record_key]] = record[record_key] | |
else: | |
ret[record_key] = record[record_key] | |
if self._label_field in self._label_name_mapping: | |
ret[self._label_name_mapping[self._label_field]] = record[ | |
self._label_field] | |
return ret | |
def load(self, input_context: Optional[tf.distribute.InputContext] = None): | |
"""Returns a tf.dataset.Dataset.""" | |
reader = input_reader.InputReader( | |
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type), | |
params=self._params, | |
decoder_fn=self._decode, | |
parser_fn=self._parse) | |
return reader.read(input_context) | |
class SentencePredictionTextDataConfig(cfg.DataConfig): | |
"""Data config for sentence prediction task with raw text.""" | |
# Either set `input_path`... | |
input_path: str = '' | |
# Either `int` or `float`. | |
label_type: str = 'int' | |
# ...or `tfds_name` and `tfds_split` to specify input. | |
tfds_name: str = '' | |
tfds_split: str = '' | |
# The name of the text feature fields. The text features will be | |
# concatenated in order. | |
text_fields: Optional[List[str]] = None | |
label_field: str = 'label' | |
global_batch_size: int = 32 | |
seq_length: int = 128 | |
is_training: bool = True | |
# Either build preprocessing with Python code by specifying these values | |
# for modeling.layers.BertTokenizer()/SentencepieceTokenizer().... | |
tokenization: str = 'WordPiece' # WordPiece or SentencePiece | |
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto | |
# file if tokenization is SentencePiece. | |
vocab_file: str = '' | |
lower_case: bool = True | |
# ...or load preprocessing from a SavedModel at this location. | |
preprocessing_hub_module_url: str = '' | |
# Either tfrecord or sstsable or recordio. | |
file_type: str = 'tfrecord' | |
include_example_id: bool = False | |
class TextProcessor(tf.Module): | |
"""Text features processing for sentence prediction task.""" | |
def __init__(self, | |
seq_length: int, | |
vocab_file: Optional[str] = None, | |
tokenization: Optional[str] = None, | |
lower_case: Optional[bool] = True, | |
preprocessing_hub_module_url: Optional[str] = None): | |
if preprocessing_hub_module_url: | |
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url) | |
self._tokenizer = self._preprocessing_hub_module.tokenize | |
self._pack_inputs = functools.partial( | |
self._preprocessing_hub_module.bert_pack_inputs, | |
seq_length=seq_length) | |
return | |
if tokenization == 'WordPiece': | |
self._tokenizer = modeling.layers.BertTokenizer( | |
vocab_file=vocab_file, lower_case=lower_case) | |
elif tokenization == 'SentencePiece': | |
self._tokenizer = modeling.layers.SentencepieceTokenizer( | |
model_file_path=vocab_file, | |
lower_case=lower_case, | |
strip_diacritics=True) # Strip diacritics to follow ALBERT model | |
else: | |
raise ValueError('Unsupported tokenization: %s' % tokenization) | |
self._pack_inputs = modeling.layers.BertPackInputs( | |
seq_length=seq_length, | |
special_tokens_dict=self._tokenizer.get_special_tokens_dict()) | |
def __call__(self, segments): | |
segments = [self._tokenizer(s) for s in segments] | |
# BertTokenizer returns a RaggedTensor with shape [batch, word, subword], | |
# and SentencepieceTokenizer returns a RaggedTensor with shape | |
# [batch, sentencepiece], | |
segments = [ | |
tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32) | |
for x in segments | |
] | |
return self._pack_inputs(segments) | |
class SentencePredictionTextDataLoader(data_loader.DataLoader): | |
"""Loads dataset with raw text for sentence prediction task.""" | |
def __init__(self, params): | |
if bool(params.tfds_name) != bool(params.tfds_split): | |
raise ValueError('`tfds_name` and `tfds_split` should be specified or ' | |
'unspecified at the same time.') | |
if bool(params.tfds_name) == bool(params.input_path): | |
raise ValueError('Must specify either `tfds_name` and `tfds_split` ' | |
'or `input_path`.') | |
if not params.text_fields: | |
raise ValueError('Unexpected empty text fields.') | |
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url): | |
raise ValueError('Must specify exactly one of vocab_file (with matching ' | |
'lower_case flag) or preprocessing_hub_module_url.') | |
self._params = params | |
self._text_fields = params.text_fields | |
self._label_field = params.label_field | |
self._label_type = params.label_type | |
self._include_example_id = params.include_example_id | |
self._text_processor = TextProcessor( | |
seq_length=params.seq_length, | |
vocab_file=params.vocab_file, | |
tokenization=params.tokenization, | |
lower_case=params.lower_case, | |
preprocessing_hub_module_url=params.preprocessing_hub_module_url) | |
def _bert_preprocess(self, record: Mapping[str, tf.Tensor]): | |
"""Berts preprocess.""" | |
segments = [record[x] for x in self._text_fields] | |
model_inputs = self._text_processor(segments) | |
for key in record: | |
if key not in self._text_fields: | |
model_inputs[key] = record[key] | |
return model_inputs | |
def name_to_features_spec(self): | |
name_to_features = {} | |
for text_field in self._text_fields: | |
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string) | |
label_type = LABEL_TYPES_MAP[self._label_type] | |
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type) | |
if self._include_example_id: | |
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) | |
return name_to_features | |
def _decode(self, record: tf.Tensor): | |
"""Decodes a serialized tf.Example.""" | |
example = tf.io.parse_single_example(record, self.name_to_features_spec()) | |
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
# So cast all int64 to int32. | |
for name in example: | |
t = example[name] | |
if t.dtype == tf.int64: | |
t = tf.cast(t, tf.int32) | |
example[name] = t | |
return example | |
def load(self, input_context: Optional[tf.distribute.InputContext] = None): | |
"""Returns a tf.dataset.Dataset.""" | |
reader = input_reader.InputReader( | |
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type), | |
decoder_fn=self._decode if self._params.input_path else None, | |
params=self._params, | |
postprocess_fn=self._bert_preprocess) | |
return reader.read(input_context) | |