"""Model-related code and constants.""" import dataclasses import os import re import PIL.Image # pylint: disable=g-bad-import-order import gradio_helpers import llama_cpp ORGANIZATION = 'abetlen' BASE_MODELS = [ ('paligemma-3b-mix-224-gguf', 'paligemma-3b-mix-224'), ] MODELS = { **{ model_name: ( f'{ORGANIZATION}/{repo}', (f'{model_name}-text-model-q4_k_m.gguf', f'{model_name}-mmproj-f16.gguf'), ) for repo, model_name in BASE_MODELS }, } MODELS_INFO = { 'paligemma-3b-mix-224': ( 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output ' 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' 'bfloat16 and float16 format for research purposes only.' ), } MODELS_RES_SEQ = { 'paligemma-3b-mix-224': (224, 256), } # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM. # Below value should be smaller than "available RAM - one model". # A single bf16 is about 5860 MB. MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9) # config = paligemma_bv.PaligemmaConfig( # ckpt='', # will be set below # res=224, # text_len=64, # tokenizer='gemma(tokensets=("loc", "seg"))', # vocab_size=256_000 + 1024 + 128, # ) def get_cached_model( model_name: str, ):# -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]: """Returns model and params, using RAM cache.""" res, seq = MODELS_RES_SEQ[model_name] model_path = gradio_helpers.get_paths()[model_name] config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq) model, params_cpu = gradio_helpers.get_memory_cache( config_, lambda: paligemma_bv.load_model(config_), max_cache_size_bytes=MAX_RAM_CACHE, ) return model, params_cpu def generate( model_name: str, sampler: str, image: PIL.Image.Image, prompt: str ) -> str: """Generates output with specified `model_name`, `sampler`.""" # model, params_cpu = get_cached_model(model_name) # batch = model.shard_batch(model.prepare_batch([image], [prompt])) # with gradio_helpers.timed('sharding'): # params = model.shard_params(params_cpu) # with gradio_helpers.timed('computation', start_message=True): # tokens = model.predict(params, batch, sampler=sampler) model_path, clip_path = gradio_helpers.get_paths()[model_name] print(model_path) print(gradio_helpers.get_paths()) model = llama_cpp.Llama( model_path, chat_handler=llama_cpp.llama_chat_format.PaliGemmaChatHandler( clip_path ), n_ctx=1024, n_ubatch=512, n_batch=512, ) print(prompt) return model.create_chat_completion(messages=[{ "role": "user", "content": prompt }])["choices"][0]["message"]["content"]