File size: 2,772 Bytes
dea4744 f78095a dea4744 7731867 dea4744 1ab8bf1 dea4744 12eb8ec dea4744 7731867 dea4744 b206e22 dea4744 f78095a 12eb8ec 3dd5a0f f78095a d0b9ede f78095a 3d2ba8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""Model-related code and constants."""
import dataclasses
import os
import re
import PIL.Image
# pylint: disable=g-bad-import-order
import gradio_helpers
import llama_cpp
ORGANIZATION = 'abetlen'
BASE_MODELS = [
('paligemma-3b-mix-224-gguf', 'paligemma-3b-mix-224'),
]
MODELS = {
**{
model_name: (
f'{ORGANIZATION}/{repo}',
(f'{model_name}-text-model-q4_k_m.gguf', f'{model_name}-mmproj-f16.gguf'),
)
for repo, model_name in BASE_MODELS
},
}
MODELS_INFO = {
'paligemma-3b-mix-224': (
'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
'bfloat16 and float16 format for research purposes only.'
),
}
MODELS_RES_SEQ = {
'paligemma-3b-mix-224': (224, 256),
}
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
# Below value should be smaller than "available RAM - one model".
# A single bf16 is about 5860 MB.
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
# config = paligemma_bv.PaligemmaConfig(
# ckpt='', # will be set below
# res=224,
# text_len=64,
# tokenizer='gemma(tokensets=("loc", "seg"))',
# vocab_size=256_000 + 1024 + 128,
# )
def get_cached_model(
model_name: str,
):# -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
"""Returns model and params, using RAM cache."""
res, seq = MODELS_RES_SEQ[model_name]
model_path = gradio_helpers.get_paths()[model_name]
config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
model, params_cpu = gradio_helpers.get_memory_cache(
config_,
lambda: paligemma_bv.load_model(config_),
max_cache_size_bytes=MAX_RAM_CACHE,
)
return model, params_cpu
def generate(
model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
) -> str:
"""Generates output with specified `model_name`, `sampler`."""
# model, params_cpu = get_cached_model(model_name)
# batch = model.shard_batch(model.prepare_batch([image], [prompt]))
# with gradio_helpers.timed('sharding'):
# params = model.shard_params(params_cpu)
# with gradio_helpers.timed('computation', start_message=True):
# tokens = model.predict(params, batch, sampler=sampler)
model_path, clip_path = gradio_helpers.get_paths()[model_name]
print(model_path)
print(gradio_helpers.get_paths())
model = llama_cpp.Llama(
model_path,
chat_handler=llama_cpp.llama_chat_format.PaligemmaChatHandler(
clip_path
),
n_ctx=1024,
n_ubatch=512,
n_batch=512,
)
return model.create_chat_completion(prompt)["choices"][0]["message"]["content"]
|