Spaces:
Running
Running
from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig | |
from starvector.data.base import ImageTrainProcessor | |
from starvector.util import dtype_mapping | |
from transformers import AutoConfig | |
def load_pretrained_model(model_path, device="cuda", **kwargs): | |
model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device) | |
tokenizer = model.model.svg_transformer.tokenizer | |
image_processor = ImageTrainProcessor() | |
context_len = model.model.query_length + model.model.max_length | |
return tokenizer, model, image_processor, context_len | |
def model_builder(config): | |
model_name = config.model.get("model_name", False) | |
args = { | |
"task": config.model.task, | |
"train_image_encoder": config.training.train_image_encoder, | |
"ignore_mismatched_sizes": True, | |
"starcoder_model_name": config.model.starcoder_model_name, | |
"train_LLM": config.training.train_LLM, | |
"torch_dtype": dtype_mapping[config.training.model_precision], | |
"transformer_layer_cls": config.model.get("transformer_layer_cls", False), | |
"use_cache": config.model.use_cache, | |
} | |
if model_name: | |
model = StarVectorForCausalLM.from_pretrained(model_name, **args) | |
else: | |
starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name) | |
starvector_config = StarVectorConfig( | |
max_length_train=config.model.max_length, | |
image_encoder_type=config.model.image_encoder_type, | |
use_flash_attn=config.model.use_flash_attn, | |
adapter_norm=config.model.adapter_norm, | |
starcoder_model_name=config.model.starcoder_model_name, | |
torch_dtype=dtype_mapping[config.training.model_precision], | |
num_attention_heads=starcoder_model_config.num_attention_heads, | |
num_hidden_layers=starcoder_model_config.num_hidden_layers, | |
vocab_size=starcoder_model_config.vocab_size, | |
hidden_size=starcoder_model_config.hidden_size, | |
num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None), | |
) | |
model = StarVectorForCausalLM(starvector_config, **args) | |
return model | |