File size: 5,219 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import logging

import torch

from wenet.llm_asr.llmasr_model import LLMASR_Model
from wenet.transformer.cmvn import GlobalCMVN
from wenet.utils.checkpoint import load_checkpoint, load_trained_modules
from wenet.utils.cmvn import load_cmvn

from gxl_ai_utils.utils import utils_file

def init_llmasr(args, configs, is_inference=False):
    llm_path = configs["llm_path"]
    lora = configs["use_lora"]
    lora_alpha = configs["lora_alpha"]
    lora_rank = configs["lora_rank"]
    lora_dropout = configs["lora_dropout"]
    # prompt_pattern = configs['prompt_pattern']

    encoder_output_dim = -1
    if configs['encoder'] == 'transformer':
        if configs.get('cmvn', None) == 'global_cmvn':
            mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'],
                                   configs['cmvn_conf']['is_json_cmvn'])
            global_cmvn = GlobalCMVN(
                torch.from_numpy(mean).float(),
                torch.from_numpy(istd).float())
        else:
            global_cmvn = None
        encoder_type = configs.get('encoder', 'conformer')
        input_dim = configs['input_dim']
        from wenet.utils.init_model import WENET_ENCODER_CLASSES
        encoder = WENET_ENCODER_CLASSES[encoder_type](
            input_dim,
            global_cmvn=global_cmvn,
            **configs['encoder_conf'],
            **configs['encoder_conf']['efficient_conf']
            if 'efficient_conf' in configs['encoder_conf'] else {})
        encoder_output_dim = configs['encoder_conf']['output_size']
    elif configs['encoder'] == 'whisper':
        raise NotImplementedError('whisper 还没实现')
    elif configs['encoder'] == 'hubert':
        raise NotImplementedError('hubert 还没实现')
    else:
        encoder = None
    logging.info(f'encoder output dim:{encoder_output_dim}')


    # encoder = encoder.to(torch.float16)
    speech_token_num = configs.get('speech_token_num', 0)
    train_speech_out = speech_token_num != 0

    model = LLMASR_Model(
        encoder=encoder,
        encoder_output_dim=encoder_output_dim,
        llm_path=llm_path,
        lora=lora,
        lora_alpha=lora_alpha,
        lora_rank=lora_rank,
        lora_dropout=lora_dropout,
        is_inference=is_inference,
        downsample_rate=configs.get('downsample_rate',1),
        adapter_type=configs.get('adapter_type', 'lyz'),
        speech_token_num=speech_token_num,
        train_speech_out=train_speech_out,
    )

    utils_file.print_model_size(model.encoder)
    utils_file.print_model_size(model.llama_model)
    # utils_file.print_model_size(model.speech_transformer)
    # utils_file.print_model_size(model.speech_llama_proj)

    logging.info(f'耿雪龙:init_salmonn():开始加载初始化模型')
    if hasattr(args, 'checkpoint') and args.checkpoint is not None:
        logging.info(f'耿雪龙: 设置了初始化模型位置,开始加载,参数文件位置:{args.checkpoint}')
        infos = load_checkpoint(model, args.checkpoint)
    elif hasattr(args, 'checkpoint') and args.enc_init is not None:
        infos = load_trained_modules(model, args)
    else:
        infos = {}

    if configs.get('init_step', False):
        infos = {}
    configs["init_infos"] = infos
    print(configs)
    logging.info('耿雪龙:加载初始化模型完毕')

    if not is_inference:
        logging.info('耿雪龙:不更换LLM的参数')
        # logging.info('耿雪龙: 开始更换LLM的参数')
        # checkpoint4llm_wrapper = "/home/work_nfs8/xlgeng/new_workspace/wenet_gxl_salmonn4ft_LLM/examples/aishell/ft_LLM/exp/ft_2B_v1/1_epoch/step_34272.pt"
        # load_checkpoint(model, checkpoint4llm_wrapper)
        # logging.info('耿雪龙: 更换LLM的参数完毕')
    else:
        logging.info('耿雪龙: 不更换LLM的参数')

    logging.info('耿雪龙:开始选择性冻结模块')
    fire_module = configs.get("fire_module", None)
    if fire_module is None:
        logging.info('耿雪龙:没有选择解冻的模块,也就是没有训练参数,直接报错返回')
        raise ValueError('没有选择解冻的模块,也就是没有训练参数,直接报错返回')
    for k, p in model.named_parameters():
        # if k.startswith("llama_model") or k.startswith("speech_encoder"):
        # if k.startswith("llama_model") or k.startswith("speech_transformer"):
        if fire_module == 'link':
            # link 包括下采样块, transformer块, 前后linear块
            if k.startswith("llama_model") or k.startswith("encoder"):
                p.requires_grad = False
        elif fire_module == 'encoder':
            if not k.startswith("encoder"):
                p.requires_grad = False
        elif fire_module == 'llm':
            if not k.startswith("llama_model"):
                p.requires_grad = False
        elif fire_module == 'link_and_encoder':
            # 这里和speech token相关的层不会被冻结
            if k.startswith("llama_model"):
                p.requires_grad = False
        elif fire_module == "link_and_encoder_and_lora":
            break
        logging.info(f"{k} {p.requires_grad}")
    logging.info('耿雪龙:冻结完毕')

    return model, configs