|
from __future__ import print_function, division |
|
import os |
|
import sys |
|
import time |
|
import argparse |
|
import warnings |
|
import torch |
|
import pickle |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
from torch.utils.data import Dataset, DataLoader, TensorDataset |
|
from torchvision import transforms, utils |
|
from models.modeling import PATHOLOGICAL_CLASSFIER, CONFIGS |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
def load_weights(model, weight_path): |
|
print("Loading PATHOLOGICAL_CLASSFIER...",weight_path) |
|
loadnet = torch.load(weight_path,map_location=device) |
|
if "model_state_dict" in loadnet: |
|
keyname = "model_state_dict" |
|
else: |
|
keyname = "model_state_dict" |
|
model.load_state_dict(loadnet[keyname], strict=True) |
|
return model |
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, root_path): |
|
m_data = [] |
|
img_pkl_file_path = os.path.join(root_path, "img_feature") |
|
txt_pkl_file_path = os.path.join(root_path, "txt_feature") |
|
target_pkl_file_path = os.path.join(root_path, "target") |
|
for file in os.listdir(img_pkl_file_path): |
|
|
|
img_pkl_file = os.path.join(img_pkl_file_path, file) |
|
txt_pkl_file = os.path.join(txt_pkl_file_path, file) |
|
target_pkl_file = os.path.join(target_pkl_file_path, file) |
|
with open(img_pkl_file, "rb") as img_f: |
|
img_load_dict = pickle.load(img_f) |
|
m_input_img = img_load_dict["img_feature"] |
|
with open(txt_pkl_file, "rb") as txt_f: |
|
txt_load_dict = pickle.load(txt_f) |
|
m_input_txt = txt_load_dict["txt_feature"] |
|
with open(target_pkl_file, "rb") as target_f: |
|
target_load_dict = pickle.load(target_f) |
|
m_output_os = target_load_dict["target_os"] |
|
m_output_dfs = target_load_dict["target_dfs"] |
|
m_data.append((m_input_img, m_input_txt, m_output_os, m_output_dfs,file)) |
|
self.m_data = m_data |
|
def __getitem__(self, idx): |
|
inp_i, inp_txt, oup_os, oup_dfs,f_name = self.m_data[idx] |
|
return inp_i, inp_txt, oup_os, oup_dfs,f_name |
|
def __len__(self): |
|
return len(self.m_data) |
|
|
|
def valid(args): |
|
torch.manual_seed(0) |
|
num_classes = 2 |
|
config = CONFIGS["PATHOLOGICAL_CLASSFIER"] |
|
model = PATHOLOGICAL_CLASSFIER(config, num_classes=num_classes, vis=True, mm=True) |
|
|
|
model_path = '/your/trained/model/path/' |
|
p_c_model = load_weights(model, model_path) |
|
|
|
p_c_model.to(device) |
|
test_dataset = MyDataset("/your/dataset/path/" ) |
|
test_loader = DataLoader(test_dataset, batch_size=1) |
|
|
|
|
|
print("--------Start testing-------") |
|
p_c_model.eval() |
|
|
|
valid_1_acc = 0 |
|
valid_1_total = 0 |
|
valid_1_cnt = 0 |
|
|
|
valid_2_acc = 0 |
|
valid_2_total = 0 |
|
valid_2_cnt = 0 |
|
valid_total_cnt=0 |
|
|
|
target_cnt_0=0 |
|
target_cnt_1=0 |
|
with torch.no_grad(): |
|
for imgs, txt, target_1, target_2,file_name in test_loader: |
|
output_1, output_2, = model(imgs.to(device), txt.to(device)) |
|
|
|
out_1_list_prob = (torch.softmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist()) |
|
|
|
out_1_list = (torch.argmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist()) |
|
target_1_list = target_1.tolist() |
|
|
|
out_2_list = (torch.argmax(output_2.squeeze(1), axis=-1).cpu().numpy().tolist()) |
|
target_2_list = target_2.tolist() |
|
|
|
valid_1_total += len(out_1_list) |
|
valid_2_total += len(out_2_list) |
|
|
|
for i in range(len(out_1_list)): |
|
if out_1_list[i] == target_1_list[i]: |
|
valid_1_cnt += 1 |
|
if out_2_list[i] == target_2_list[i]: |
|
valid_2_cnt += 1 |
|
if out_1_list[i] == target_1_list[i] and out_2_list[i] == target_2_list[i]: |
|
valid_total_cnt+=1 |
|
|
|
valid_1_acc = valid_1_cnt / valid_1_total |
|
valid_2_acc = valid_2_cnt / valid_2_total |
|
valid_total_acc =valid_total_cnt/valid_1_total |
|
|
|
print(valid_1_acc,valid_1_total, valid_2_acc,valid_2_total,valid_total_acc,valid_total_cnt) |
|
print("="*100) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="") |
|
args = parser.parse_args() |
|
valid(args) |
|
|
|
|
|
|
|
|
|
|
|
|