|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model definition to train a single ViT model with the contrastive trainer.""" |
|
|
|
import importlib |
|
from typing import Optional, Any |
|
|
|
from big_vision import utils |
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
|
|
ConfigDict = Any |
|
|
|
|
|
class Model(nn.Module): |
|
"""Single ViT to encode regular images and text images.""" |
|
image: Optional[ConfigDict] = None |
|
image_model: str = "vit" |
|
out_dim: int = 768 |
|
temperature_init: float = 10.0 |
|
|
|
@nn.compact |
|
def __call__(self, image, text=None, **kw): |
|
"""Returns (B, C) image and (B, C) text representations, and some extras.""" |
|
ztxt, zimg = None, None |
|
kw = kw or {} |
|
|
|
image_model = importlib.import_module( |
|
f"big_vision.models.{self.image_model}" |
|
).Model(**{"num_classes": self.out_dim, **(self.image or {})}, name="img") |
|
|
|
def _compute_embedding(input_image, prefix): |
|
zemb, out_emb = image_model(input_image, **kw) |
|
out = {f"{prefix}/{k}": v for k, v in out_emb.items()} |
|
|
|
|
|
out[f"{prefix}/norm"] = jnp.linalg.norm(zemb, axis=1, keepdims=True) |
|
out[f"{prefix}/normalized"] = zemb = zemb / (out[f"{prefix}/norm"] + 1e-8) |
|
return zemb, out |
|
|
|
out = {} |
|
if image is not None: |
|
zimg, out_img = _compute_embedding(image, "img") |
|
out.update(out_img) |
|
|
|
if text is not None: |
|
ztxt, out_txt = _compute_embedding(text, "txt") |
|
out.update(out_txt) |
|
|
|
temp_init = jnp.log(self.temperature_init) |
|
t = self.param("t", |
|
lambda key, shape, dtype: temp_init*jnp.ones(shape, dtype), |
|
(1,), jnp.float32) |
|
out["t"] = jnp.exp(t) |
|
out["t/parameter"] = t |
|
|
|
return zimg, ztxt, out |
|
|
|
|
|
def load(init_params, init_files, model_cfg, img_load_kw={}): |
|
"""Loads the ViT parameters - adapted from proj/image_text/two_towers.py.""" |
|
if isinstance(init_files, str): |
|
|
|
init_files = {k: f"{init_files}:{k}" for k in ("img", "t")} |
|
else: |
|
init_files = {**init_files} |
|
|
|
restored_params = {**init_params} |
|
|
|
img_init = init_files.pop("image", init_files.pop("img", None)) |
|
if img_init: |
|
restored_params["img"] = importlib.import_module( |
|
f"big_vision.models.{model_cfg.image_model}" |
|
).load(init_params["img"], img_init, model_cfg.image, **img_load_kw) |
|
|
|
t_init = init_files.pop("temperature", init_files.pop("t", None)) |
|
if t_init: |
|
restored_params["t"] = utils.load_params(None, t_init) |
|
|
|
assert not init_files, ( |
|
f"There's something unused left in `config.model_init`. You probably got " |
|
f"a typo. Here it is: {init_files}") |
|
|
|
return restored_params |
|
|