Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Serving export modules for TF Model Garden NLP models.""" | |
# pylint:disable=missing-class-docstring | |
import dataclasses | |
from typing import Dict, List, Optional, Text | |
import tensorflow as tf, tf_keras | |
import tensorflow_text as tf_text | |
from official.core import export_base | |
from official.modeling.hyperparams import base_config | |
from official.nlp.data import sentence_prediction_dataloader | |
def features_to_int32(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: | |
"""Converts tf.int64 features to tf.int32, keep other features the same. | |
tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
Args: | |
features: Input tensor dictionary. | |
Returns: | |
Features with tf.int64 converted to tf.int32. | |
""" | |
converted_features = {} | |
for name, tensor in features.items(): | |
if tensor.dtype == tf.int64: | |
converted_features[name] = tf.cast(tensor, tf.int32) | |
else: | |
converted_features[name] = tensor | |
return converted_features | |
class SentencePrediction(export_base.ExportModule): | |
"""The export module for the sentence prediction task.""" | |
class Params(base_config.Config): | |
inputs_only: bool = True | |
parse_sequence_length: Optional[int] = None | |
use_v2_feature_names: bool = True | |
# For text input processing. | |
text_fields: Optional[List[str]] = None | |
# Either specify these values for preprocessing by Python code... | |
tokenization: str = "WordPiece" # WordPiece or SentencePiece | |
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto | |
# file if tokenization is SentencePiece. | |
vocab_file: str = "" | |
lower_case: bool = True | |
# ...or load preprocessing from a SavedModel at this location. | |
preprocessing_hub_module_url: str = "" | |
def __init__(self, params, model: tf_keras.Model, inference_step=None): | |
super().__init__(params, model, inference_step) | |
if params.use_v2_feature_names: | |
self.input_word_ids_field = "input_word_ids" | |
self.input_type_ids_field = "input_type_ids" | |
else: | |
self.input_word_ids_field = "input_ids" | |
self.input_type_ids_field = "segment_ids" | |
if params.text_fields: | |
self._text_processor = sentence_prediction_dataloader.TextProcessor( | |
seq_length=params.parse_sequence_length, | |
vocab_file=params.vocab_file, | |
tokenization=params.tokenization, | |
lower_case=params.lower_case, | |
preprocessing_hub_module_url=params.preprocessing_hub_module_url) | |
def _serve_tokenized_input(self, | |
input_word_ids, | |
input_mask=None, | |
input_type_ids=None) -> tf.Tensor: | |
if input_type_ids is None: | |
# Requires CLS token is the first token of inputs. | |
input_type_ids = tf.zeros_like(input_word_ids) | |
if input_mask is None: | |
# The mask has 1 for real tokens and 0 for padding tokens. | |
input_mask = tf.where( | |
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids), | |
tf.ones_like(input_word_ids)) | |
inputs = dict( | |
input_word_ids=input_word_ids, | |
input_mask=input_mask, | |
input_type_ids=input_type_ids) | |
return self.inference_step(inputs) | |
def serve(self, | |
input_word_ids, | |
input_mask=None, | |
input_type_ids=None) -> Dict[str, tf.Tensor]: | |
return dict( | |
outputs=self._serve_tokenized_input(input_word_ids, input_mask, | |
input_type_ids)) | |
def serve_probability(self, | |
input_word_ids, | |
input_mask=None, | |
input_type_ids=None) -> Dict[str, tf.Tensor]: | |
return dict( | |
outputs=tf.nn.softmax( | |
self._serve_tokenized_input(input_word_ids, input_mask, | |
input_type_ids))) | |
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]: | |
sequence_length = self.params.parse_sequence_length | |
inputs_only = self.params.inputs_only | |
name_to_features = { | |
self.input_word_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
} | |
if not inputs_only: | |
name_to_features.update({ | |
"input_mask": | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
self.input_type_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64) | |
}) | |
features = tf.io.parse_example(inputs, name_to_features) | |
features = features_to_int32(features) | |
return self.serve( | |
features[self.input_word_ids_field], | |
input_mask=None if inputs_only else features["input_mask"], | |
input_type_ids=None | |
if inputs_only else features[self.input_type_ids_field]) | |
def serve_text_examples(self, inputs) -> Dict[str, tf.Tensor]: | |
name_to_features = {} | |
for text_field in self.params.text_fields: | |
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string) | |
features = tf.io.parse_example(inputs, name_to_features) | |
segments = [features[x] for x in self.params.text_fields] | |
model_inputs = self._text_processor(segments) | |
if self.params.inputs_only: | |
return self.serve(input_word_ids=model_inputs["input_word_ids"]) | |
return self.serve(**model_inputs) | |
def get_inference_signatures(self, function_keys: Dict[Text, Text]): | |
signatures = {} | |
valid_keys = ("serve", "serve_examples", "serve_text_examples") | |
for func_key, signature_key in function_keys.items(): | |
if func_key not in valid_keys: | |
raise ValueError("Invalid function key for the module: %s with key %s. " | |
"Valid keys are: %s" % | |
(self.__class__, func_key, valid_keys)) | |
if func_key == "serve": | |
if self.params.inputs_only: | |
signatures[signature_key] = self.serve.get_concrete_function( | |
input_word_ids=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_word_ids")) | |
else: | |
signatures[signature_key] = self.serve.get_concrete_function( | |
input_word_ids=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_word_ids"), | |
input_mask=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_mask"), | |
input_type_ids=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_type_ids")) | |
if func_key == "serve_examples": | |
signatures[signature_key] = self.serve_examples.get_concrete_function( | |
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")) | |
if func_key == "serve_text_examples": | |
signatures[ | |
signature_key] = self.serve_text_examples.get_concrete_function( | |
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")) | |
return signatures | |
class MaskedLM(export_base.ExportModule): | |
"""The export module for the Bert Pretrain (MaskedLM) task.""" | |
def __init__(self, params, model: tf_keras.Model, inference_step=None): | |
super().__init__(params, model, inference_step) | |
if params.use_v2_feature_names: | |
self.input_word_ids_field = "input_word_ids" | |
self.input_type_ids_field = "input_type_ids" | |
else: | |
self.input_word_ids_field = "input_ids" | |
self.input_type_ids_field = "segment_ids" | |
class Params(base_config.Config): | |
cls_head_name: str = "next_sentence" | |
use_v2_feature_names: bool = True | |
parse_sequence_length: Optional[int] = None | |
max_predictions_per_seq: Optional[int] = None | |
def serve(self, input_word_ids, input_mask, input_type_ids, | |
masked_lm_positions) -> Dict[str, tf.Tensor]: | |
inputs = dict( | |
input_word_ids=input_word_ids, | |
input_mask=input_mask, | |
input_type_ids=input_type_ids, | |
masked_lm_positions=masked_lm_positions) | |
outputs = self.inference_step(inputs) | |
return dict(classification=outputs[self.params.cls_head_name]) | |
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]: | |
sequence_length = self.params.parse_sequence_length | |
max_predictions_per_seq = self.params.max_predictions_per_seq | |
name_to_features = { | |
self.input_word_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
"input_mask": | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
self.input_type_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
"masked_lm_positions": | |
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64) | |
} | |
features = tf.io.parse_example(inputs, name_to_features) | |
features = features_to_int32(features) | |
return self.serve( | |
input_word_ids=features[self.input_word_ids_field], | |
input_mask=features["input_mask"], | |
input_type_ids=features[self.input_word_ids_field], | |
masked_lm_positions=features["masked_lm_positions"]) | |
def get_inference_signatures(self, function_keys: Dict[Text, Text]): | |
signatures = {} | |
valid_keys = ("serve", "serve_examples") | |
for func_key, signature_key in function_keys.items(): | |
if func_key not in valid_keys: | |
raise ValueError("Invalid function key for the module: %s with key %s. " | |
"Valid keys are: %s" % | |
(self.__class__, func_key, valid_keys)) | |
if func_key == "serve": | |
signatures[signature_key] = self.serve.get_concrete_function( | |
input_word_ids=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_word_ids"), | |
input_mask=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_mask"), | |
input_type_ids=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_type_ids"), | |
masked_lm_positions=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="masked_lm_positions")) | |
if func_key == "serve_examples": | |
signatures[signature_key] = self.serve_examples.get_concrete_function( | |
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")) | |
return signatures | |
class QuestionAnswering(export_base.ExportModule): | |
"""The export module for the question answering task.""" | |
class Params(base_config.Config): | |
parse_sequence_length: Optional[int] = None | |
use_v2_feature_names: bool = True | |
def __init__(self, params, model: tf_keras.Model, inference_step=None): | |
super().__init__(params, model, inference_step) | |
if params.use_v2_feature_names: | |
self.input_word_ids_field = "input_word_ids" | |
self.input_type_ids_field = "input_type_ids" | |
else: | |
self.input_word_ids_field = "input_ids" | |
self.input_type_ids_field = "segment_ids" | |
def serve(self, | |
input_word_ids, | |
input_mask=None, | |
input_type_ids=None) -> Dict[str, tf.Tensor]: | |
if input_mask is None: | |
# The mask has 1 for real tokens and 0 for padding tokens. | |
input_mask = tf.where( | |
tf.equal(input_word_ids, 0), tf.zeros_like(input_word_ids), | |
tf.ones_like(input_word_ids)) | |
inputs = dict( | |
input_word_ids=input_word_ids, | |
input_mask=input_mask, | |
input_type_ids=input_type_ids) | |
outputs = self.inference_step(inputs) | |
return dict(start_logits=outputs[0], end_logits=outputs[1]) | |
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]: | |
sequence_length = self.params.parse_sequence_length | |
name_to_features = { | |
self.input_word_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
"input_mask": | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
self.input_type_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64) | |
} | |
features = tf.io.parse_example(inputs, name_to_features) | |
features = features_to_int32(features) | |
return self.serve( | |
input_word_ids=features[self.input_word_ids_field], | |
input_mask=features["input_mask"], | |
input_type_ids=features[self.input_type_ids_field]) | |
def get_inference_signatures(self, function_keys: Dict[Text, Text]): | |
signatures = {} | |
valid_keys = ("serve", "serve_examples") | |
for func_key, signature_key in function_keys.items(): | |
if func_key not in valid_keys: | |
raise ValueError("Invalid function key for the module: %s with key %s. " | |
"Valid keys are: %s" % | |
(self.__class__, func_key, valid_keys)) | |
if func_key == "serve": | |
signatures[signature_key] = self.serve.get_concrete_function( | |
input_word_ids=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_word_ids"), | |
input_mask=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_mask"), | |
input_type_ids=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_type_ids")) | |
if func_key == "serve_examples": | |
signatures[signature_key] = self.serve_examples.get_concrete_function( | |
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")) | |
return signatures | |
class Tagging(export_base.ExportModule): | |
"""The export module for the tagging task.""" | |
class Params(base_config.Config): | |
parse_sequence_length: Optional[int] = None | |
use_v2_feature_names: bool = True | |
output_encoder_outputs: bool = False | |
def __init__(self, params, model: tf_keras.Model, inference_step=None): | |
super().__init__(params, model, inference_step) | |
if params.use_v2_feature_names: | |
self.input_word_ids_field = "input_word_ids" | |
self.input_type_ids_field = "input_type_ids" | |
else: | |
self.input_word_ids_field = "input_ids" | |
self.input_type_ids_field = "segment_ids" | |
def serve(self, input_word_ids, input_mask, | |
input_type_ids) -> Dict[str, tf.Tensor]: | |
inputs = dict( | |
input_word_ids=input_word_ids, | |
input_mask=input_mask, | |
input_type_ids=input_type_ids) | |
outputs = self.inference_step(inputs) | |
if self.params.output_encoder_outputs: | |
return dict( | |
logits=outputs["logits"], encoder_outputs=outputs["encoder_outputs"]) | |
else: | |
return dict(logits=outputs["logits"]) | |
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]: | |
sequence_length = self.params.parse_sequence_length | |
name_to_features = { | |
self.input_word_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
"input_mask": | |
tf.io.FixedLenFeature([sequence_length], tf.int64), | |
self.input_type_ids_field: | |
tf.io.FixedLenFeature([sequence_length], tf.int64) | |
} | |
features = tf.io.parse_example(inputs, name_to_features) | |
features = features_to_int32(features) | |
return self.serve( | |
input_word_ids=features[self.input_word_ids_field], | |
input_mask=features["input_mask"], | |
input_type_ids=features[self.input_type_ids_field]) | |
def get_inference_signatures(self, function_keys: Dict[Text, Text]): | |
signatures = {} | |
valid_keys = ("serve", "serve_examples") | |
for func_key, signature_key in function_keys.items(): | |
if func_key not in valid_keys: | |
raise ValueError("Invalid function key for the module: %s with key %s. " | |
"Valid keys are: %s" % | |
(self.__class__, func_key, valid_keys)) | |
if func_key == "serve": | |
signatures[signature_key] = self.serve.get_concrete_function( | |
input_word_ids=tf.TensorSpec( | |
shape=[None, None], | |
dtype=tf.int32, | |
name=self.input_word_ids_field), | |
input_mask=tf.TensorSpec( | |
shape=[None, None], dtype=tf.int32, name="input_mask"), | |
input_type_ids=tf.TensorSpec( | |
shape=[None, None], | |
dtype=tf.int32, | |
name=self.input_type_ids_field)) | |
if func_key == "serve_examples": | |
signatures[signature_key] = self.serve_examples.get_concrete_function( | |
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")) | |
return signatures | |
class Translation(export_base.ExportModule): | |
"""The export module for the translation task.""" | |
class Params(base_config.Config): | |
sentencepiece_model_path: str = "" | |
# Needs to be specified if padded_decode is True/on TPUs. | |
batch_size: Optional[int] = None | |
def __init__(self, params, model: tf_keras.Model, inference_step=None): | |
super().__init__(params, model, inference_step) | |
self._sp_tokenizer = tf_text.SentencepieceTokenizer( | |
model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(), | |
add_eos=True) | |
try: | |
empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy() | |
except tf.errors.InternalError: | |
raise ValueError( | |
"EOS token not in tokenizer vocab." | |
"Please make sure the tokenizer generates a single token for an " | |
"empty string.") | |
self._eos_id = empty_str_tokenized.item() | |
self._batch_size = params.batch_size | |
def serve(self, inputs) -> Dict[str, tf.Tensor]: | |
return self.inference_step(inputs) | |
def serve_text(self, text: tf.Tensor) -> Dict[str, tf.Tensor]: | |
tokenized = self._sp_tokenizer.tokenize(text).to_tensor(0) | |
return self._sp_tokenizer.detokenize( | |
self.serve({"inputs": tokenized})["outputs"]) | |
def get_inference_signatures(self, function_keys: Dict[Text, Text]): | |
signatures = {} | |
valid_keys = ("serve_text") | |
for func_key, signature_key in function_keys.items(): | |
if func_key not in valid_keys: | |
raise ValueError("Invalid function key for the module: %s with key %s. " | |
"Valid keys are: %s" % | |
(self.__class__, func_key, valid_keys)) | |
if func_key == "serve_text": | |
signatures[signature_key] = self.serve_text.get_concrete_function( | |
tf.TensorSpec(shape=[self._batch_size], | |
dtype=tf.string, name="text")) | |
return signatures | |