File size: 6,124 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer encoders both for text and for images."""
import importlib
from typing import Any, Optional, Tuple, Union
from absl import logging
from big_vision import utils
import flax.linen as nn
import jax.numpy as jnp
ConfigDict = Any
class Model(nn.Module):
"""Two towers transformer."""
image: Optional[ConfigDict] = None
text: Optional[ConfigDict] = None
text_model: str = "proj.image_text.text_transformer"
image_model: str = "vit"
out_dim: Union[int, Tuple[int, int]] = 128
temperature_init: float = 1.0
bias_init: Optional[float] = None
@nn.compact
def __call__(self, image, text=None, **kw):
"""Returns (B,C) image and (B,C) text representations."""
# Support calling without text or without image, for example for few-shot.
ztxt, zimg = None, None
out = {}
out_dims = self.out_dim
if isinstance(out_dims, int):
out_dims = (out_dims, out_dims)
# Embed the text:
if text is not None:
text_model = importlib.import_module(
f"big_vision.models.{self.text_model}"
).Model(**{"num_classes": out_dims[1], **(self.text or {})}, name="txt")
ztxt, out_txt = text_model(text, **kw)
for k, v in out_txt.items():
out[f"txt/{k}"] = v
# Normalize the embeddings the models give us.
out["txt/norm"] = jnp.linalg.norm(ztxt, axis=1, keepdims=True)
out["txt/normalized"] = ztxt = ztxt / (out["txt/norm"] + 1e-8)
if image is not None:
image_model = importlib.import_module(
f"big_vision.models.{self.image_model}"
).Model(**{"num_classes": out_dims[0], **(self.image or {})}, name="img") # pylint: disable=not-a-mapping
zimg, out_img = image_model(image, **kw)
for k, v in out_img.items():
out[f"img/{k}"] = v
# Normalize the embeddings the models give us.
out["img/norm"] = jnp.linalg.norm(zimg, axis=1, keepdims=True)
out["img/normalized"] = zimg = zimg / (out["img/norm"] + 1e-8)
temp_init = jnp.log(self.temperature_init)
t = self.param("t",
lambda key, shape, dtype: temp_init * jnp.ones(shape, dtype),
(1,), jnp.float32)
out["t"] = jnp.exp(t)
out["t/parameter"] = t
if (b_init := self.bias_init) is not None:
out["b"] = self.param("b", lambda k, s, d: b_init * jnp.ones(s, d),
(1,), jnp.float32)
# We could actually play with pre-multiplying by temperature here, such
# that out["t"] is nothing special to the trainer anymore.
return zimg, ztxt, out
def load(init_params, init_files, model_cfg, img_load_kw={}, txt_load_kw={}): # pylint: disable=dangerous-default-value
"""Loads both towers, `init_files` is now a dict with `img` and `txt` keys."""
if isinstance(init_files, str):
init_files = VANITY_NAMES.get(init_files, init_files)
if isinstance(init_files, str):
# A shortcut for a single file checkpoint of a two_towers model.
if "bias_init" in model_cfg.keys():
logging.info("loading img, txt, t, and b from a single checkpoint.")
init_files = {k: f"{init_files}:{k}" for k in ("img", "txt", "t", "b")}
else:
logging.info("loading img, txt, and t from a single checkpoint.")
init_files = {k: f"{init_files}:{k}" for k in ("img", "txt", "t")}
else:
init_files = {**init_files} # Shallow copy because we'll pop stuff off.
if not init_params: # Convenience to skip checks in colab.
init_params = {"img": None, "txt": None}
restored_params = {**init_params}
img_init = init_files.pop("image", init_files.pop("img", None))
if img_init:
restored_params["img"] = importlib.import_module(
f"big_vision.models.{model_cfg.get('image_model', 'vit')}"
).load(init_params["img"], img_init, model_cfg.image, **img_load_kw)
txt_init = init_files.pop("text", init_files.pop("txt", None))
if txt_init:
restored_params["txt"] = importlib.import_module(
f"big_vision.models.{model_cfg.get('text_model', 'proj.image_text.text_transformer')}" # pylint: disable=line-too-long
).load(init_params["txt"], txt_init, model_cfg.text, **txt_load_kw)
t_init = init_files.pop("temperature", init_files.pop("t", None))
if t_init:
restored_params["t"] = utils.load_params(t_init)
b_init = init_files.pop("bias", init_files.pop("b", None))
if b_init:
restored_params["b"] = utils.load_params(b_init)
assert not init_files, (
f"There's something unused left in `config.model_init`. You probably got "
f"a typo. Here it is: {init_files}")
return restored_params
# Shortcut names for some canonical paper checkpoints:
VANITY_NAMES = {
# pylint: disable=line-too-long
# SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343
"SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz",
"SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz",
"SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz",
"SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz",
"SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz",
"SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz",
"SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz",
"SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz",
"SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz",
# pylint: enable=line-too-long
}
|