File size: 3,282 Bytes
dea4744
 
 
 
 
 
 
 
 
 
 
f78095a
 
dea4744
7731867
dea4744
1ab8bf1
dea4744
 
 
 
 
12eb8ec
dea4744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7731867
 
 
 
 
 
 
dea4744
 
 
 
b206e22
dea4744
 
 
 
 
 
 
 
 
 
 
355f06e
 
 
 
 
 
 
dea4744
 
 
 
 
f78095a
 
 
 
 
 
 
12eb8ec
3dd5a0f
 
f78095a
 
4fe13a9
d0b9ede
 
 
 
 
8d337c8
f78095a
8939e76
 
355f06e
 
 
 
 
 
 
 
 
 
8939e76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Model-related code and constants."""

import dataclasses
import os
import re

import PIL.Image

# pylint: disable=g-bad-import-order
import gradio_helpers

import llama_cpp


ORGANIZATION = 'abetlen'
BASE_MODELS = [
    ('paligemma-3b-mix-224-gguf', 'paligemma-3b-mix-224'),
]
MODELS = {
    **{
        model_name: (
            f'{ORGANIZATION}/{repo}',
            (f'{model_name}-text-model-q4_k_m.gguf', f'{model_name}-mmproj-f16.gguf'),
        )
        for repo, model_name in BASE_MODELS
    },
}

MODELS_INFO = {
    'paligemma-3b-mix-224': (
        'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
        'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
        'bfloat16 and float16 format for research purposes only.'
    ),
}

MODELS_RES_SEQ = {
    'paligemma-3b-mix-224': (224, 256),
}

# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
# Below value should be smaller than "available RAM - one model".
# A single bf16 is about 5860 MB.
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)

# config = paligemma_bv.PaligemmaConfig(
#     ckpt='',  # will be set below
#     res=224,
#     text_len=64,
#     tokenizer='gemma(tokensets=("loc", "seg"))',
#     vocab_size=256_000 + 1024 + 128,
# )


def get_cached_model(
    model_name: str,
):# -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
  """Returns model and params, using RAM cache."""
  res, seq = MODELS_RES_SEQ[model_name]
  model_path = gradio_helpers.get_paths()[model_name]
  config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
  model, params_cpu = gradio_helpers.get_memory_cache(
      config_,
      lambda: paligemma_bv.load_model(config_),
      max_cache_size_bytes=MAX_RAM_CACHE,
  )
  return model, params_cpu

def pil_image_to_base64(image: PIL.Image.Image) -> str:
  """Converts PIL image to base64."""
  import io
  import base64
  buffered = io.BytesIO()
  image.save(buffered, format='JPEG')
  return base64.b64encode(buffered.getvalue()).decode('utf-8')

def generate(
    model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
) -> str:
  """Generates output with specified `model_name`, `sampler`."""
  # model, params_cpu = get_cached_model(model_name)
  # batch = model.shard_batch(model.prepare_batch([image], [prompt]))
  # with gradio_helpers.timed('sharding'):
  #   params = model.shard_params(params_cpu)
  # with gradio_helpers.timed('computation', start_message=True):
  #   tokens = model.predict(params, batch, sampler=sampler)

  model_path, clip_path = gradio_helpers.get_paths()[model_name]
  print(model_path)
  print(gradio_helpers.get_paths())
  model = llama_cpp.Llama(
      model_path,
      chat_handler=llama_cpp.llama_chat_format.PaliGemmaChatHandler(
        clip_path
      ),
      n_ctx=1024,
      n_ubatch=512,
      n_batch=512,
      n_gpu_layers=-1,
  )
  return model.create_chat_completion(messages=[{
    "role": "user",
    "content": [
      {
        "type": "text",
        "text": prompt
      },
      {
        "type": "image_url",
        "image_url": "data:image/jpeg;base64," + pil_image_to_base64(image)
      }
    ]
  }])["choices"][0]["message"]["content"]