File size: 7,486 Bytes
201936b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import cv2
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

class split_white_and_gray():
    def __init__(self,threshold=120) -> None:
        """
        Initialize the class with a threshold value.

        Args:
            threshold (int, optional): The threshold value to be set. Defaults to 120.
        """
        self.threshold = threshold

    def __call__(self,tensor):
        """
        Apply thresholding to the input tensor and return the white matter, gray matter, and the original tensor.
        
        Parameters:
        tensor (torch.Tensor): The input tensor to be thresholded.
        
        Returns:
        torch.Tensor: The thresholded white matter.
        torch.Tensor: The thresholded gray matter.
        torch.Tensor: The original input tensor.
        """
        tensor = (tensor*255).to(torch.int64)

        # Apply thresholding
        white_matter = torch.where(tensor >= self.threshold,tensor,0)
        white_matter = (white_matter/255).to(torch.float64)
        gray_matter = torch.where(tensor < self.threshold,tensor,0)
        gray_matter = (gray_matter/255).to(torch.float64)
        tensor = (tensor/255).to(torch.float64)

        return white_matter, gray_matter,tensor
    
def showcam_withoutmask(original_image, grayscale_cam, image_title='Original Image'):
    """This function applies the CAM mask to the original image and returns the Matplotlib Figure object.
    
    :param original_image: The original image tensor in PyTorch format.
    :param grayscale_cam: The CAM mask tensor in PyTorch format.

    :return: Matplotlib Figure object.
    """
    # Assuming you have two tensors: 'original_image' and 'cam_mask'
    # Make sure both tensors are on the CPU
    original_image = torch.squeeze(original_image).cpu()  # torch.Size([3, 150, 150])
    cam_mask = grayscale_cam.cpu()  # torch.Size([1, 150, 150])

    # Convert the tensors to NumPy arrays
    original_image_np = original_image.numpy()
    cam_mask_np = cam_mask.numpy()

    # Apply the mask to the original image
    masked_image = original_image_np * cam_mask_np

    # Normalize the masked_image
    masked_image_norm = (masked_image - np.min(masked_image)) / (np.max(masked_image) - np.min(masked_image))

    # Create Matplotlib Figure
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Plot the original image
    axes[0].imshow(original_image_np.transpose(1, 2, 0))  # Assuming your original image is in (C, H, W) format
    axes[0].set_title(image_title)

    # Plot the CAM mask
    axes[1].imshow(cam_mask_np[0], cmap='jet')  # Assuming your mask is grayscale
    axes[1].set_title('CAM Mask')

    # Plot the overlay (normalized)
    axes[2].imshow(masked_image_norm.transpose(1, 2, 0))  # Assuming your original image is in (C, H, W) format
    axes[2].set_title('Overlay (Normalized)')

    return fig

def showcam_withmask(img_tensor: torch.Tensor,
                     mask_tensor: torch.Tensor,
                     use_rgb: bool = False,
                     colormap: int = cv2.COLORMAP_JET,
                     image_weight: float = 0.5,
                     image_title: str = 'Original Image') -> plt.Figure:
    """ This function overlays the CAM mask on the image as a heatmap and returns the Figure object.
    By default, the heatmap is in BGR format.

    :param img_tensor: The base image tensor in PyTorch format.
    :param mask_tensor: The CAM mask tensor in PyTorch format.
    :param use_rgb: Whether to use an RGB or BGR heatmap; set to True if 'img_tensor' is in RGB format.
    :param colormap: The OpenCV colormap to be used.
    :param image_weight: The final result is image_weight * img + (1-image_weight) * mask.

    :return: Matplotlib Figure object.
    """
    # Convert PyTorch tensors to NumPy arrays
    img = img_tensor.cpu().numpy().transpose(1, 2, 0)
    mask = mask_tensor.cpu().numpy()

    # Convert the mask to a single-channel image
    mask_single_channel = np.uint8(255 * mask[0])

    heatmap = cv2.applyColorMap(mask_single_channel, colormap)

    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

    heatmap = np.float32(heatmap) / 255

    if np.max(img) > 1:
        raise Exception("The input image should be in the range [0, 1]")

    if image_weight < 0 or image_weight > 1:
        raise Exception(f"image_weight should be in the range [0, 1]. Got: {image_weight}")

    cam = (1 - image_weight) * heatmap + image_weight * img
    cam = cam / np.max(cam)

    # Create Matplotlib Figure
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot the original image
    axes[0].imshow(img)
    axes[0].set_title(image_title)

    # Plot the CAM mask
    axes[1].imshow(mask[0], cmap='jet')
    axes[1].set_title('CAM Mask')

    # Plot the overlay
    axes[2].imshow(cam)
    axes[2].set_title('Overlay')

    return fig

def predict_and_gradcam(pil_image, model, target=100, plot_type='withmask'):
    transform = transforms.Compose([
        transforms.Resize((150, 150)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        split_white_and_gray(120),
    ])
    white_matter_tensor, gray_matter_tensor, origin_tensor = transform(pil_image)
    white_matter_tensor, gray_matter_tensor, origin_tensor = white_matter_tensor.unsqueeze(0).to(torch.float32),\
        gray_matter_tensor.unsqueeze(0).to(torch.float32),\
        origin_tensor.unsqueeze(0).to(torch.float32)
    
    def calculate_gradcammask(model_grad, input_tensor):
        target_layer = [model_grad.layer4[-1]] 
        gradcam = GradCAM(model=model_grad, target_layers=target_layer)
        targets = [ClassifierOutputTarget(target)]
        grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)
        grayscale_cam = torch.tensor(grayscale_cam)

        return grayscale_cam
    
    origin_model = model.resnet18_model
    white_model = model.whitematter_resnet18_model
    gray_model = model.graymatter_resnet18_model

    origin_cam = calculate_gradcammask(origin_model, origin_tensor)
    white_cam = calculate_gradcammask(white_model, white_matter_tensor)
    gray_cam = calculate_gradcammask(gray_model, gray_matter_tensor)

    class_idx = {0: 'Moderate Demented', 1: 'Mild Demented', 2: 'Very Mild Demented', 3: 'Non Demented'}
    prediction = model(white_matter_tensor, gray_matter_tensor, origin_tensor)
    predicted_class_index = torch.argmax(prediction).item()
    predicted_class_label = class_idx[predicted_class_index]

    if plot_type == 'withmask':
        return  predicted_class_label, showcam_withmask(torch.squeeze(origin_tensor), origin_cam),\
                showcam_withmask(torch.squeeze(white_matter_tensor), white_cam, image_title='White Matter'),\
                showcam_withmask(torch.squeeze(gray_matter_tensor), gray_cam, image_title='Gray Matter')
    elif plot_type == 'withoutmask':
        return  predicted_class_label, showcam_withoutmask(torch.squeeze(origin_tensor),origin_cam),\
                showcam_withoutmask(torch.squeeze(white_matter_tensor),white_cam, image_title='White Matter'),\
                showcam_withoutmask(torch.squeeze(gray_matter_tensor),gray_cam , image_title='Gray Matter')
    else:
        raise ValueError("plot_type must be either 'withmask' or 'withoutmask'")