|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A script to export the ALBERT core model as a TF-Hub SavedModel.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
from absl import app |
|
from absl import flags |
|
import tensorflow as tf |
|
from typing import Text |
|
|
|
from official.nlp.albert import configs |
|
from official.nlp.bert import bert_models |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string("albert_config_file", None, |
|
"Albert configuration file to define core albert layers.") |
|
flags.DEFINE_string("model_checkpoint_path", None, |
|
"File path to TF model checkpoint.") |
|
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.") |
|
flags.DEFINE_string( |
|
"sp_model_file", None, |
|
"The sentence piece model file that the ALBERT model was trained on.") |
|
|
|
|
|
def create_albert_model( |
|
albert_config: configs.AlbertConfig) -> tf.keras.Model: |
|
"""Creates an ALBERT keras core model from ALBERT configuration. |
|
|
|
Args: |
|
albert_config: An `AlbertConfig` to create the core model. |
|
|
|
Returns: |
|
A keras model. |
|
""" |
|
|
|
input_word_ids = tf.keras.layers.Input( |
|
shape=(None,), dtype=tf.int32, name="input_word_ids") |
|
input_mask = tf.keras.layers.Input( |
|
shape=(None,), dtype=tf.int32, name="input_mask") |
|
input_type_ids = tf.keras.layers.Input( |
|
shape=(None,), dtype=tf.int32, name="input_type_ids") |
|
transformer_encoder = bert_models.get_transformer_encoder( |
|
albert_config, sequence_length=None) |
|
sequence_output, pooled_output = transformer_encoder( |
|
[input_word_ids, input_mask, input_type_ids]) |
|
|
|
|
|
return tf.keras.Model( |
|
inputs=[input_word_ids, input_mask, input_type_ids], |
|
outputs=[pooled_output, sequence_output]), transformer_encoder |
|
|
|
|
|
def export_albert_tfhub(albert_config: configs.AlbertConfig, |
|
model_checkpoint_path: Text, hub_destination: Text, |
|
sp_model_file: Text): |
|
"""Restores a tf.keras.Model and saves for TF-Hub.""" |
|
core_model, encoder = create_albert_model(albert_config) |
|
checkpoint = tf.train.Checkpoint(model=encoder) |
|
checkpoint.restore(model_checkpoint_path).assert_consumed() |
|
core_model.sp_model_file = tf.saved_model.Asset(sp_model_file) |
|
core_model.save(hub_destination, include_optimizer=False, save_format="tf") |
|
|
|
|
|
def main(_): |
|
albert_config = configs.AlbertConfig.from_json_file( |
|
FLAGS.albert_config_file) |
|
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path, |
|
FLAGS.export_path, FLAGS.sp_model_file) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|