File size: 9,018 Bytes
74b17e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import torch

from ..utils import *
from ..model import *

class BaseTrainingRecipe:

    def __init__(self, training_arguments):
        self.training_arguments = training_arguments

    
    def __call__(self, model):
        model = self.training_model_converse(model)
        model = self.tune_type_setting(model)
        model.config.tune_type_connector = self.training_arguments.tune_type_connector
        model.config.tune_type_vision_tower = self.training_arguments.tune_type_vision_tower
        model.config.tune_type_llm = self.training_arguments.tune_type_llm
        model.config.tune_vision_tower_from_layer = self.training_arguments.tune_vision_tower_from_layer
        return model
    
 
    def add_args(self, model_args):
        llm_dtype = (torch.float16 if self.training_arguments.fp16 else (torch.bfloat16 if self.training_arguments.bf16 else torch.float32))
        model_args['llm'].update(dict(torch_dtype=llm_dtype))
        if self.training_arguments.pretrained_model_path is not None:
            model_args['llm'].update(dict(pretrained_llm_path=os.path.join(self.training_arguments.pretrained_model_path, 'language_model')))
            model_args['vision_tower'].update(dict(pretrained_vision_tower_path=os.path.join(self.training_arguments.pretrained_model_path, 'vision_tower')))
            model_args['connector'].update(dict(pretrained_connector_path=os.path.join(self.training_arguments.pretrained_model_path, 'connector')))
        return model_args
            
    def tune_type_setting(self, model):
        model = self._llm_tune_type_setting(model)
        model = self._vision_tower_tune_type_setting(model)
        model = self._connector_tune_type_setting(model)
        return model    
        
        
        
    def _llm_tune_type_setting(self, model):
        tune_type = self.training_arguments.tune_type_llm.lower()
        assert tune_type in ('frozen', 'full', 'lora', 'qlora'), f'tune_type {tune_type} not supported in this training recipe!'
        if tune_type == 'full':
            model.language_model.requires_grad_(True)
        elif tune_type == 'frozen':
            model.language_model.requires_grad_(False)
        self.support_gradient_checkpoint(model.language_model, self.training_arguments.gradient_checkpointing)
        return model
        
    def _vision_tower_tune_type_setting(self, model):
        tune_type = self.training_arguments.tune_type_vision_tower.lower()
        assert tune_type in ('frozen', 'full', 'partially-tune', 'lora', 'qlora'), f'tune_type {tune_type} not supported in this training recipe!'
        if tune_type == 'full':
            model.vision_tower.requires_grad_(True)
        elif tune_type == 'frozen':
            model.vision_tower.requires_grad_(False)         
        elif tune_type == 'partially-tune':
            #--------------------------------------------
            #--------------------------------------------
            #TODO gradient checkpointing related???
            #--------------------------------------------
            #--------------------------------------------
            from_layer = self.training_arguments.tune_vision_tower_from_layer
            if from_layer > -1:
                log(f'Tune the vision tower from layer {from_layer}!')
                for n, p in model.vision_tower.named_parameters():
                    if 'vision_model.encoder.layers.' in n: #TODO not sure if other visual encoders contain 'vision_model.encoder.layers.'
                        layer_id = int(n.split('vision_model.encoder.layers.')[-1].split('.')[0])
                        if layer_id >= from_layer:
                            p.requires_grad = True
                        else:
                            p.requires_grad = False
                    else:
                        p.requires_grad = False
        #self.support_gradient_checkpoint(model.vision_tower._vision_tower, self.training_arguments.gradient_checkpointing)
        return model
        
    def _connector_tune_type_setting(self, model):
        tune_type = self.training_arguments.tune_type_connector.lower()
        assert tune_type in ('frozen', 'full', 'lora', 'qlora'), f'tune_type {tune_type} not supported in this training recipe!'   
        if tune_type == 'full':
            for p in model.connector.parameters():
                p.requires_grad = True
        elif tune_type == 'frozen':
            for p in model.connector.parameters():
                p.requires_grad = False
        return model
    
    
        
    def training_model_converse(self, model):
        return model
        
    
    def save(self, model, trainer):
        model.config.use_cache = True
        #save tokenizer       
        model.tokenizer.save_pretrained(self.training_arguments.output_dir)
        #save entire model config
        model.config.save_pretrained(self.training_arguments.output_dir, from_pt=True)
        #save trainer
        trainer.save_state()

        if 'finetune' in self.training_arguments.output_dir and self.training_arguments.pretrained_model_path is not None: # for finetune stage
            if trainer.deepspeed:
                torch.cuda.synchronize()
            trainer.save_model(self.training_arguments.output_dir)
            return
        
        #the followings are for pretrain stage
        #save language model
        language_model_state_dict = get_state_maybe_zero_3(model.language_model.named_parameters(), [''], False)
        if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
            language_model_output_dir = os.path.join(self.training_arguments.output_dir, 'language_model')
            os.makedirs(language_model_output_dir, exist_ok=True)
            language_model_output_path = os.path.join(self.training_arguments.output_dir, 'language_model/pytorch_model.bin')
            torch.save(language_model_state_dict, language_model_output_path)
            model.config.text_config.save_pretrained(language_model_output_dir, from_pt=True)
        #save vision tower
        vision_tower_state_dict = get_state_maybe_zero_3(model.vision_tower._vision_tower.named_parameters(), [''], False)
        if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
            vision_tower_output_dir = os.path.join(self.training_arguments.output_dir, 'vision_tower')
            os.makedirs(vision_tower_output_dir, exist_ok=True)
            vision_tower_output_path = os.path.join(self.training_arguments.output_dir, 'vision_tower/pytorch_model.bin')
            torch.save(vision_tower_state_dict, vision_tower_output_path)
            if isinstance(model.vision_tower._vision_tower, PreTrainedModel):
                model.vision_tower._vision_tower.config.save_pretrained(vision_tower_output_dir, from_pt=True)
        #save connector
        connector_state_dict = get_state_maybe_zero_3(model.connector.named_parameters(), [''], False)
        if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
            connector_output_dir = os.path.join(self.training_arguments.output_dir, 'connector')
            os.makedirs(connector_output_dir, exist_ok=True)
            connector_output_path = os.path.join(self.training_arguments.output_dir, 'connector/pytorch_model.bin')
            torch.save(connector_state_dict, connector_output_path)
    

    def load(self, model, model_args={}):
        if not ('lora' in self.training_arguments.pretrained_model_path and os.path.exists(os.path.join(self.training_arguments.pretrained_model_path, 'adapter_config.json'))): # loading model for non-lora/non-qlora pretraining
            model.load_llm(**model_args['llm'])
            model.load_vision_tower(**model_args['vision_tower'])
            model.load_connector(**model_args['connector'])
        else:
            model.language_model = model.language_model.from_pretrained(model_args['llm']['model_name_or_path'],attn_implementation='flash_attention_2',torch_dtype=model_args['llm']['torch_dtype'])
            model.load_vision_tower(**model_args['vision_tower'])
            model.load_connector(**model_args['connector'])
            model.to(model_args['llm']['torch_dtype'])
            from peft import PeftModel
            print('Loading LoRA weights...')
            model = PeftModel.from_pretrained(model, self.training_arguments.pretrained_model_path)
            print('Merging LoRA weights...')
            model = model.merge_and_unload()
            print('Model is loaded...')

        return model
        
    
    def support_gradient_checkpoint(self, model, gradient_checkpointing=False):
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        if gradient_checkpointing:
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:
                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)