|
"""Model-related code and constants.""" |
|
|
|
import dataclasses |
|
import os |
|
import re |
|
|
|
import PIL.Image |
|
|
|
|
|
import gradio_helpers |
|
|
|
import llama_cpp |
|
|
|
|
|
ORGANIZATION = 'abetlen' |
|
BASE_MODELS = [ |
|
('paligemma-3b-mix-224-gguf', 'paligemma-3b-mix-224'), |
|
] |
|
MODELS = { |
|
**{ |
|
model_name: ( |
|
f'{ORGANIZATION}/{repo}', |
|
(f'{model_name}-text-model-q4_k_m.gguf', f'{model_name}-mmproj-f16.gguf'), |
|
) |
|
for repo, model_name in BASE_MODELS |
|
}, |
|
} |
|
|
|
MODELS_INFO = { |
|
'paligemma-3b-mix-224': ( |
|
'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output ' |
|
'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' |
|
'bfloat16 and float16 format for research purposes only.' |
|
), |
|
} |
|
|
|
MODELS_RES_SEQ = { |
|
'paligemma-3b-mix-224': (224, 256), |
|
} |
|
|
|
|
|
|
|
|
|
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cached_model( |
|
model_name: str, |
|
): |
|
"""Returns model and params, using RAM cache.""" |
|
res, seq = MODELS_RES_SEQ[model_name] |
|
model_path = gradio_helpers.get_paths()[model_name] |
|
config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq) |
|
model, params_cpu = gradio_helpers.get_memory_cache( |
|
config_, |
|
lambda: paligemma_bv.load_model(config_), |
|
max_cache_size_bytes=MAX_RAM_CACHE, |
|
) |
|
return model, params_cpu |
|
|
|
def pil_image_to_base64(image: PIL.Image.Image) -> str: |
|
"""Converts PIL image to base64.""" |
|
import io |
|
import base64 |
|
buffered = io.BytesIO() |
|
image.save(buffered, format='JPEG') |
|
return base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
|
def generate( |
|
model_name: str, sampler: str, image: PIL.Image.Image, prompt: str |
|
) -> str: |
|
"""Generates output with specified `model_name`, `sampler`.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path, clip_path = gradio_helpers.get_paths()[model_name] |
|
print(model_path) |
|
print(gradio_helpers.get_paths()) |
|
model = llama_cpp.Llama( |
|
model_path, |
|
chat_handler=llama_cpp.llama_chat_format.PaliGemmaChatHandler( |
|
clip_path |
|
), |
|
n_ctx=1024, |
|
n_ubatch=512, |
|
n_batch=512, |
|
n_gpu_layers=-1, |
|
) |
|
return model.create_chat_completion(messages=[{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": prompt |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": "data:image/jpeg;base64," + pil_image_to_base64(image) |
|
} |
|
] |
|
}])["choices"][0]["message"]["content"] |
|
|