deanna-emery's picture
updates
93528c6
raw
history blame
18.9 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
import dataclasses
from typing import Dict, List, Optional, Text
import tensorflow as tf, tf_keras
import tensorflow_text as tf_text
from official.core import export_base
from official.modeling.hyperparams import base_config
from official.nlp.data import sentence_prediction_dataloader
def features_to_int32(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Converts tf.int64 features to tf.int32, keep other features the same.
tf.Example only supports tf.int64, but the TPU only supports tf.int32.
Args:
features: Input tensor dictionary.
Returns:
Features with tf.int64 converted to tf.int32.
"""
converted_features = {}
for name, tensor in features.items():
if tensor.dtype == tf.int64:
converted_features[name] = tf.cast(tensor, tf.int32)
else:
converted_features[name] = tensor
return converted_features
class SentencePrediction(export_base.ExportModule):
"""The export module for the sentence prediction task."""
@dataclasses.dataclass
class Params(base_config.Config):
inputs_only: bool = True
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
# For text input processing.
text_fields: Optional[List[str]] = None
# Either specify these values for preprocessing by Python code...
tokenization: str = "WordPiece" # WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file: str = ""
lower_case: bool = True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url: str = ""
def __init__(self, params, model: tf_keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
if params.text_fields:
self._text_processor = sentence_prediction_dataloader.TextProcessor(
seq_length=params.parse_sequence_length,
vocab_file=params.vocab_file,
tokenization=params.tokenization,
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
def _serve_tokenized_input(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> tf.Tensor:
if input_type_ids is None:
# Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids)
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
return self.inference_step(inputs)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
return dict(
outputs=self._serve_tokenized_input(input_word_ids, input_mask,
input_type_ids))
@tf.function
def serve_probability(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
return dict(
outputs=tf.nn.softmax(
self._serve_tokenized_input(input_word_ids, input_mask,
input_type_ids)))
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
inputs_only = self.params.inputs_only
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
}
if not inputs_only:
name_to_features.update({
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
})
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
features[self.input_word_ids_field],
input_mask=None if inputs_only else features["input_mask"],
input_type_ids=None
if inputs_only else features[self.input_type_ids_field])
@tf.function
def serve_text_examples(self, inputs) -> Dict[str, tf.Tensor]:
name_to_features = {}
for text_field in self.params.text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
features = tf.io.parse_example(inputs, name_to_features)
segments = [features[x] for x in self.params.text_fields]
model_inputs = self._text_processor(segments)
if self.params.inputs_only:
return self.serve(input_word_ids=model_inputs["input_word_ids"])
return self.serve(**model_inputs)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples", "serve_text_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
if self.params.inputs_only:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"))
else:
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
if func_key == "serve_text_examples":
signatures[
signature_key] = self.serve_text_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class MaskedLM(export_base.ExportModule):
"""The export module for the Bert Pretrain (MaskedLM) task."""
def __init__(self, params, model: tf_keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@dataclasses.dataclass
class Params(base_config.Config):
cls_head_name: str = "next_sentence"
use_v2_feature_names: bool = True
parse_sequence_length: Optional[int] = None
max_predictions_per_seq: Optional[int] = None
@tf.function
def serve(self, input_word_ids, input_mask, input_type_ids,
masked_lm_positions) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids,
masked_lm_positions=masked_lm_positions)
outputs = self.inference_step(inputs)
return dict(classification=outputs[self.params.cls_head_name])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
max_predictions_per_seq = self.params.max_predictions_per_seq
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"masked_lm_positions":
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_word_ids_field],
masked_lm_positions=features["masked_lm_positions"])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"),
masked_lm_positions=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="masked_lm_positions"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class QuestionAnswering(export_base.ExportModule):
"""The export module for the question answering task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
def __init__(self, params, model: tf_keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
if input_mask is None:
# The mask has 1 for real tokens and 0 for padding tokens.
input_mask = tf.where(
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids),
tf.ones_like(input_word_ids))
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
return dict(start_logits=outputs[0], end_logits=outputs[1])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_word_ids"),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_type_ids"))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class Tagging(export_base.ExportModule):
"""The export module for the tagging task."""
@dataclasses.dataclass
class Params(base_config.Config):
parse_sequence_length: Optional[int] = None
use_v2_feature_names: bool = True
output_encoder_outputs: bool = False
def __init__(self, params, model: tf_keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
if params.use_v2_feature_names:
self.input_word_ids_field = "input_word_ids"
self.input_type_ids_field = "input_type_ids"
else:
self.input_word_ids_field = "input_ids"
self.input_type_ids_field = "segment_ids"
@tf.function
def serve(self, input_word_ids, input_mask,
input_type_ids) -> Dict[str, tf.Tensor]:
inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
outputs = self.inference_step(inputs)
if self.params.output_encoder_outputs:
return dict(
logits=outputs["logits"], encoder_outputs=outputs["encoder_outputs"])
else:
return dict(logits=outputs["logits"])
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
sequence_length = self.params.parse_sequence_length
name_to_features = {
self.input_word_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64),
"input_mask":
tf.io.FixedLenFeature([sequence_length], tf.int64),
self.input_type_ids_field:
tf.io.FixedLenFeature([sequence_length], tf.int64)
}
features = tf.io.parse_example(inputs, name_to_features)
features = features_to_int32(features)
return self.serve(
input_word_ids=features[self.input_word_ids_field],
input_mask=features["input_mask"],
input_type_ids=features[self.input_type_ids_field])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve", "serve_examples")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve":
signatures[signature_key] = self.serve.get_concrete_function(
input_word_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_word_ids_field),
input_mask=tf.TensorSpec(
shape=[None, None], dtype=tf.int32, name="input_mask"),
input_type_ids=tf.TensorSpec(
shape=[None, None],
dtype=tf.int32,
name=self.input_type_ids_field))
if func_key == "serve_examples":
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class Translation(export_base.ExportModule):
"""The export module for the translation task."""
@dataclasses.dataclass
class Params(base_config.Config):
sentencepiece_model_path: str = ""
# Needs to be specified if padded_decode is True/on TPUs.
batch_size: Optional[int] = None
def __init__(self, params, model: tf_keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
self._sp_tokenizer = tf_text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(),
add_eos=True)
try:
empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy()
except tf.errors.InternalError:
raise ValueError(
"EOS token not in tokenizer vocab."
"Please make sure the tokenizer generates a single token for an "
"empty string.")
self._eos_id = empty_str_tokenized.item()
self._batch_size = params.batch_size
@tf.function
def serve(self, inputs) -> Dict[str, tf.Tensor]:
return self.inference_step(inputs)
@tf.function
def serve_text(self, text: tf.Tensor) -> Dict[str, tf.Tensor]:
tokenized = self._sp_tokenizer.tokenize(text).to_tensor(0)
return self._sp_tokenizer.detokenize(
self.serve({"inputs": tokenized})["outputs"])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve_text")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve_text":
signatures[signature_key] = self.serve_text.get_concrete_function(
tf.TensorSpec(shape=[self._batch_size],
dtype=tf.string, name="text"))
return signatures