"""Wraps `big_vision` PaliGemma model for easy use in demo."""

from collections.abc import Callable
import dataclasses
from typing import Any

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import PIL.Image

from big_vision import sharding
from big_vision import utils
from big_vision.models.proj.paligemma import paligemma
from big_vision.pp import builder as pp_builder
from big_vision.pp import ops_general  # pylint: disable=unused-import
from big_vision.pp import ops_image  # pylint: disable=unused-import
from big_vision.pp import ops_text  # pylint: disable=unused-import
from big_vision.pp import tokenizer
from big_vision.pp.proj.paligemma import ops as ops_paligemma  # pylint: disable=unused-import
from big_vision.trainers.proj.paligemma import predict_fns


mesh = jax.sharding.Mesh(jax.devices(), 'data')


def _recover_bf16(x):
  if x.dtype == np.dtype('V2'):
    x = x.view('bfloat16')
  return x


def _load(
    path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152
):
  """Loads model, params, decode functions and tokenizer."""
  tok = tokenizer.get_tokenizer(tokenizer_spec)

  config = ml_collections.FrozenConfigDict(dict(
      llm_model='proj.paligemma.gemma_bv',
      llm=dict(vocab_size=vocab_size, variant='gemma_2b'),
      img=dict(variant='So400m/14', pool_type='none', scan=True),
  ))
  model = paligemma.Model(**config)
  decode = predict_fns.get_all(model)['decode']
  beam_decode = predict_fns.get_all(model)['beam_decode']

  params_cpu = paligemma.load(None, path, config)
  # Some numpy versions don't load bfloat16 correctly:
  params_cpu = jax.tree.map(_recover_bf16, params_cpu)

  return model, params_cpu, decode, beam_decode, tok


def _shard_params(params_cpu):
  """Shards `params_cpu` with fsdp strategy on all available devices."""
  params_sharding = sharding.infer_sharding(
      params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh
  )
  params = jax.tree.map(utils.reshard, params_cpu, params_sharding)
  return params


def _pil2np(img):
  """Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`."""
  if isinstance(img, PIL.Image.Image):
    img = np.array(img)
    img = img[..., :3]
    if img.ndim == 2:
      img = img[..., None]
    if img.shape[-1] == 1:
      img = np.repeat(img, 3, axis=-1)
  return img


def _prepare_batch(
    images,
    prefixes,
    *,
    res=224,
    tokenizer_spec='gemma(tokensets=("loc", "seg"))',
    suffixes=None,
    text_len=64,
):
  """Returns non-sharded batch."""

  pp_fn = pp_builder.get_preprocess_fn('|'.join([
      f'resize({res}, antialias=True)|value_range(-1, 1)',
      f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')",
      f"tok(key='septok', text='\\n', model='{tokenizer_spec}')",
      f"tok(key='suffix', model='{tokenizer_spec}')",
      'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])',  # pylint: disable=line-too-long
      f'tolen({text_len}, pad_value=0, key="text")',
      f'tolen({text_len}, pad_value=1, key="mask_ar")',
      f'tolen({text_len}, pad_value=0, key="mask_input")',
      'keep("image", "text", "mask_ar", "mask_input")',
  ]), log_data=False)
  assert not isinstance(prefixes, str), f'expected batch: {prefixes}'
  assert (
      isinstance(images, (list, tuple)) or images.ndim == 4
  ), f'expected batch: {images.shape}'
  if suffixes is None:
    suffixes = [''] * len(prefixes)
  assert len(prefixes) == len(suffixes) == len(images)
  examples = [{'_mask': True, **pp_fn({
      'image': np.asarray(_pil2np(image)),
      'prefix': np.array(prefix),
      'suffix': np.array(suffix),
  })} for image, prefix, suffix in zip(images, prefixes, suffixes)]
  batch = jax.tree_map(lambda *xs: np.stack(xs), *examples)
  return batch


def _shard_batch(batch, n=None):
  """Shards `batch` with fsdp strategy on all available devices."""
  if n is None:
    n = jax.local_device_count()
  def pad(x):
    return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1))
  batch = {k: pad(v) for k, v in batch.items()}
  data_sharding = jax.sharding.NamedSharding(
      mesh, jax.sharding.PartitionSpec('data')
  )
  batch_on_device = utils.reshard(batch, data_sharding)
  return batch_on_device


@dataclasses.dataclass(frozen=True, kw_only=True, order=True)
class PaligemmaConfig:
  """Desribes a `big_vision` PaliGemma model."""

  ckpt: str
  res: int
  text_len: int
  tokenizer: str
  vocab_size: int


@dataclasses.dataclass(frozen=True, kw_only=True)
class PaliGemmaModel:
  """Wraps a `big_vision` PaliGemma model."""

  config: PaligemmaConfig
  tokenizer: tokenizer.Tokenizer
  decode: Callable[..., Any]
  beam_decode: Callable[..., Any]

  @classmethod
  def shard_batch(cls, batch):
    return _shard_batch(batch)

  @classmethod
  def shard_params(cls, params_cpu):
    return _shard_params(params_cpu)

  def prepare_batch(self, images, texts, suffixes=None):
    return _prepare_batch(
        images=images,
        prefixes=texts,
        suffixes=suffixes,
        res=self.config.res,
        tokenizer_spec=self.config.tokenizer,
        text_len=self.config.text_len,
    )

  def predict(
      self,
      params,
      batch,
      devices=None,
      max_decode_len=128,
      sampler='greedy',
      **kw,
  ):
    """Returns tokens."""
    if devices is None:
      devices = jax.devices()
    if sampler == 'beam':
      decode = self.beam_decode
    else:
      decode = self.decode
      kw['sampler'] = sampler
    return decode(
        {'params': params},
        batch=batch,
        devices=devices,
        eos_token=self.tokenizer.eos_token,
        max_decode_len=max_decode_len,
        **kw,
    )


ParamsCpu = Any


def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]:
  """Loads model from config."""
  model, params_cpu, decode, beam_decode, tok = _load(
      path=config.ckpt,
      tokenizer_spec=config.tokenizer,
      vocab_size=config.vocab_size,
  )
  del model
  return PaliGemmaModel(
      config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode,
  ), params_cpu