markytools's picture
added strexp
d61b9c7
#!/usr/bin/env python3
import os
from typing import Any, Dict, List
import torch
from captum.concept._core.concept import Concept
from captum.concept._utils.common import concepts_to_str
class CAV:
r"""
Concept Activation Vector (CAV) is a vector orthogonal to the decision
boundary of a classifier which distinguishes between activation
vectors produced by different concepts.
More details can be found in the paper:
https://arxiv.org/pdf/1711.11279.pdf
"""
def __init__(
self,
concepts: List[Concept],
layer: str,
stats: Dict[str, Any] = None,
save_path: str = "./cav/",
model_id: str = "default_model_id",
) -> None:
r"""
This class encapsulates the instances of CAVs objects, saves them in
and loads them from the disk (storage).
Args:
concepts (list[Concept]): a List of Concept objects. Only their
names will be saved and loaded.
layer (str): The layer where concept activation vectors are
computed using a predefined classifier.
stats (dict, optional): a dictionary that retains information about
the CAV classifier such as CAV weights and accuracies.
Ex.: stats = {"weights": weights, "classes": classes,
"accs": accs}, where "weights" are learned
model parameters, "classes" are a list of classes used
by the model to generate the "weights" and "accs"
the classifier training or validation accuracy.
save_path (str, optional): The path where the CAV objects are stored.
model_id (str, optional): A unique model identifier associated with
this CAV instance.
"""
self.concepts = concepts
self.layer = layer
self.stats = stats
self.save_path = save_path
self.model_id = model_id
@staticmethod
def assemble_save_path(
path: str, model_id: str, concepts: List[Concept], layer: str
) -> str:
r"""
A utility method for assembling filename and its path, from
a concept list and a layer name.
Args:
path (str): A path to be concatenated with the concepts key and
layer name.
model_id (str): A unique model identifier associated with input
`layer` and `concepts`
concepts (list(Concept)): A list of concepts that are concatenated
together and used as a concept key using their ids. These
concept ids are retrieved from TCAV s`Concept` objects.
layer (str): The name of the layer for which the activations are
computed.
Returns:
cav_path(str): A string containing the path where the computed CAVs
will be stored.
For example, given:
concept_ids = [0, 1, 2]
concept_names = ["striped", "random_0", "random_1"]
layer = "inception4c"
path = "/cavs",
the resulting save path will be:
"/cavs/default_model_id/0-1-2-inception4c.pkl"
"""
file_name = concepts_to_str(concepts) + "-" + layer + ".pkl"
return os.path.join(path, model_id, file_name)
def save(self):
r"""
Saves a dictionary of the CAV computed values into a pickle file in the
location returned by the "assemble_save_path" static methods. The
dictionary contains the concept names list, the layer name for which
the activations are computed for, the stats dictionary which contains
information about the classifier train/eval statistics such as the
weights and training accuracies. Ex.:
save_dict = {
"concept_ids": [0, 1, 2],
"concept_names": ["striped", "random_0", "random_1"],
"layer": "inception4c",
"stats": {"weights": weights, "classes": classes, "accs": accs}
}
"""
save_dict = {
"concept_ids": [c.id for c in self.concepts],
"concept_names": [c.name for c in self.concepts],
"layer": self.layer,
"stats": self.stats,
}
cavs_path = CAV.assemble_save_path(
self.save_path, self.model_id, self.concepts, self.layer
)
torch.save(save_dict, cavs_path)
@staticmethod
def create_cav_dir_if_missing(save_path: str, model_id: str) -> None:
r"""
A utility function for creating the directories where the CAVs will
be stored. CAVs are saved in a folder under named by `model_id`
under `save_path`.
Args:
save_path (str): A root path where the CAVs will be stored
model_id (str): A unique model identifier associated with the
CAVs. A folder named `model_id` is created under
`save_path`. The CAVs are later stored there.
"""
cav_model_id_path = os.path.join(save_path, model_id)
if not os.path.exists(cav_model_id_path):
os.makedirs(cav_model_id_path)
@staticmethod
def load(cavs_path: str, model_id: str, concepts: List[Concept], layer: str):
r"""
Loads CAV dictionary from a pickle file for given input
`layer` and `concepts`.
Args:
cavs_path (str): The root path where the cavs are stored
in the storage (on the disk).
Ex.: "/cavs"
model_id (str): A unique model identifier associated with the
CAVs. There exist a folder named `model_id` under
`cavs_path` path. The CAVs are loaded from this folder.
concepts (list[Concept]): A List of concepts for which
we would like to load the cavs.
layer (str): The layer name. Ex.: "inception4c". In case of nested
layers we use dots to specify the depth / hierarchy.
Ex.: "layer.sublayer.subsublayer"
Returns:
cav(CAV): An instance of a CAV class, containing the respective CAV
score per concept and layer. An example of a path where the
cavs are loaded from is:
"/cavs/default_model_id/0-1-2-inception4c.pkl"
"""
cavs_path = CAV.assemble_save_path(cavs_path, model_id, concepts, layer)
if os.path.exists(cavs_path):
save_dict = torch.load(cavs_path)
concept_names = save_dict["concept_names"]
concept_ids = save_dict["concept_ids"]
concepts = [
Concept(concept_id, concept_name, None)
for concept_id, concept_name in zip(concept_ids, concept_names)
]
cav = CAV(concepts, save_dict["layer"], save_dict["stats"])
return cav
return None