from __future__ import print_function, division import os import sys import time import argparse import warnings import torch import pickle import torch.nn as nn import torch.optim as optim import pandas as pd import numpy as np import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader, TensorDataset from torchvision import transforms, utils from models.modeling import PATHOLOGICAL_CLASSFIER, CONFIGS device = "cuda" if torch.cuda.is_available() else "cpu" def load_weights(model, weight_path): print("Loading PATHOLOGICAL_CLASSFIER...",weight_path) loadnet = torch.load(weight_path,map_location=device) if "model_state_dict" in loadnet: keyname = "model_state_dict" else: keyname = "model_state_dict" model.load_state_dict(loadnet[keyname], strict=True) return model class MyDataset(Dataset): def __init__(self, root_path): m_data = [] img_pkl_file_path = os.path.join(root_path, "img_feature") txt_pkl_file_path = os.path.join(root_path, "txt_feature") target_pkl_file_path = os.path.join(root_path, "target") for file in os.listdir(img_pkl_file_path): img_pkl_file = os.path.join(img_pkl_file_path, file) txt_pkl_file = os.path.join(txt_pkl_file_path, file) target_pkl_file = os.path.join(target_pkl_file_path, file) with open(img_pkl_file, "rb") as img_f: img_load_dict = pickle.load(img_f) m_input_img = img_load_dict["img_feature"] with open(txt_pkl_file, "rb") as txt_f: txt_load_dict = pickle.load(txt_f) m_input_txt = txt_load_dict["txt_feature"] with open(target_pkl_file, "rb") as target_f: target_load_dict = pickle.load(target_f) m_output_os = target_load_dict["target_os"] m_output_dfs = target_load_dict["target_dfs"] m_data.append((m_input_img, m_input_txt, m_output_os, m_output_dfs,file)) self.m_data = m_data def __getitem__(self, idx): inp_i, inp_txt, oup_os, oup_dfs,f_name = self.m_data[idx] return inp_i, inp_txt, oup_os, oup_dfs,f_name def __len__(self): return len(self.m_data) def valid(args): torch.manual_seed(0) num_classes = 2 config = CONFIGS["PATHOLOGICAL_CLASSFIER"] model = PATHOLOGICAL_CLASSFIER(config, num_classes=num_classes, vis=True, mm=True) model_path = '/your/trained/model/path/' p_c_model = load_weights(model, model_path) p_c_model.to(device) test_dataset = MyDataset("/your/dataset/path/" ) test_loader = DataLoader(test_dataset, batch_size=1) # #----- Test ------ print("--------Start testing-------") p_c_model.eval() valid_1_acc = 0 valid_1_total = 0 valid_1_cnt = 0 valid_2_acc = 0 valid_2_total = 0 valid_2_cnt = 0 valid_total_cnt=0 target_cnt_0=0 target_cnt_1=0 with torch.no_grad(): for imgs, txt, target_1, target_2,file_name in test_loader: output_1, output_2, = model(imgs.to(device), txt.to(device)) out_1_list_prob = (torch.softmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist()) out_1_list = (torch.argmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist()) target_1_list = target_1.tolist() out_2_list = (torch.argmax(output_2.squeeze(1), axis=-1).cpu().numpy().tolist()) target_2_list = target_2.tolist() valid_1_total += len(out_1_list) valid_2_total += len(out_2_list) for i in range(len(out_1_list)): if out_1_list[i] == target_1_list[i]: valid_1_cnt += 1 if out_2_list[i] == target_2_list[i]: valid_2_cnt += 1 if out_1_list[i] == target_1_list[i] and out_2_list[i] == target_2_list[i]: valid_total_cnt+=1 valid_1_acc = valid_1_cnt / valid_1_total valid_2_acc = valid_2_cnt / valid_2_total valid_total_acc =valid_total_cnt/valid_1_total print(valid_1_acc,valid_1_total, valid_2_acc,valid_2_total,valid_total_acc,valid_total_cnt) print("="*100) if __name__ == "__main__": parser = argparse.ArgumentParser(description="") args = parser.parse_args() valid(args)