|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Sentence prediction (classification) task.""" |
|
from absl import logging |
|
import dataclasses |
|
import numpy as np |
|
from scipy import stats |
|
from sklearn import metrics as sklearn_metrics |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
|
|
from official.core import base_task |
|
from official.modeling.hyperparams import config_definitions as cfg |
|
from official.nlp.configs import bert |
|
from official.nlp.data import sentence_prediction_dataloader |
|
from official.nlp.modeling import losses as loss_lib |
|
from official.nlp.tasks import utils |
|
|
|
|
|
@dataclasses.dataclass |
|
class SentencePredictionConfig(cfg.TaskConfig): |
|
"""The model config.""" |
|
|
|
|
|
init_checkpoint: str = '' |
|
hub_module_url: str = '' |
|
metric_type: str = 'accuracy' |
|
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig( |
|
num_masked_tokens=0, |
|
cls_heads=[ |
|
bert.ClsHeadConfig( |
|
inner_dim=768, |
|
num_classes=3, |
|
dropout_rate=0.1, |
|
name='sentence_prediction') |
|
]) |
|
train_data: cfg.DataConfig = cfg.DataConfig() |
|
validation_data: cfg.DataConfig = cfg.DataConfig() |
|
|
|
|
|
@base_task.register_task_cls(SentencePredictionConfig) |
|
class SentencePredictionTask(base_task.Task): |
|
"""Task object for sentence_prediction.""" |
|
|
|
def __init__(self, params=cfg.TaskConfig): |
|
super(SentencePredictionTask, self).__init__(params) |
|
if params.hub_module_url and params.init_checkpoint: |
|
raise ValueError('At most one of `hub_module_url` and ' |
|
'`pretrain_checkpoint_dir` can be specified.') |
|
if params.hub_module_url: |
|
self._hub_module = hub.load(params.hub_module_url) |
|
else: |
|
self._hub_module = None |
|
self.metric_type = params.metric_type |
|
|
|
def build_model(self): |
|
if self._hub_module: |
|
encoder_from_hub = utils.get_encoder_from_hub(self._hub_module) |
|
return bert.instantiate_bertpretrainer_from_cfg( |
|
self.task_config.network, encoder_network=encoder_from_hub) |
|
else: |
|
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network) |
|
|
|
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: |
|
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( |
|
labels=labels, |
|
predictions=tf.nn.log_softmax( |
|
tf.cast(model_outputs['sentence_prediction'], tf.float32), axis=-1)) |
|
|
|
if aux_losses: |
|
loss += tf.add_n(aux_losses) |
|
return loss |
|
|
|
def build_inputs(self, params, input_context=None): |
|
"""Returns tf.data.Dataset for sentence_prediction task.""" |
|
if params.input_path == 'dummy': |
|
|
|
def dummy_data(_): |
|
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32) |
|
x = dict( |
|
input_word_ids=dummy_ids, |
|
input_mask=dummy_ids, |
|
input_type_ids=dummy_ids) |
|
y = tf.ones((1, 1), dtype=tf.int32) |
|
return (x, y) |
|
|
|
dataset = tf.data.Dataset.range(1) |
|
dataset = dataset.repeat() |
|
dataset = dataset.map( |
|
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
return dataset |
|
|
|
return sentence_prediction_dataloader.SentencePredictionDataLoader( |
|
params).load(input_context) |
|
|
|
def build_metrics(self, training=None): |
|
del training |
|
metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')] |
|
return metrics |
|
|
|
def process_metrics(self, metrics, labels, model_outputs): |
|
for metric in metrics: |
|
metric.update_state(labels, model_outputs['sentence_prediction']) |
|
|
|
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): |
|
compiled_metrics.update_state(labels, model_outputs['sentence_prediction']) |
|
|
|
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): |
|
if self.metric_type == 'accuracy': |
|
return super(SentencePredictionTask, |
|
self).validation_step(inputs, model, metrics) |
|
features, labels = inputs |
|
outputs = self.inference_step(features, model) |
|
loss = self.build_losses( |
|
labels=labels, model_outputs=outputs, aux_losses=model.losses) |
|
if self.metric_type == 'matthews_corrcoef': |
|
return { |
|
self.loss: |
|
loss, |
|
'sentence_prediction': |
|
tf.expand_dims( |
|
tf.math.argmax(outputs['sentence_prediction'], axis=1), |
|
axis=0), |
|
'labels': |
|
labels, |
|
} |
|
if self.metric_type == 'pearson_spearman_corr': |
|
return { |
|
self.loss: loss, |
|
'sentence_prediction': outputs['sentence_prediction'], |
|
'labels': labels, |
|
} |
|
|
|
def aggregate_logs(self, state=None, step_outputs=None): |
|
if state is None: |
|
state = {'sentence_prediction': [], 'labels': []} |
|
state['sentence_prediction'].append( |
|
np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']], |
|
axis=0)) |
|
state['labels'].append( |
|
np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0)) |
|
return state |
|
|
|
def reduce_aggregated_logs(self, aggregated_logs): |
|
if self.metric_type == 'matthews_corrcoef': |
|
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0) |
|
labels = np.concatenate(aggregated_logs['labels'], axis=0) |
|
return { |
|
self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels) |
|
} |
|
if self.metric_type == 'pearson_spearman_corr': |
|
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0) |
|
labels = np.concatenate(aggregated_logs['labels'], axis=0) |
|
pearson_corr = stats.pearsonr(preds, labels)[0] |
|
spearman_corr = stats.spearmanr(preds, labels)[0] |
|
corr_metric = (pearson_corr + spearman_corr) / 2 |
|
return {self.metric_type: corr_metric} |
|
|
|
def initialize(self, model): |
|
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" |
|
ckpt_dir_or_file = self.task_config.init_checkpoint |
|
if tf.io.gfile.isdir(ckpt_dir_or_file): |
|
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) |
|
if not ckpt_dir_or_file: |
|
return |
|
|
|
pretrain2finetune_mapping = { |
|
'encoder': |
|
model.checkpoint_items['encoder'], |
|
'next_sentence.pooler_dense': |
|
model.checkpoint_items['sentence_prediction.pooler_dense'], |
|
} |
|
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping) |
|
status = ckpt.restore(ckpt_dir_or_file) |
|
status.expect_partial().assert_existing_objects_matched() |
|
logging.info('finished loading pretrained checkpoint from %s', |
|
ckpt_dir_or_file) |
|
|