|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BERT encoder, optionally loading pre-trained checkpoints.""" |
|
|
|
import dataclasses |
|
from typing import Optional |
|
|
|
from absl import logging |
|
from big_vision import utils |
|
from big_vision.models import common |
|
import flax |
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
from tensorflow.io import gfile |
|
|
|
from flaxformer.architectures.bert import bert |
|
from flaxformer.architectures.bert import bert_checkpoint_converter |
|
from flaxformer.architectures.bert import configs |
|
|
|
|
|
class Model(nn.Module): |
|
"""BERT encoder with linear projection on last layer CLS token.""" |
|
|
|
config: str |
|
num_classes: Optional[int] = None |
|
head_zeroinit: bool = True |
|
|
|
@nn.compact |
|
def __call__(self, text, *, train=False): |
|
out = {} |
|
|
|
batch_size, max_len = text.shape |
|
bert_model = bert.BertEncoder(**dataclasses.asdict({ |
|
"base": configs.BertBaseConfig(), |
|
"large": configs.BertLargeConfig(), |
|
}[self.config])) |
|
x = out["transformed"] = bert_model( |
|
token_ids=text, |
|
position_ids=jnp.tile( |
|
jnp.arange(0, max_len, dtype=jnp.int32), [batch_size, 1]), |
|
segment_ids=jnp.zeros([batch_size, max_len], dtype=jnp.int32), |
|
input_mask=text.astype(jnp.bool_).astype(jnp.int32), |
|
enable_dropout=train, |
|
) |
|
|
|
x = out["pre_logits"] = x[:, 0] |
|
|
|
if self.num_classes: |
|
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} |
|
x = out["logits"] = nn.Dense(self.num_classes, name="head", **kw)(x) |
|
|
|
return x, out |
|
|
|
|
|
def load(params, path, model_cfg=None, dont_load=()): |
|
"""Returns `params` with BERT weights replaced from checkpoint at `path`.""" |
|
del model_cfg |
|
|
|
checkpoint_path = f"{path}/bert_model.ckpt" |
|
if gfile.exists(f"{checkpoint_path}.index"): |
|
logging.info("Loading original BERT checkpoint from '%s'", checkpoint_path) |
|
params = flax.core.FrozenDict(params).unfreeze() |
|
max_len = ( |
|
params["BertEncoder_0"]["embedder"]["embedders_position_ids"] |
|
["embedding"].shape[0]) |
|
bert_params, pooler_params = ( |
|
bert_checkpoint_converter.load_params_from_tf_checkpoint( |
|
checkpoint_path=f"{path}/bert_model.ckpt")) |
|
del pooler_params |
|
if isinstance(bert_params, flax.core.FrozenDict): |
|
bert_params = bert_params.unfreeze() |
|
bert_params["embedder"]["embedders_position_ids"]["embedding"] = ( |
|
bert_params["embedder"]["embedders_position_ids"]["embedding"][:max_len] |
|
) |
|
return common.merge_params( |
|
{"BertEncoder_0": bert_params}, params, dont_load) |
|
|
|
logging.info( |
|
"Could not find original BERT checkpoint path '%s', " |
|
"loading big_vision checkpoint '%s'", checkpoint_path, path) |
|
restored_params = utils.load_params(path) |
|
return common.merge_params(restored_params, params, dont_load) |
|
|