|
""" |
|
Testing Script ver: Oct 23rd 17:30 |
|
""" |
|
|
|
from __future__ import print_function, division |
|
|
|
import argparse |
|
import json |
|
import time |
|
|
|
import torchvision |
|
from tensorboardX import SummaryWriter |
|
|
|
from Backbone.getmodel import get_model |
|
from Backbone.GetPromptModel import build_promptmodel |
|
|
|
from utils.data_augmentation import * |
|
from utils.visual_usage import * |
|
|
|
|
|
def test_model(model, test_dataloader, criterion, class_names, test_dataset_size, model_idx, test_model_idx, edge_size, |
|
check_minibatch=100, device=None, draw_path='../imaging_results', enable_attention_check=None, |
|
enable_visualize_check=True, writer=None): |
|
""" |
|
Testing iteration |
|
|
|
:param model: model object |
|
:param test_dataloader: the test_dataloader obj |
|
:param criterion: loss func obj |
|
:param class_names: The name of classes for priting |
|
:param test_dataset_size: size of datasets |
|
|
|
:param model_idx: model idx for the getting trained model |
|
:param edge_size: image size for the input image |
|
:param check_minibatch: number of skip over minibatch in calculating the criteria's results etc. |
|
|
|
:param device: cpu/gpu object |
|
:param draw_path: path folder for output pic |
|
:param enable_attention_check: use attention_check to show the pics of models' attention areas |
|
:param enable_visualize_check: use visualize_check to show the pics |
|
|
|
:param writer: attach the records to the tensorboard backend |
|
""" |
|
|
|
|
|
if device is None: |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
since = time.time() |
|
|
|
print('Epoch: Test') |
|
print('-' * 10) |
|
|
|
phase = 'test' |
|
index = 0 |
|
model_time = time.time() |
|
|
|
|
|
json_log = {'test': {}} |
|
|
|
|
|
log_dict = {} |
|
for cls_idx in range(len(class_names)): |
|
log_dict[class_names[cls_idx]] = {'tp': 0.0, 'tn': 0.0, 'fp': 0.0, 'fn': 0.0} |
|
|
|
model.eval() |
|
|
|
|
|
running_loss = 0.0 |
|
log_running_loss = 0.0 |
|
running_corrects = 0 |
|
|
|
|
|
for inputs, labels in test_dataloader: |
|
inputs = inputs.to(device) |
|
|
|
|
|
labels = labels.to(device) |
|
|
|
|
|
|
|
|
|
|
|
outputs = model(inputs) |
|
_, preds = torch.max(outputs, 1) |
|
loss = criterion(outputs, labels) |
|
|
|
|
|
log_running_loss += loss.item() |
|
running_loss += loss.item() * inputs.size(0) |
|
running_corrects += torch.sum(preds == labels.data) |
|
|
|
|
|
for cls_idx in range(len(class_names)): |
|
|
|
tp = np.dot((labels.cpu().data == cls_idx).numpy().astype(int), |
|
(preds == cls_idx).cpu().numpy().astype(int)) |
|
tn = np.dot((labels.cpu().data != cls_idx).numpy().astype(int), |
|
(preds != cls_idx).cpu().numpy().astype(int)) |
|
|
|
fp = np.sum((preds == cls_idx).cpu().numpy()) - tp |
|
|
|
fn = np.sum((labels.cpu().data == cls_idx).numpy()) - tp |
|
|
|
|
|
log_dict[class_names[cls_idx]]['tp'] += tp |
|
log_dict[class_names[cls_idx]]['tn'] += tn |
|
log_dict[class_names[cls_idx]]['fp'] += fp |
|
log_dict[class_names[cls_idx]]['fn'] += fn |
|
|
|
|
|
if writer is not None: |
|
|
|
writer.add_scalar(phase + ' minibatch loss', |
|
float(loss.item()), |
|
index) |
|
writer.add_scalar(phase + ' minibatch ACC', |
|
float(torch.sum(preds == labels.data) / inputs.size(0)), |
|
index) |
|
|
|
|
|
if index % check_minibatch == check_minibatch - 1: |
|
model_time = time.time() - model_time |
|
|
|
check_index = index // check_minibatch + 1 |
|
|
|
epoch_idx = 'test' |
|
print('Epoch:', epoch_idx, ' ', phase, 'index of ' + str(check_minibatch) + ' minibatch:', |
|
check_index, ' time used:', model_time) |
|
|
|
print('minibatch AVG loss:', float(log_running_loss) / check_minibatch) |
|
|
|
|
|
|
|
if enable_attention_check: |
|
try: |
|
check_SAA(inputs, labels, model, model_idx, edge_size, class_names, num_images=1, |
|
pic_name='GradCAM_' + str(epoch_idx) + '_I_' + str(index + 1), |
|
draw_path=draw_path, writer=writer) |
|
except: |
|
print('model:', model_idx, ' with edge_size', edge_size, 'is not supported yet') |
|
else: |
|
pass |
|
|
|
if enable_visualize_check: |
|
visualize_check(inputs, labels, model, class_names, num_images=-1, |
|
pic_name='Visual_' + str(epoch_idx) + '_I_' + str(index + 1), |
|
draw_path=draw_path, writer=writer) |
|
|
|
model_time = time.time() |
|
log_running_loss = 0.0 |
|
|
|
index += 1 |
|
|
|
json_log['test'][phase] = log_dict |
|
|
|
|
|
epoch_loss = running_loss / test_dataset_size |
|
epoch_acc = running_corrects.double() / test_dataset_size * 100 |
|
print('\nEpoch: {} \nLoss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) |
|
|
|
for cls_idx in range(len(class_names)): |
|
|
|
tp = log_dict[class_names[cls_idx]]['tp'] |
|
tn = log_dict[class_names[cls_idx]]['tn'] |
|
fp = log_dict[class_names[cls_idx]]['fp'] |
|
fn = log_dict[class_names[cls_idx]]['fn'] |
|
tp_plus_fp = tp + fp |
|
tp_plus_fn = tp + fn |
|
fp_plus_tn = fp + tn |
|
fn_plus_tn = fn + tn |
|
|
|
|
|
if tp_plus_fp == 0: |
|
precision = 0 |
|
else: |
|
precision = float(tp) / tp_plus_fp * 100 |
|
|
|
if tp_plus_fn == 0: |
|
recall = 0 |
|
else: |
|
recall = float(tp) / tp_plus_fn * 100 |
|
|
|
|
|
TPR = recall |
|
|
|
|
|
|
|
if fp_plus_tn == 0: |
|
TNR = 0 |
|
FPR = 0 |
|
else: |
|
TNR = tn / fp_plus_tn * 100 |
|
FPR = fp / fp_plus_tn * 100 |
|
|
|
|
|
if fn_plus_tn == 0: |
|
NPV = 0 |
|
else: |
|
NPV = tn / fn_plus_tn * 100 |
|
|
|
print('{} precision: {:.4f} recall: {:.4f}'.format(class_names[cls_idx], precision, recall)) |
|
print('{} sensitivity: {:.4f} specificity: {:.4f}'.format(class_names[cls_idx], TPR, TNR)) |
|
print('{} FPR: {:.4f} NPV: {:.4f}'.format(class_names[cls_idx], FPR, NPV)) |
|
print('{} TP: {}'.format(class_names[cls_idx], tp)) |
|
print('{} TN: {}'.format(class_names[cls_idx], tn)) |
|
print('{} FP: {}'.format(class_names[cls_idx], fp)) |
|
print('{} FN: {}'.format(class_names[cls_idx], fn)) |
|
|
|
print('\n') |
|
|
|
time_elapsed = time.time() - since |
|
print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) |
|
|
|
|
|
if writer is not None: |
|
writer.close() |
|
|
|
|
|
json.dump(json_log, open(os.path.join(draw_path, test_model_idx + '_log.json'), 'w'), ensure_ascii=False, indent=2) |
|
|
|
return model |
|
|
|
|
|
def main(args): |
|
if args.paint: |
|
|
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
gpu_idx = args.gpu_idx |
|
|
|
enable_tensorboard = args.enable_tensorboard |
|
|
|
enable_attention_check = args.enable_attention_check |
|
enable_visualize_check = args.enable_visualize_check |
|
|
|
data_augmentation_mode = args.data_augmentation_mode |
|
|
|
|
|
PromptTuning = args.PromptTuning |
|
Prompt_Token_num = args.Prompt_Token_num |
|
PromptUnFreeze = args.PromptUnFreeze |
|
|
|
model_idx = args.model_idx |
|
|
|
|
|
drop_rate = args.drop_rate |
|
attn_drop_rate = args.attn_drop_rate |
|
drop_path_rate = args.drop_path_rate |
|
use_cls_token = False if args.cls_token_off else True |
|
use_pos_embedding = False if args.pos_embedding_off else True |
|
use_att_module = None if args.att_module == 'None' else args.att_module |
|
|
|
|
|
draw_root = args.draw_root |
|
model_path = args.model_path |
|
dataroot = args.dataroot |
|
model_path_by_hand = args.model_path_by_hand |
|
|
|
Pre_Trained_model_path = args.Pre_Trained_model_path |
|
|
|
|
|
test_model_idx = 'CLS_' + model_idx + '_test' |
|
|
|
|
|
draw_path = os.path.join(draw_root, test_model_idx) |
|
|
|
|
|
|
|
if model_path_by_hand is None: |
|
|
|
save_model_path = os.path.join(model_path, 'CLS_' + model_idx + '.pth') |
|
else: |
|
save_model_path = model_path_by_hand |
|
|
|
if not os.path.exists(draw_path): |
|
os.makedirs(draw_path) |
|
|
|
|
|
test_dataroot = os.path.join(dataroot, 'test') |
|
|
|
|
|
num_classes = args.num_classes |
|
edge_size = args.edge_size |
|
|
|
|
|
batch_size = args.batch_size |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
data_transforms = data_augmentation(data_augmentation_mode, edge_size=edge_size) |
|
|
|
|
|
test_datasets = torchvision.datasets.ImageFolder(test_dataroot, data_transforms['val']) |
|
test_dataset_size = len(test_datasets) |
|
|
|
check_minibatch = args.check_minibatch if args.check_minibatch is not None else test_dataset_size // ( |
|
20 * batch_size) |
|
|
|
test_dataloader = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=False, num_workers=1) |
|
|
|
class_names = [d.name for d in os.scandir(test_dataroot) if d.is_dir()] |
|
class_names.sort() |
|
|
|
if num_classes == 0: |
|
print("class_names:", class_names) |
|
num_classes = len(class_names) |
|
else: |
|
if len(class_names) == num_classes: |
|
print("class_names:", class_names) |
|
else: |
|
print('classfication number of the model mismatch the dataset requirement of:', len(class_names)) |
|
return -1 |
|
|
|
|
|
pretrained_backbone = False |
|
|
|
if PromptTuning is None: |
|
model = get_model(num_classes, edge_size, model_idx, drop_rate, attn_drop_rate, drop_path_rate, |
|
pretrained_backbone, use_cls_token, use_pos_embedding, use_att_module) |
|
else: |
|
if Pre_Trained_model_path is not None and os.path.exists(Pre_Trained_model_path): |
|
base_state_dict = torch.load(Pre_Trained_model_path) |
|
else: |
|
base_state_dict = 'timm' |
|
print('base_state_dict of timm') |
|
|
|
print('Test the PromptTuning of ', model_idx) |
|
print('Prompt VPT type:', PromptTuning) |
|
model = build_promptmodel(num_classes, edge_size, model_idx, Prompt_Token_num=Prompt_Token_num, |
|
VPT_type=PromptTuning, base_state_dict=base_state_dict) |
|
|
|
try: |
|
if PromptTuning is None: |
|
model.load_state_dict(torch.load(save_model_path)) |
|
else: |
|
if PromptUnFreeze: |
|
model.load_state_dict(torch.load(save_model_path)) |
|
else: |
|
model.load_prompt(torch.load(save_model_path)) |
|
|
|
print("model loaded") |
|
print("model :", model_idx) |
|
|
|
except: |
|
try: |
|
model = nn.DataParallel(model) |
|
|
|
if PromptTuning is None: |
|
model.load_state_dict(torch.load(save_model_path)) |
|
else: |
|
if PromptUnFreeze: |
|
model.load_state_dict(torch.load(save_model_path)) |
|
else: |
|
model.load_prompt(torch.load(save_model_path)) |
|
|
|
print("DataParallel model loaded") |
|
except: |
|
print("model loading erro!!") |
|
return -1 |
|
|
|
if gpu_idx == -1: |
|
if torch.cuda.device_count() > 1: |
|
print("Use", torch.cuda.device_count(), "GPUs!") |
|
|
|
model = nn.DataParallel(model) |
|
else: |
|
print('we dont have more GPU idx here, try to use gpu_idx=0') |
|
try: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
except: |
|
print("GPU distributing ERRO occur use CPU instead") |
|
|
|
else: |
|
|
|
try: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_idx) |
|
except: |
|
print('we dont have that GPU idx here, try to use gpu_idx=0') |
|
try: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
except: |
|
print("GPU distributing ERRO occur use CPU instead") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model.to(device) |
|
|
|
|
|
if enable_tensorboard: |
|
writer = SummaryWriter(draw_path) |
|
else: |
|
writer = None |
|
|
|
|
|
|
|
|
|
print("*********************************{}*************************************".format('setting')) |
|
print(args) |
|
|
|
test_model(model, test_dataloader, criterion, class_names, test_dataset_size, model_idx=model_idx, |
|
test_model_idx=test_model_idx, edge_size=edge_size, check_minibatch=check_minibatch, |
|
device=device, draw_path=draw_path, enable_attention_check=enable_attention_check, |
|
enable_visualize_check=enable_visualize_check, writer=writer) |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') |
|
|
|
|
|
parser.add_argument('--model_idx', default='ViT_base', type=str, help='Model Name or index') |
|
|
|
|
|
parser.add_argument('--drop_rate', default=0.0, type=float, help='dropout rate , default 0.0') |
|
parser.add_argument('--attn_drop_rate', default=0.0, type=float, help='dropout rate Aftter Attention, default 0.0') |
|
parser.add_argument('--drop_path_rate', default=0.0, type=float, help='drop path for stochastic depth, default 0.0') |
|
|
|
|
|
parser.add_argument('--cls_token_off', action='store_true', help='use cls_token in model structure') |
|
parser.add_argument('--pos_embedding_off', action='store_true', help='use pos_embedding in model structure') |
|
|
|
parser.add_argument('--att_module', default='SimAM', type=str, help='use which att_module in model structure') |
|
|
|
|
|
parser.add_argument('--gpu_idx', default=0, type=int, |
|
help='use a single GPU with its index, -1 to use multiple GPU') |
|
|
|
|
|
parser.add_argument('--dataroot', default=r'/data/k5_dataset', |
|
help='path to dataset') |
|
parser.add_argument('--model_path', default=r'/home/saved_models', |
|
help='root path to save model state-dict, model will be find by name') |
|
parser.add_argument('--draw_root', default=r'/home/runs', |
|
help='path to draw and save tensorboard output') |
|
|
|
parser.add_argument('--model_path_by_hand', default=None, type=str, help='specified path to a model state-dict') |
|
|
|
|
|
parser.add_argument('--paint', action='store_false', help='paint in front desk') |
|
|
|
|
|
parser.add_argument('--enable_tensorboard', action='store_true', help='enable tensorboard to save status') |
|
|
|
parser.add_argument('--enable_attention_check', action='store_true', help='check and save attention map') |
|
parser.add_argument('--enable_visualize_check', action='store_true', help='check and save pics') |
|
|
|
parser.add_argument('--data_augmentation_mode', default=0, type=int, help='data_augmentation_mode') |
|
|
|
|
|
parser.add_argument('--PromptTuning', default=None, type=str, |
|
help='use Prompt Tuning strategy instead of Finetuning') |
|
|
|
parser.add_argument('--Prompt_Token_num', default=20, type=int, help='Prompt_Token_num') |
|
|
|
parser.add_argument('--PromptUnFreeze', action='store_true', help='prompt tuning with all parameaters un-freezed') |
|
|
|
parser.add_argument('--Pre_Trained_model_path', default=None, type=str, |
|
help='Finetuning a trained model in this dataset') |
|
|
|
|
|
parser.add_argument('--num_classes', default=0, type=int, help='classification number, default 0 for auto-fit') |
|
parser.add_argument('--edge_size', default=384, type=int, help='edge size of input image') |
|
|
|
|
|
parser.add_argument('--batch_size', default=1, type=int, help='testing batch_size default 1') |
|
|
|
parser.add_argument('--check_minibatch', default=None, type=int, help='check batch_size') |
|
|
|
return parser |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = get_args_parser() |
|
args = parser.parse_args() |
|
main(args) |
|
|