Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from abc import ABC, abstractmethod | |
from typing import Any | |
from torch.nn import Module | |
from torch.utils.data import Dataset | |
class DataInfluence(ABC): | |
r""" | |
An abstract class to define model data influence skeleton. | |
""" | |
def __init_( | |
self, model: Module, influence_src_dataset: Dataset, **kwargs: Any | |
) -> None: | |
r""" | |
Args: | |
model (torch.nn.Module): An instance of pytorch model. | |
influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is | |
used to create a PyTorch Dataloader to iterate over the dataset and | |
its labels. This is the dataset for which we will be seeking for | |
influential instances. In most cases this is the training dataset. | |
**kwargs: Additional key-value arguments that are necessary for specific | |
implementation of `DataInfluence` abstract class. | |
""" | |
self.model = model | |
self.influence_src_dataset = influence_src_dataset | |
def influence(self, inputs: Any = None, **kwargs: Any) -> Any: | |
r""" | |
Args: | |
inputs (Any): Batch of examples for which influential | |
instances are computed. They are passed to the forward_func. If | |
`inputs` if a tensor or tuple of tensors, the first dimension | |
of a tensor corresponds to the batch dimension. | |
**kwargs: Additional key-value arguments that are necessary for specific | |
implementation of `DataInfluence` abstract class. | |
Returns: | |
influences (Any): We do not add restrictions on the return type for now, | |
though this may change in the future. | |
""" | |
pass | |