strexp / captum /attr /_utils /custom_modules.py
markytools's picture
added strexp
d61b9c7
raw
history blame
394 Bytes
#!/usr/bin/env python3
import torch.nn as nn
class Addition_Module(nn.Module):
"""Custom addition module that uses multiple inputs to assure correct relevance
propagation. Any addition in a forward function needs to be replaced with the
module before using LRP."""
def __init__(self) -> None:
super().__init__()
def forward(self, x1, x2):
return x1 + x2