LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
6.7 kB
from typing import Any, Dict
from schema import Schema
from data import Scenario, MergedDataset
from methods.base.alg import BaseAlg
from data import build_dataloader
from ..model import ElasticDNN_OfflineFMModel
from ...model.base import ElasticDNNUtil
import torch.optim
import tqdm
from torch import nn
from torchvision.transforms import Compose
from utils.dl.common.env import create_tbwriter
import os
import random
import numpy as np
from copy import deepcopy
from utils.dl.common.model import get_module
from utils.common.log import logger
class ElasticDNN_FMLoRAAlg(BaseAlg):
def get_required_models_schema(self) -> Schema:
return Schema({
'fm': ElasticDNN_OfflineFMModel
})
def get_required_hyp_schema(self) -> Schema:
from schema import Optional
return Schema({
'launch_tbboard': bool,
'samples_size': object,
'ab_r': int,
'train_batch_size': int,
'val_batch_size': int,
'num_workers': int,
'optimizer': str,
'optimizer_args': dict,
'scheduler': str,
'scheduler_args': dict,
'num_iters': int,
'val_freq': int,
Optional('fm_lora_ckpt_path'): str,
Optional('transform'): Compose,
})
def run(self, scenario: Scenario, hyps: Dict, collate_fn=None) -> Dict[str, Any]:
super().run(scenario, hyps)
assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion
# 1. add LoRA
lora_util = self.models['fm'].get_lora_util()
device = self.models['fm'].device
sample = hyps['samples_size']
if isinstance(sample, (tuple, list)) and isinstance(sample[0], int):
sample = torch.rand(hyps['samples_size']).to(device)
lora_util.add_lora_ab_to_fm(self.models['fm'].models_dict['main'], hyps['ab_r'], sample)
if 'fm_lora_ckpt_path' in hyps.keys() and hyps['fm_lora_ckpt_path'] != '' and hyps['fm_lora_ckpt_path'] is not None:
_ckpt = torch.load(hyps['fm_lora_ckpt_path'])['main']
new_state_dict = deepcopy(self.models['fm'].models_dict['main'].state_dict())
for n, p in _ckpt.named_parameters():
if 'qkv.abs' not in n:
continue
new_state_dict[n] = p
logger.info(f'use {n} from ckpt')
self.models['fm'].models_dict['main'].load_state_dict(new_state_dict)
# 2. train (knowledge distillation, index relationship)
if 'transform' in hyps.keys():
offline_datasets = scenario.get_offline_datasets(transform=hyps['transform'])
else:
offline_datasets = scenario.get_offline_datasets()
train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()])
# debug
# from data.visualize import visualize_classes_in_object_detection
# d = offline_datasets['GTA5Det']['val']
# class_to_idx_map = {c: d.idx_map[i] for i, c in enumerate(d.classes)}
# print(class_to_idx_map)
# visualize_classes_in_object_detection(d, class_to_idx_map,
# {}, os.path.join(self.res_save_dir, 'debug.png'))
# exit()
val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()])
train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'],
True, None, collate_fn=collate_fn))
# if hyps['use_train_loader_for_val']:
# val_loader = build_dataloader(train_dataset, hyps['val_batch_size'], hyps['num_workers'],
# False, False)
# logger.warn('use train loader for val!!!')
# else:
val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'],
False, False, collate_fn=collate_fn)
lora_params = lora_util.train_only_lora(self.models['fm'].models_dict['main'])
head_params = self.models['fm'].get_task_head_params()
num_lora_params = sum([np.prod(p.size()) for p in lora_params])
total_params = sum([np.prod(p.size()) for p in self.models['fm'].models_dict['main'].parameters()])
logger.info(f'num lora params: {num_lora_params}, total params: {total_params}, ratio: {num_lora_params / total_params}')
optimizer = torch.optim.__dict__[hyps['optimizer']](lora_params + head_params, **hyps['optimizer_args'])
scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args'])
fbs_tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard'])
pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True)
best_val_acc = 0
val_acc = 0
for iter_index in pbar:
self.models['fm'].to_train_mode()
x, y = next(train_loader)
if isinstance(x, dict):
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(device)
y = y.to(device)
else:
x, y = x.to(device), y.to(device)
task_loss = self.models['fm'].forward_to_get_task_loss(x, y)
optimizer.zero_grad()
task_loss.backward()
optimizer.step()
scheduler.step()
if (iter_index + 1) % hyps['val_freq'] == 0:
# logger.warn('use train loader for val!!!')
self.models['fm'].to_eval_mode()
val_acc = self.models['fm'].get_accuracy(val_loader)
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt'))
if val_acc > best_val_acc:
best_val_acc = val_acc
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt'))
fbs_tb_writer.add_scalar(f'losses/task_loss', task_loss, iter_index)
fbs_tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index)
fbs_tb_writer.add_scalar(f'lr', optimizer.param_groups[0]['lr'], iter_index)
pbar.set_description(f'loss: {task_loss:.6f}, val_acc: {val_acc:.4f}')