File size: 1,814 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python3

from abc import ABC, abstractmethod
from typing import Any

from torch.nn import Module
from torch.utils.data import Dataset


class DataInfluence(ABC):
    r"""
    An abstract class to define model data influence skeleton.
    """

    def __init_(
        self, model: Module, influence_src_dataset: Dataset, **kwargs: Any
    ) -> None:
        r"""
        Args:
            model (torch.nn.Module): An instance of pytorch model.
            influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
                    used to create a PyTorch Dataloader to iterate over the dataset and
                    its labels. This is the dataset for which we will be seeking for
                    influential instances. In most cases this is the training dataset.
            **kwargs: Additional key-value arguments that are necessary for specific
                    implementation of `DataInfluence` abstract class.
        """
        self.model = model
        self.influence_src_dataset = influence_src_dataset

    @abstractmethod
    def influence(self, inputs: Any = None, **kwargs: Any) -> Any:
        r"""
        Args:
            inputs (Any): Batch of examples for which influential
                    instances are computed. They are passed to the forward_func. If
                    `inputs` if a tensor or tuple of tensors, the first dimension
                    of a tensor corresponds to the batch dimension.
            **kwargs: Additional key-value arguments that are necessary for specific
                    implementation of `DataInfluence` abstract class.

        Returns:
            influences (Any): We do not add restrictions on the return type for now,
                    though this may change in the future.
        """
        pass