File size: 6,698 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from typing import Any, Dict
from schema import Schema
from data import Scenario, MergedDataset
from methods.base.alg import BaseAlg
from data import build_dataloader
from ..model import ElasticDNN_OfflineFMModel
from ...model.base import ElasticDNNUtil
import torch.optim
import tqdm
from torch import nn
from torchvision.transforms import Compose
from utils.dl.common.env import create_tbwriter
import os
import random
import numpy as np
from copy import deepcopy
from utils.dl.common.model import get_module
from utils.common.log import logger
class ElasticDNN_FMLoRAAlg(BaseAlg):
def get_required_models_schema(self) -> Schema:
return Schema({
'fm': ElasticDNN_OfflineFMModel
})
def get_required_hyp_schema(self) -> Schema:
from schema import Optional
return Schema({
'launch_tbboard': bool,
'samples_size': object,
'ab_r': int,
'train_batch_size': int,
'val_batch_size': int,
'num_workers': int,
'optimizer': str,
'optimizer_args': dict,
'scheduler': str,
'scheduler_args': dict,
'num_iters': int,
'val_freq': int,
Optional('fm_lora_ckpt_path'): str,
Optional('transform'): Compose,
})
def run(self, scenario: Scenario, hyps: Dict, collate_fn=None) -> Dict[str, Any]:
super().run(scenario, hyps)
assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion
# 1. add LoRA
lora_util = self.models['fm'].get_lora_util()
device = self.models['fm'].device
sample = hyps['samples_size']
if isinstance(sample, (tuple, list)) and isinstance(sample[0], int):
sample = torch.rand(hyps['samples_size']).to(device)
lora_util.add_lora_ab_to_fm(self.models['fm'].models_dict['main'], hyps['ab_r'], sample)
if 'fm_lora_ckpt_path' in hyps.keys() and hyps['fm_lora_ckpt_path'] != '' and hyps['fm_lora_ckpt_path'] is not None:
_ckpt = torch.load(hyps['fm_lora_ckpt_path'])['main']
new_state_dict = deepcopy(self.models['fm'].models_dict['main'].state_dict())
for n, p in _ckpt.named_parameters():
if 'qkv.abs' not in n:
continue
new_state_dict[n] = p
logger.info(f'use {n} from ckpt')
self.models['fm'].models_dict['main'].load_state_dict(new_state_dict)
# 2. train (knowledge distillation, index relationship)
if 'transform' in hyps.keys():
offline_datasets = scenario.get_offline_datasets(transform=hyps['transform'])
else:
offline_datasets = scenario.get_offline_datasets()
train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()])
# debug
# from data.visualize import visualize_classes_in_object_detection
# d = offline_datasets['GTA5Det']['val']
# class_to_idx_map = {c: d.idx_map[i] for i, c in enumerate(d.classes)}
# print(class_to_idx_map)
# visualize_classes_in_object_detection(d, class_to_idx_map,
# {}, os.path.join(self.res_save_dir, 'debug.png'))
# exit()
val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()])
train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'],
True, None, collate_fn=collate_fn))
# if hyps['use_train_loader_for_val']:
# val_loader = build_dataloader(train_dataset, hyps['val_batch_size'], hyps['num_workers'],
# False, False)
# logger.warn('use train loader for val!!!')
# else:
val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'],
False, False, collate_fn=collate_fn)
lora_params = lora_util.train_only_lora(self.models['fm'].models_dict['main'])
head_params = self.models['fm'].get_task_head_params()
num_lora_params = sum([np.prod(p.size()) for p in lora_params])
total_params = sum([np.prod(p.size()) for p in self.models['fm'].models_dict['main'].parameters()])
logger.info(f'num lora params: {num_lora_params}, total params: {total_params}, ratio: {num_lora_params / total_params}')
optimizer = torch.optim.__dict__[hyps['optimizer']](lora_params + head_params, **hyps['optimizer_args'])
scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args'])
fbs_tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard'])
pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True)
best_val_acc = 0
val_acc = 0
for iter_index in pbar:
self.models['fm'].to_train_mode()
x, y = next(train_loader)
if isinstance(x, dict):
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(device)
y = y.to(device)
else:
x, y = x.to(device), y.to(device)
task_loss = self.models['fm'].forward_to_get_task_loss(x, y)
optimizer.zero_grad()
task_loss.backward()
optimizer.step()
scheduler.step()
if (iter_index + 1) % hyps['val_freq'] == 0:
# logger.warn('use train loader for val!!!')
self.models['fm'].to_eval_mode()
val_acc = self.models['fm'].get_accuracy(val_loader)
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt'))
if val_acc > best_val_acc:
best_val_acc = val_acc
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt'))
fbs_tb_writer.add_scalar(f'losses/task_loss', task_loss, iter_index)
fbs_tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index)
fbs_tb_writer.add_scalar(f'lr', optimizer.param_groups[0]['lr'], iter_index)
pbar.set_description(f'loss: {task_loss:.6f}, val_acc: {val_acc:.4f}')
|