NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
from absl import logging
import dataclasses
import numpy as np
from scipy import stats
from sklearn import metrics as sklearn_metrics
import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.modeling import losses as loss_lib
from official.nlp.tasks import utils
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
metric_type: str = 'accuracy'
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0, # No masked language modeling head.
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=3,
dropout_rate=0.1,
name='sentence_prediction')
])
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(SentencePredictionConfig)
class SentencePredictionTask(base_task.Task):
"""Task object for sentence_prediction."""
def __init__(self, params=cfg.TaskConfig):
super(SentencePredictionTask, self).__init__(params)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
self.metric_type = params.metric_type
def build_model(self):
if self._hub_module:
encoder_from_hub = utils.get_encoder_from_hub(self._hub_module)
return bert.instantiate_bertpretrainer_from_cfg(
self.task_config.network, encoder_network=encoder_from_hub)
else:
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels,
predictions=tf.nn.log_softmax(
tf.cast(model_outputs['sentence_prediction'], tf.float32), axis=-1))
if aux_losses:
loss += tf.add_n(aux_losses)
return loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = tf.ones((1, 1), dtype=tf.int32)
return (x, y)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return sentence_prediction_dataloader.SentencePredictionDataLoader(
params).load(input_context)
def build_metrics(self, training=None):
del training
metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels, model_outputs['sentence_prediction'])
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
if self.metric_type == 'matthews_corrcoef':
return {
self.loss:
loss,
'sentence_prediction':
tf.expand_dims(
tf.math.argmax(outputs['sentence_prediction'], axis=1),
axis=0),
'labels':
labels,
}
if self.metric_type == 'pearson_spearman_corr':
return {
self.loss: loss,
'sentence_prediction': outputs['sentence_prediction'],
'labels': labels,
}
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
state = {'sentence_prediction': [], 'labels': []}
state['sentence_prediction'].append(
np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
axis=0))
state['labels'].append(
np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
return state
def reduce_aggregated_logs(self, aggregated_logs):
if self.metric_type == 'matthews_corrcoef':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
return {
self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
}
if self.metric_type == 'pearson_spearman_corr':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
pearson_corr = stats.pearsonr(preds, labels)[0]
spearman_corr = stats.spearmanr(preds, labels)[0]
corr_metric = (pearson_corr + spearman_corr) / 2
return {self.metric_type: corr_metric}
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
pretrain2finetune_mapping = {
'encoder':
model.checkpoint_items['encoder'],
'next_sentence.pooler_dense':
model.checkpoint_items['sentence_prediction.pooler_dense'],
}
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)