markytools's picture
added strexp
d61b9c7
#!/usr/bin/env python3
from abc import ABC, abstractmethod
from typing import Any
from torch.nn import Module
from torch.utils.data import Dataset
class DataInfluence(ABC):
r"""
An abstract class to define model data influence skeleton.
"""
def __init_(
self, model: Module, influence_src_dataset: Dataset, **kwargs: Any
) -> None:
r"""
Args:
model (torch.nn.Module): An instance of pytorch model.
influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
used to create a PyTorch Dataloader to iterate over the dataset and
its labels. This is the dataset for which we will be seeking for
influential instances. In most cases this is the training dataset.
**kwargs: Additional key-value arguments that are necessary for specific
implementation of `DataInfluence` abstract class.
"""
self.model = model
self.influence_src_dataset = influence_src_dataset
@abstractmethod
def influence(self, inputs: Any = None, **kwargs: Any) -> Any:
r"""
Args:
inputs (Any): Batch of examples for which influential
instances are computed. They are passed to the forward_func. If
`inputs` if a tensor or tuple of tensors, the first dimension
of a tensor corresponds to the batch dimension.
**kwargs: Additional key-value arguments that are necessary for specific
implementation of `DataInfluence` abstract class.
Returns:
influences (Any): We do not add restrictions on the return type for now,
though this may change in the future.
"""
pass