|
import torch |
|
import numpy as np |
|
from typing import List, Callable |
|
from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric |
|
|
|
|
|
def multiply_tensor_with_cam(input_tensor: torch.Tensor, |
|
cam: torch.Tensor): |
|
""" Multiply an input tensor (after normalization) |
|
with a pixel attribution map |
|
""" |
|
return input_tensor * cam |
|
|
|
|
|
class CamMultImageConfidenceChange(PerturbationConfidenceMetric): |
|
def __init__(self): |
|
super(CamMultImageConfidenceChange, |
|
self).__init__(multiply_tensor_with_cam) |
|
|
|
|
|
class DropInConfidence(CamMultImageConfidenceChange): |
|
def __init__(self): |
|
super(DropInConfidence, self).__init__() |
|
|
|
def __call__(self, *args, **kwargs): |
|
scores = super(DropInConfidence, self).__call__(*args, **kwargs) |
|
scores = -scores |
|
return np.maximum(scores, 0) |
|
|
|
|
|
class IncreaseInConfidence(CamMultImageConfidenceChange): |
|
def __init__(self): |
|
super(IncreaseInConfidence, self).__init__() |
|
|
|
def __call__(self, *args, **kwargs): |
|
scores = super(IncreaseInConfidence, self).__call__(*args, **kwargs) |
|
return np.float32(scores > 0) |
|
|