|
|
|
"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM). |
|
|
|
requirement: pip install grad-cam |
|
""" |
|
|
|
from argparse import ArgumentParser |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from mmengine import Config |
|
from mmengine.model import revert_sync_batchnorm |
|
from PIL import Image |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image |
|
|
|
from mmseg.apis import inference_model, init_model, show_result_pyplot |
|
from mmseg.utils import register_all_modules |
|
|
|
|
|
class SemanticSegmentationTarget: |
|
"""wrap the model. |
|
|
|
requirement: pip install grad-cam |
|
|
|
Args: |
|
category (int): Visualization class. |
|
mask (ndarray): Mask of class. |
|
size (tuple): Image size. |
|
""" |
|
|
|
def __init__(self, category, mask, size): |
|
self.category = category |
|
self.mask = torch.from_numpy(mask) |
|
self.size = size |
|
if torch.cuda.is_available(): |
|
self.mask = self.mask.cuda() |
|
|
|
def __call__(self, model_output): |
|
model_output = torch.unsqueeze(model_output, dim=0) |
|
model_output = F.interpolate( |
|
model_output, size=self.size, mode='bilinear') |
|
model_output = torch.squeeze(model_output, dim=0) |
|
|
|
return (model_output[self.category, :, :] * self.mask).sum() |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument('img', help='Image file') |
|
parser.add_argument('config', help='Config file') |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument( |
|
'--out-file', |
|
default='prediction.png', |
|
help='Path to output prediction file') |
|
parser.add_argument( |
|
'--cam-file', default='vis_cam.png', help='Path to output cam file') |
|
parser.add_argument( |
|
'--target-layers', |
|
default='backbone.layer4[2]', |
|
help='Target layers to visualize CAM') |
|
parser.add_argument( |
|
'--category-index', default='7', help='Category to visualize CAM') |
|
parser.add_argument( |
|
'--device', default='cuda:0', help='Device used for inference') |
|
args = parser.parse_args() |
|
|
|
|
|
register_all_modules() |
|
model = init_model(args.config, args.checkpoint, device=args.device) |
|
if args.device == 'cpu': |
|
model = revert_sync_batchnorm(model) |
|
|
|
|
|
result = inference_model(model, args.img) |
|
|
|
|
|
show_result_pyplot( |
|
model, |
|
args.img, |
|
result, |
|
draw_gt=False, |
|
show=False if args.out_file is not None else True, |
|
out_file=args.out_file) |
|
|
|
|
|
prediction_data = result.pred_sem_seg.data |
|
pre_np_data = prediction_data.cpu().numpy().squeeze(0) |
|
|
|
target_layers = args.target_layers |
|
target_layers = [eval(f'model.{target_layers}')] |
|
|
|
category = int(args.category_index) |
|
mask_float = np.float32(pre_np_data == category) |
|
|
|
|
|
image = np.array(Image.open(args.img).convert('RGB')) |
|
height, width = image.shape[0], image.shape[1] |
|
rgb_img = np.float32(image) / 255 |
|
config = Config.fromfile(args.config) |
|
image_mean = config.data_preprocessor['mean'] |
|
image_std = config.data_preprocessor['std'] |
|
input_tensor = preprocess_image( |
|
rgb_img, |
|
mean=[x / 255 for x in image_mean], |
|
std=[x / 255 for x in image_std]) |
|
|
|
|
|
|
|
targets = [ |
|
SemanticSegmentationTarget(category, mask_float, (height, width)) |
|
] |
|
with GradCAM( |
|
model=model, |
|
target_layers=target_layers, |
|
use_cuda=torch.cuda.is_available()) as cam: |
|
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] |
|
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) |
|
|
|
|
|
Image.fromarray(cam_image).save(args.cam_file) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|