pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer encoders for text, similar to CLIP."""
from typing import Any
from big_vision import utils
from big_vision.models import common
from big_vision.models import vit
import flax.linen as nn
import flax.training.checkpoints
import numpy as np
ConfigDict = Any
class _Model(nn.Module):
"""Text transformer similar to CLIP."""
# Differences to CLIP text encoder (gpt-2) that I am aware of:
# 1. https://imgur.com/HNi3jix (gpt-1)
# 2. https://imgur.com/qKGZgBR (gpt-2)
# 3. https://imgur.com/a/xrpYHF0 (clip)
# - LayerNorm is on res-path (like pre-activation resnet)
# - dropout 0.1 everywhere
# - init as var=0.02, scaled by depth
# - BOS and EOS tokens, take repr from EOS.
# - self-attention is autoregressively masked.
# - scaled in width only, with the image model.
num_classes: int
width: int = 512
depth: int = 12
mlp_dim: int = 2048
num_heads: int = 8
dropout: float = 0.0
vocab_size: int = 32_000
pool_type: str = "last"
scan: bool = False
remat_policy: str = "nothing_saveable"
@nn.compact
def __call__(self, text, *, train=False):
out = {}
# We can't use where/argwhere since the output shape is not fixed.
# Here we use the fact that sequences are padded with EOS tokens, that the
# EOS token has value 1, and that argmin returns the first index.
# eos_indices = jnp.argmin(text, axis=1)
embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.width)
x = out["embedded"] = embedding(text)
# Add posemb
n, l, d = x.shape # pylint: disable=unused-variable
x = x + self.param("pos_embedding",
nn.initializers.normal(stddev=1/np.sqrt(d)),
(1, l, d), x.dtype)
x, encoder_out = vit.Encoder(
depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads,
scan=self.scan, remat_policy=self.remat_policy, dropout=self.dropout)(
x, deterministic=not train)
out.update({"transformed": x, **encoder_out})
# Share weights between embeddings and logit transformation.
out["vocab_logits"] = embedding.attend(x)
if self.pool_type == "last":
# Assuming "sticky" EOS tokenization, last token is always EOS.
x = out["pre_logits"] = x[:, -1, :]
elif self.pool_type == "first":
x = out["pre_logits"] = x[:, 0, :]
elif self.pool_type in ("mean", "gap"):
x = out["pre_logits"] = x.mean(axis=1)
elif self.pool_type in ("max", "gmp"):
x = out["pre_logits"] = x.max(axis=1)
elif self.pool_type == "map":
x = out["pre_logits"] = vit.MAPHead(
num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
else:
raise NotImplementedError(f"Cannot do pooling '{self.pool_type}'")
if self.num_classes:
x = out["logits"] = nn.Dense(self.num_classes, name="head")(x)
return x, out
def Model(num_classes, *, variant=None, **kw): # pylint: disable=invalid-name
"""Factory function, because linen really don't like what I'm doing!"""
return _Model(num_classes, **{**vit.decode_variant(variant), **kw})
def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name
"""Load init from checkpoint, both old model and this one. +Hi-res posemb."""
del model_cfg # unused
params = utils.load_params(init_file)
params = flax.core.unfreeze(
flax.training.checkpoints.convert_pre_linen(params))
# Some older (but expensive to train) checkpoints had the posemb added twice
# by mistake. We detect this here and merge them.
extra_posemb = params["Encoder_0"].pop("pos_embedding", 0)
params["pos_embedding"] += extra_posemb
return common.merge_params(params, init_params, dont_load)