|
""" |
|
self supervise dataset AI-inferance Script ver: Aug 25th 22:00 |
|
|
|
""" |
|
import argparse |
|
import csv |
|
import os |
|
import shutil |
|
import sys |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from tqdm import tqdm |
|
|
|
sys.path.append("..") |
|
from Backbone.getmodel import get_model |
|
from utils.tools import find_all_files |
|
from utils.data_augmentation import data_augmentation |
|
|
|
|
|
def trans_csv_folder_to_imagefoder(target_path=r'C:\Users\admin\Desktop\MRAS_SEED_dataset', |
|
original_path=r'C:\Users\admin\Desktop\dataset\MARS_SEED_Dataset\train\train_org_image', |
|
csv_path=r'C:\Users\admin\Desktop\dataset\MARS_SEED_Dataset\train\train_label.csv'): |
|
""" |
|
Original data format: a folder with image inside + a csv file with header which has the name and category of every image. |
|
Process original dataset and get data packet in image folder format |
|
|
|
:param target_path: the path of target image folder |
|
:param original_path: The folder with images |
|
:param csv_path: A csv file with header and the name and category of each image |
|
""" |
|
idx = -1 |
|
with open(csv_path, "rt", encoding="utf-8") as csvfile: |
|
reader = csv.reader(csvfile) |
|
rows = [row for row in reader] |
|
|
|
if not os.path.exists(target_path): |
|
os.makedirs(target_path) |
|
|
|
for row in tqdm(rows): |
|
idx += 1 |
|
|
|
item_path = row[0] |
|
if os.path.exists(os.path.join(target_path, row[1])): |
|
shutil.copy(item_path, os.path.join(target_path, row[1])) |
|
else: |
|
os.makedirs(os.path.join(target_path, row[1])) |
|
shutil.copy(item_path, os.path.join(target_path, row[1])) |
|
|
|
print('total num:', idx) |
|
|
|
|
|
class PILImageTransform: |
|
def __init__(self): |
|
pass |
|
|
|
def __call__(self, image): |
|
|
|
b, g, r = cv2.split(image) |
|
image = cv2.merge([r, g, b]) |
|
return Image.fromarray(np.uint8(image)) |
|
|
|
|
|
class Front_Background_Dataset(torch.utils.data.Dataset): |
|
def __init__(self, input_root, data_transforms=None, edge_size=384, suffix='.jpg'): |
|
|
|
super().__init__() |
|
|
|
self.data_root = input_root |
|
|
|
|
|
self.input_ids = sorted(find_all_files(self.data_root, suffix=suffix)) |
|
|
|
|
|
self.PIL_Transform = PILImageTransform() |
|
|
|
|
|
if data_transforms is not None: |
|
self.transform = data_transforms |
|
else: |
|
self.transform = transforms.Compose([transforms.Resize(edge_size), transforms.ToTensor()]) |
|
|
|
def __len__(self): |
|
return len(self.input_ids) |
|
|
|
def __getitem__(self, idx): |
|
|
|
imageName = self.input_ids[idx] |
|
|
|
imageID = imageName |
|
|
|
|
|
|
|
|
|
image = np.array(cv2.imread(imageName), dtype=np.float32) |
|
|
|
image = self.transform(self.PIL_Transform(image)) |
|
|
|
return image, imageID |
|
|
|
|
|
def inferance(model, dataloader, record_dir, class_names=['0', '1'], result_csv_name='inferance.csv', device='cuda'): |
|
if not os.path.exists(record_dir): |
|
os.makedirs(record_dir) |
|
|
|
model.eval() |
|
print('Inferance') |
|
print('-' * 10) |
|
|
|
check_idx = 0 |
|
|
|
with open(os.path.join(record_dir, result_csv_name), 'w') as f_log: |
|
|
|
for images, imageIDs in dataloader: |
|
images = images.to(device) |
|
|
|
|
|
outputs = model(images) |
|
confidence, preds = torch.max(outputs, 1) |
|
|
|
pred_labels = preds.cpu().numpy() |
|
|
|
for output_idx in range(len(pred_labels)): |
|
f_log.write(str(imageIDs[output_idx]) + ', ' + str(class_names[pred_labels[output_idx]]) + ', \n') |
|
check_idx += 1 |
|
|
|
f_log.close() |
|
print(str(check_idx) + ' samples are all recorded') |
|
|
|
|
|
def main(args): |
|
if args.paint: |
|
|
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
|
|
model_idx = args.model_idx |
|
dataroot = args.dataroot |
|
save_model_path = os.path.join(args.model_path, 'CLS_' + model_idx + '.pth') |
|
record_dir = args.record_dir |
|
if not os.path.exists(record_dir): |
|
os.mkdir(record_dir) |
|
|
|
gpu_idx = args.gpu_idx |
|
|
|
drop_rate = args.drop_rate |
|
attn_drop_rate = args.attn_drop_rate |
|
drop_path_rate = args.drop_path_rate |
|
use_cls_token = False if args.cls_token_off else True |
|
use_pos_embedding = False if args.pos_embedding_off else True |
|
use_att_module = None if args.att_module == 'None' else args.att_module |
|
edge_size = args.edge_size |
|
batch_size = args.batch_size |
|
|
|
data_transforms = data_augmentation(data_augmentation_mode=args.data_augmentation_mode, edge_size=edge_size) |
|
|
|
inf_dataset = Front_Background_Dataset(dataroot, data_transforms=data_transforms['val'], edge_size=edge_size, |
|
suffix='.jpg') |
|
dataloader = torch.utils.data.DataLoader(inf_dataset, batch_size=batch_size, num_workers=2, shuffle=False) |
|
|
|
class_names = ['0', '1'] |
|
|
|
|
|
pretrained_backbone = False |
|
if args.num_classes == 0: |
|
print("class_names:", class_names) |
|
num_classes = len(class_names) |
|
else: |
|
if len(class_names) == args.num_classes: |
|
print("class_names:", class_names) |
|
else: |
|
print('classfication number of the model mismatch the dataset requirement of:', len(class_names)) |
|
return -1 |
|
|
|
model = get_model(num_classes, edge_size, model_idx, drop_rate, attn_drop_rate, drop_path_rate, |
|
pretrained_backbone, use_cls_token, use_pos_embedding, use_att_module) |
|
|
|
|
|
if gpu_idx == -1: |
|
if torch.cuda.device_count() > 1: |
|
print("Use", torch.cuda.device_count(), "GPUs!") |
|
|
|
model = nn.DataParallel(model) |
|
else: |
|
print('we dont have more GPU idx here, try to use gpu_idx=0') |
|
try: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
except: |
|
print("GPU distributing ERRO occur use CPU instead") |
|
|
|
else: |
|
|
|
try: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_idx) |
|
except: |
|
print('we dont have that GPU idx here, try to use gpu_idx=0') |
|
try: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
except: |
|
print("GPU distributing ERRO occur use CPU instead") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
try: |
|
model.load_state_dict(torch.load(save_model_path), False) |
|
except: |
|
print('model loading erro') |
|
else: |
|
print('model loaded') |
|
|
|
model.to(device) |
|
|
|
inferance(model, dataloader, record_dir, class_names=class_names, result_csv_name='inferance.csv', device='cuda') |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet INF') |
|
|
|
|
|
parser.add_argument('--model_idx', default='Hybrid2_384_401_testsample', type=str, help='Model Name or index') |
|
|
|
|
|
parser.add_argument('--MIL_Stripe', action='store_true', help='MIL_Stripe') |
|
|
|
|
|
parser.add_argument('--drop_rate', default=0.0, type=float, help='dropout rate , default 0.0') |
|
parser.add_argument('--attn_drop_rate', default=0.0, type=float, help='dropout rate Aftter Attention, default 0.0') |
|
parser.add_argument('--drop_path_rate', default=0.0, type=float, help='drop path for stochastic depth, default 0.0') |
|
|
|
|
|
parser.add_argument('--cls_token_off', action='store_true', help='use cls_token in model structure') |
|
parser.add_argument('--pos_embedding_off', action='store_true', help='use pos_embedding in model structure') |
|
|
|
parser.add_argument('--att_module', default='SimAM', type=str, help='use which att_module in model structure') |
|
|
|
|
|
parser.add_argument('--gpu_idx', default=0, type=int, |
|
help='use a single GPU with its index, -1 to use multiple GPU') |
|
|
|
|
|
parser.add_argument('--dataroot', default=r'/data/pancreatic-cancer-project/k5_dataset', |
|
help='path to dataset') |
|
parser.add_argument('--model_path', default=r'/home/pancreatic-cancer-project/saved_models', |
|
help='path to save model state-dict') |
|
parser.add_argument('--record_dir', default=r'/home/pancreatic-cancer-project/INF', |
|
help='path to record INF csv') |
|
|
|
|
|
parser.add_argument('--paint', action='store_false', help='paint in front desk') |
|
parser.add_argument('--enable_notify', action='store_true', help='enable notify to send email') |
|
|
|
parser.add_argument('--enable_tensorboard', action='store_true', help='enable tensorboard to save status') |
|
|
|
parser.add_argument('--enable_attention_check', action='store_true', help='check and save attention map') |
|
parser.add_argument('--enable_visualize_check', action='store_true', help='check and save pics') |
|
|
|
parser.add_argument('--data_augmentation_mode', default=0, type=int, help='data_augmentation_mode') |
|
|
|
|
|
parser.add_argument('--PromptTuning', default=None, type=str, |
|
help='use Prompt Tuning strategy instead of Finetuning') |
|
|
|
parser.add_argument('--Prompt_Token_num', default=10, type=int, help='Prompt_Token_num') |
|
|
|
|
|
parser.add_argument('--num_classes', default=0, type=int, help='classification number, default 0 for auto-fit') |
|
parser.add_argument('--edge_size', default=384, type=int, help='edge size of input image') |
|
|
|
|
|
parser.add_argument('--batch_size', default=1, type=int, help='testing batch_size default 1') |
|
|
|
return parser |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = get_args_parser() |
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|
|
|