#!/usr/bin/env python3 import torch.nn as nn class Addition_Module(nn.Module): """Custom addition module that uses multiple inputs to assure correct relevance propagation. Any addition in a forward function needs to be replaced with the module before using LRP.""" def __init__(self) -> None: super().__init__() def forward(self, x1, x2): return x1 + x2