|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Transformer encoders both for text and for images.""" |
|
|
|
import importlib |
|
from typing import Any, Optional, Tuple, Union |
|
from absl import logging |
|
|
|
from big_vision import utils |
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
|
|
ConfigDict = Any |
|
|
|
|
|
class Model(nn.Module): |
|
"""Two towers transformer.""" |
|
image: Optional[ConfigDict] = None |
|
text: Optional[ConfigDict] = None |
|
text_model: str = "proj.image_text.text_transformer" |
|
image_model: str = "vit" |
|
out_dim: Union[int, Tuple[int, int]] = 128 |
|
temperature_init: float = 1.0 |
|
bias_init: Optional[float] = None |
|
|
|
@nn.compact |
|
def __call__(self, image, text=None, **kw): |
|
"""Returns (B,C) image and (B,C) text representations.""" |
|
|
|
|
|
ztxt, zimg = None, None |
|
out = {} |
|
out_dims = self.out_dim |
|
if isinstance(out_dims, int): |
|
out_dims = (out_dims, out_dims) |
|
|
|
|
|
if text is not None: |
|
text_model = importlib.import_module( |
|
f"big_vision.models.{self.text_model}" |
|
).Model(**{"num_classes": out_dims[1], **(self.text or {})}, name="txt") |
|
|
|
ztxt, out_txt = text_model(text, **kw) |
|
for k, v in out_txt.items(): |
|
out[f"txt/{k}"] = v |
|
|
|
|
|
out["txt/norm"] = jnp.linalg.norm(ztxt, axis=1, keepdims=True) |
|
out["txt/normalized"] = ztxt = ztxt / (out["txt/norm"] + 1e-8) |
|
|
|
if image is not None: |
|
image_model = importlib.import_module( |
|
f"big_vision.models.{self.image_model}" |
|
).Model(**{"num_classes": out_dims[0], **(self.image or {})}, name="img") |
|
|
|
zimg, out_img = image_model(image, **kw) |
|
for k, v in out_img.items(): |
|
out[f"img/{k}"] = v |
|
|
|
|
|
out["img/norm"] = jnp.linalg.norm(zimg, axis=1, keepdims=True) |
|
out["img/normalized"] = zimg = zimg / (out["img/norm"] + 1e-8) |
|
|
|
temp_init = jnp.log(self.temperature_init) |
|
t = self.param("t", |
|
lambda key, shape, dtype: temp_init * jnp.ones(shape, dtype), |
|
(1,), jnp.float32) |
|
out["t"] = jnp.exp(t) |
|
|
|
out["t/parameter"] = t |
|
if (b_init := self.bias_init) is not None: |
|
out["b"] = self.param("b", lambda k, s, d: b_init * jnp.ones(s, d), |
|
(1,), jnp.float32) |
|
|
|
|
|
|
|
|
|
return zimg, ztxt, out |
|
|
|
|
|
def load(init_params, init_files, model_cfg, img_load_kw={}, txt_load_kw={}): |
|
"""Loads both towers, `init_files` is now a dict with `img` and `txt` keys.""" |
|
if isinstance(init_files, str): |
|
init_files = VANITY_NAMES.get(init_files, init_files) |
|
|
|
if isinstance(init_files, str): |
|
|
|
if "bias_init" in model_cfg.keys(): |
|
logging.info("loading img, txt, t, and b from a single checkpoint.") |
|
init_files = {k: f"{init_files}:{k}" for k in ("img", "txt", "t", "b")} |
|
else: |
|
logging.info("loading img, txt, and t from a single checkpoint.") |
|
init_files = {k: f"{init_files}:{k}" for k in ("img", "txt", "t")} |
|
else: |
|
init_files = {**init_files} |
|
|
|
if not init_params: |
|
init_params = {"img": None, "txt": None} |
|
restored_params = {**init_params} |
|
|
|
img_init = init_files.pop("image", init_files.pop("img", None)) |
|
if img_init: |
|
restored_params["img"] = importlib.import_module( |
|
f"big_vision.models.{model_cfg.get('image_model', 'vit')}" |
|
).load(init_params["img"], img_init, model_cfg.image, **img_load_kw) |
|
|
|
txt_init = init_files.pop("text", init_files.pop("txt", None)) |
|
if txt_init: |
|
restored_params["txt"] = importlib.import_module( |
|
f"big_vision.models.{model_cfg.get('text_model', 'proj.image_text.text_transformer')}" |
|
).load(init_params["txt"], txt_init, model_cfg.text, **txt_load_kw) |
|
|
|
t_init = init_files.pop("temperature", init_files.pop("t", None)) |
|
if t_init: |
|
restored_params["t"] = utils.load_params(t_init) |
|
|
|
b_init = init_files.pop("bias", init_files.pop("b", None)) |
|
if b_init: |
|
restored_params["b"] = utils.load_params(b_init) |
|
|
|
assert not init_files, ( |
|
f"There's something unused left in `config.model_init`. You probably got " |
|
f"a typo. Here it is: {init_files}") |
|
|
|
return restored_params |
|
|
|
|
|
|
|
VANITY_NAMES = { |
|
|
|
|
|
"SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz", |
|
"SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz", |
|
"SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz", |
|
"SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz", |
|
"SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz", |
|
"SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz", |
|
"SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz", |
|
"SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz", |
|
"SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz", |
|
|
|
} |
|
|