|
from packaging import version |
|
import pathlib |
|
|
|
import tokenizers |
|
import transformers |
|
|
|
|
|
from tinyllava.train.tinyllava_trainer import LLaVATrainer |
|
from tinyllava.training_recipe import TrainingRecipeFactory |
|
from tinyllava.utils import * |
|
from tinyllava.model import * |
|
from tinyllava.data.dataset import make_supervised_data_module |
|
|
|
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') |
|
|
|
|
|
def load_settings(model_arguments, data_arguments, training_arguments): |
|
model_arguments.tune_type_connector = training_arguments.tune_type_connector |
|
model_arguments.tune_type_llm = training_arguments.tune_type_llm |
|
model_arguments.tune_type_vision_tower = training_arguments.tune_type_vision_tower |
|
model_arguments.image_aspect_ratio = data_arguments.image_aspect_ratio |
|
|
|
model_args = {} |
|
model_args['llm'] = _load_llm_settings(model_arguments) |
|
model_args['vision_tower'] = _load_vision_settings(model_arguments) |
|
model_args['connector'] = _load_connector_settings(model_arguments) |
|
return model_args |
|
|
|
def _load_llm_settings(model_arguments): |
|
llm_args = {} |
|
llm_args['model_name_or_path'] = model_arguments.model_name_or_path |
|
llm_args['cache_dir'] = model_arguments.cache_dir |
|
llm_args['attn_implementation'] = model_arguments.attn_implementation |
|
return llm_args |
|
|
|
def _load_vision_settings(model_arguments): |
|
vision_args = {} |
|
vision_args['model_name_or_path'] = model_arguments.vision_tower.split(':')[-1] |
|
if model_arguments.vision_tower2 != '': |
|
vision_args['model_name_or_path2'] = model_arguments.vision_tower2.split(':')[-1] |
|
return vision_args |
|
|
|
def _load_connector_settings(model_arguments): |
|
connector_args = {} |
|
connector_args['connector_type'] = model_arguments.connector_type |
|
return connector_args |
|
|
|
|
|
def train(): |
|
|
|
|
|
parser = transformers.HfArgumentParser( |
|
(ModelArguments, DataArguments, TrainingArguments)) |
|
model_arguments, data_arguments, training_arguments = parser.parse_args_into_dataclasses() |
|
|
|
logger_setting(getattr(training_arguments, 'output_dir', None)) |
|
|
|
training_recipe = TrainingRecipeFactory(training_arguments.training_recipe)(training_arguments) |
|
|
|
model_args = load_settings(model_arguments, data_arguments, training_arguments) |
|
model_args = training_recipe.add_args(model_args) |
|
model_config = TinyLlavaConfig() |
|
model_config.load_from_config(model_arguments) |
|
model = TinyLlavaForConditionalGeneration(model_config) |
|
|
|
if training_arguments.pretrained_model_path is not None: |
|
model = training_recipe.load(model, model_args) |
|
else: |
|
model.load_llm(**model_args['llm']) |
|
model.load_vision_tower(**model_args['vision_tower']) |
|
model.load_connector(**model_args['connector']) |
|
|
|
model = training_recipe(model) |
|
model.config.use_cache = False |
|
model.config.image_aspect_ratio = data_arguments.image_aspect_ratio |
|
tokenizer = model.tokenizer |
|
data_arguments.image_processor = model.vision_tower._image_processor |
|
data_arguments.is_multimodal = True |
|
data_module = make_supervised_data_module(tokenizer=tokenizer, |
|
data_args=data_arguments) |
|
log_trainable_params(model) |
|
trainer = LLaVATrainer(model=model, |
|
tokenizer=tokenizer, |
|
args=training_arguments, |
|
**data_module) |
|
|
|
trainer.train() |
|
|
|
training_recipe.save(model, trainer) |
|
|
|
if __name__ == "__main__": |
|
train() |
|
|