File size: 2,220 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig
from starvector.data.base import ImageTrainProcessor
from starvector.util import dtype_mapping
from transformers import AutoConfig

def load_pretrained_model(model_path, device="cuda", **kwargs):
    model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device)
    tokenizer = model.model.svg_transformer.tokenizer
    image_processor = ImageTrainProcessor()
    context_len = model.model.query_length + model.model.max_length
    return tokenizer, model, image_processor, context_len

def model_builder(config):
    model_name = config.model.get("model_name", False)

    args = {
        "task": config.model.task,
        "train_image_encoder": config.training.train_image_encoder,
        "ignore_mismatched_sizes": True,
        "starcoder_model_name": config.model.starcoder_model_name,
        "train_LLM": config.training.train_LLM,
        "torch_dtype": dtype_mapping[config.training.model_precision],
        "transformer_layer_cls": config.model.get("transformer_layer_cls", False),
        "use_cache": config.model.use_cache,
    }
    if model_name:
        model = StarVectorForCausalLM.from_pretrained(model_name, **args)
    else:
        starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name)

        starvector_config = StarVectorConfig(
            max_length_train=config.model.max_length,
            image_encoder_type=config.model.image_encoder_type,
            use_flash_attn=config.model.use_flash_attn,
            adapter_norm=config.model.adapter_norm,
            starcoder_model_name=config.model.starcoder_model_name,
            torch_dtype=dtype_mapping[config.training.model_precision],
            num_attention_heads=starcoder_model_config.num_attention_heads,
            num_hidden_layers=starcoder_model_config.num_hidden_layers,
            vocab_size=starcoder_model_config.vocab_size,
            hidden_size=starcoder_model_config.hidden_size,
            num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None),
        )
        model = StarVectorForCausalLM(starvector_config, **args)
        
    return model