File size: 3,327 Bytes
dea4744
 
fb03109
dea4744
 
 
 
 
 
 
 
 
f78095a
 
dea4744
7731867
dea4744
1ab8bf1
dea4744
 
 
 
 
a1467e1
dea4744
 
 
 
 
 
 
a1467e1
dea4744
 
 
 
 
 
 
 
 
 
 
 
 
 
7731867
 
 
 
 
 
 
dea4744
 
 
 
b206e22
dea4744
 
 
 
 
 
 
 
 
 
 
355f06e
 
 
 
 
 
 
dea4744
fb03109
dea4744
 
 
 
f78095a
 
 
 
 
 
 
12eb8ec
3dd5a0f
 
f78095a
 
4fe13a9
d0b9ede
 
 
 
 
8d337c8
f78095a
8939e76
 
355f06e
 
 
 
 
 
 
 
 
 
8939e76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Model-related code and constants."""

import spaces
import dataclasses
import os
import re

import PIL.Image

# pylint: disable=g-bad-import-order
import gradio_helpers

import llama_cpp


ORGANIZATION = 'abetlen'
BASE_MODELS = [
    ('paligemma-3b-mix-224-gguf', 'paligemma-3b-mix-224'),
]
MODELS = {
    **{
        model_name: (
            f'{ORGANIZATION}/{repo}',
            (f'{model_name}-text-model-q8_0.gguf', f'{model_name}-mmproj-f16.gguf'),
        )
        for repo, model_name in BASE_MODELS
    },
}

MODELS_INFO = {
    'paligemma-3b-mix-224': (
        'GGUF PaliGemma 3B weights quantized in Q8_0 Format, finetuned with 224x224 input images and 256 token input/output '
        'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
        'bfloat16 and float16 format for research purposes only.'
    ),
}

MODELS_RES_SEQ = {
    'paligemma-3b-mix-224': (224, 256),
}

# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
# Below value should be smaller than "available RAM - one model".
# A single bf16 is about 5860 MB.
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)

# config = paligemma_bv.PaligemmaConfig(
#     ckpt='',  # will be set below
#     res=224,
#     text_len=64,
#     tokenizer='gemma(tokensets=("loc", "seg"))',
#     vocab_size=256_000 + 1024 + 128,
# )


def get_cached_model(
    model_name: str,
):# -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
  """Returns model and params, using RAM cache."""
  res, seq = MODELS_RES_SEQ[model_name]
  model_path = gradio_helpers.get_paths()[model_name]
  config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
  model, params_cpu = gradio_helpers.get_memory_cache(
      config_,
      lambda: paligemma_bv.load_model(config_),
      max_cache_size_bytes=MAX_RAM_CACHE,
  )
  return model, params_cpu

def pil_image_to_base64(image: PIL.Image.Image) -> str:
  """Converts PIL image to base64."""
  import io
  import base64
  buffered = io.BytesIO()
  image.save(buffered, format='JPEG')
  return base64.b64encode(buffered.getvalue()).decode('utf-8')

@spaces.GPU
def generate(
    model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
) -> str:
  """Generates output with specified `model_name`, `sampler`."""
  # model, params_cpu = get_cached_model(model_name)
  # batch = model.shard_batch(model.prepare_batch([image], [prompt]))
  # with gradio_helpers.timed('sharding'):
  #   params = model.shard_params(params_cpu)
  # with gradio_helpers.timed('computation', start_message=True):
  #   tokens = model.predict(params, batch, sampler=sampler)

  model_path, clip_path = gradio_helpers.get_paths()[model_name]
  print(model_path)
  print(gradio_helpers.get_paths())
  model = llama_cpp.Llama(
      model_path,
      chat_handler=llama_cpp.llama_chat_format.PaliGemmaChatHandler(
        clip_path
      ),
      n_ctx=1024,
      n_ubatch=512,
      n_batch=512,
      n_gpu_layers=-1,
  )
  return model.create_chat_completion(messages=[{
    "role": "user",
    "content": [
      {
        "type": "text",
        "text": prompt
      },
      {
        "type": "image_url",
        "image_url": "data:image/jpeg;base64," + pil_image_to_base64(image)
      }
    ]
  }])["choices"][0]["message"]["content"]