Spaces:
Running
Running
File size: 2,220 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig
from starvector.data.base import ImageTrainProcessor
from starvector.util import dtype_mapping
from transformers import AutoConfig
def load_pretrained_model(model_path, device="cuda", **kwargs):
model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device)
tokenizer = model.model.svg_transformer.tokenizer
image_processor = ImageTrainProcessor()
context_len = model.model.query_length + model.model.max_length
return tokenizer, model, image_processor, context_len
def model_builder(config):
model_name = config.model.get("model_name", False)
args = {
"task": config.model.task,
"train_image_encoder": config.training.train_image_encoder,
"ignore_mismatched_sizes": True,
"starcoder_model_name": config.model.starcoder_model_name,
"train_LLM": config.training.train_LLM,
"torch_dtype": dtype_mapping[config.training.model_precision],
"transformer_layer_cls": config.model.get("transformer_layer_cls", False),
"use_cache": config.model.use_cache,
}
if model_name:
model = StarVectorForCausalLM.from_pretrained(model_name, **args)
else:
starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name)
starvector_config = StarVectorConfig(
max_length_train=config.model.max_length,
image_encoder_type=config.model.image_encoder_type,
use_flash_attn=config.model.use_flash_attn,
adapter_norm=config.model.adapter_norm,
starcoder_model_name=config.model.starcoder_model_name,
torch_dtype=dtype_mapping[config.training.model_precision],
num_attention_heads=starcoder_model_config.num_attention_heads,
num_hidden_layers=starcoder_model_config.num_hidden_layers,
vocab_size=starcoder_model_config.vocab_size,
hidden_size=starcoder_model_config.hidden_size,
num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None),
)
model = StarVectorForCausalLM(starvector_config, **args)
return model
|