File size: 6,035 Bytes
74b17e0 fcf0cff 74b17e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from transformers import PretrainedConfig, LlavaConfig
from transformers import CONFIG_MAPPING
from transformers import AutoConfig
from tinyllava.utils.constants import *
class TinyLlavaConfig(PretrainedConfig):
model_type = "tinyllava"
def __init__(
self,
llm_model_name_or_path = '',
tokenizer_name_or_path = None,
vision_model_name_or_path = '',
vision_model_name_or_path2 = '',
connector_type = None,
text_config=None,
hidden_size=2048,
vocab_size=32000,
ignore_index=-100,
image_token_index=32000,
pad_token = None,
pad_token_id = None,
tokenizer_padding_side = 'right',
tokenizer_model_max_length = 2048,
vision_config = None,
vision_hidden_size = None,
vision_feature_layer = -2,
vision_feature_select_strategy = 'patch',
image_aspect_ratio = 'square',
resampler_hidden_size = None,
num_queries = None,
num_resampler_layers = None,
use_cache = False,
cache_dir = None,
tokenizer_use_fast = False,
tune_type_llm = 'frozen',
tune_type_connector = 'frozen',
tune_type_vision_tower = 'frozen',
tune_vision_tower_from_layer = -1,
**kwargs
):
self.llm_model_name_or_path = llm_model_name_or_path
self.tokenizer_name_or_path = tokenizer_name_or_path or self.llm_model_name_or_path
self.vision_model_name_or_path = vision_model_name_or_path
self.vision_model_name_or_path2 = vision_model_name_or_path2
self.connector_type = connector_type
self.tune_type_llm = tune_type_llm
self.tune_type_connector = tune_type_connector
self.tune_type_vision_tower = tune_type_vision_tower
self.tune_vision_tower_from_layer = tune_vision_tower_from_layer
self.ignore_index = IGNORE_INDEX
self.image_token_index = IMAGE_TOKEN_INDEX
self.pad_token = pad_token
self.pad_token_id = pad_token_id
self.tokenizer_padding_side = tokenizer_padding_side
self.tokenizer_model_max_length = tokenizer_model_max_length
self.vision_feature_layer = vision_feature_layer
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_aspect_ratio = image_aspect_ratio
self.resampler_hidden_size = resampler_hidden_size
self.num_queries = num_queries
self.num_resampler_layers = num_resampler_layers
self.use_cache = use_cache
self.cache_dir = cache_dir
self.tokenizer_use_fast = tokenizer_use_fast
self._load_text_config(text_config)
self._load_vision_config(vision_config)
super().__init__(**kwargs)
def load_from_config(self, config):
self.llm_model_name_or_path = getattr(config, 'model_name_or_path', '')
self.tokenizer_name_or_path = getattr(config, 'tokenizer_name_or_path', None) or self.llm_model_name_or_path
self.vision_model_name_or_path = getattr(config, 'vision_tower', '')
self.vision_model_name_or_path2 = getattr(config, 'vision_tower2', '')
self.connector_type = getattr(config, 'connector_type', None)
self.vision_feature_layer = getattr(config, 'mm_vision_select_layer', -2)
self.vision_feature_select_strategy = getattr(config, 'mm_vision_select_feature', "patch")
self.image_aspect_ratio = getattr(config, 'image_aspect_ratio', "pad")
self.resampler_hidden_size = getattr(config, 'resampler_hidden_size', None)
self.num_queries = getattr(config, 'num_queries', None)
self.num_resampler_layers = getattr(config, 'num_resampler_layers', None)
self.cache_dir = getattr(config, 'cache_dir', None)
self.tokenizer_use_fast = getattr(config, 'tokenizer_use_fast', False)
self.tokenizer_model_max_length = getattr(config, 'model_max_length', 2048)
self.tokenizer_padding_side = getattr(config, 'tokenizer_padding_side', 'right')
self._load_text_config()
self._load_vision_config()
def _load_text_config(self, text_config=None):
if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
self.text_config = CONFIG_MAPPING['llama']()
else:
self.text_config = AutoConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True)
if text_config is not None:
self.text_config = self.text_config.from_dict(text_config)
self.hidden_size = getattr(self.text_config, 'hidden_size', getattr(self.text_config, 'model_dim', None))
self.vocab_size = getattr(self.text_config, 'vocab_size', None)
def _load_vision_config(self, vision_config=None):
if self.vision_model_name_or_path is None or self.vision_model_name_or_path == '':
self.vision_config = CONFIG_MAPPING['clip_vision_model'](
intermediate_size=4096,
hidden_size=1024,
patch_size=14,
image_size=336,
num_hidden_layers=24,
num_attention_heads=16,
vocab_size=32000,
projection_dim=768,
)
else:
self.vision_config = AutoConfig.from_pretrained(self.vision_model_name_or_path.split(':')[-1], trust_remote_code=True)
self.vision_config = getattr(self.vision_config, 'vision_config', self.vision_config)
if vision_config is not None:
self.vision_config = self.vision_config.from_dict(vision_config)
self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
|