File size: 2,684 Bytes
0d998a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Defines the dictionary classes
"""

import torch
import torch.nn as nn
from tensordict import TensorDict


class SparseAutoEncoder(nn.Module):
    """
    A 2-layer sparse autoencoder.
    """

    def __init__(
        self,
        activation_dim,
        dict_size,
        pre_bias=False,
        init_normalise_dict=None,
    ):
        super().__init__()
        self.activation_dim = activation_dim
        self.dict_size = dict_size
        self.pre_bias = pre_bias
        self.init_normalise_dict = init_normalise_dict

        self.b_enc = nn.Parameter(torch.zeros(self.dict_size))
        self.relu = nn.ReLU()

        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.dict_size,
                    self.activation_dim,
                )
            )
        )
        if init_normalise_dict == "l2":
            self.normalize_dict_(less_than_1=False)
            self.W_dec *= 0.1
        elif init_normalise_dict == "less_than_1":
            self.normalize_dict_(less_than_1=True)

        self.W_enc = nn.Parameter(self.W_dec.t())
        self.b_dec = nn.Parameter(
            torch.zeros(
                self.activation_dim,
            )
        )

    @torch.no_grad()
    def normalize_dict_(
        self,
        less_than_1=False,
    ):
        norm = self.W_dec.norm(dim=1)
        positive_mask = norm != 0
        if less_than_1:
            greater_than_1_mask = (norm > 1) & (positive_mask)
            self.W_dec[greater_than_1_mask] /= norm[greater_than_1_mask].unsqueeze(1)
        else:
            self.W_dec[positive_mask] /= norm[positive_mask].unsqueeze(1)

    def encode(self, x):
        return x @ self.W_enc + self.b_enc

    def decode(self, f):
        return f @ self.W_dec + self.b_dec

    def forward(self, x, output_features=False, ghost_mask=None):
        """
        Forward pass of an autoencoder.
        x : activations to be autoencoded
        output_features : if True, return the encoded features as well
            as the decoded x
        ghost_mask : if not None, run this autoencoder in "ghost mode"
            where features are masked
        """
        if self.pre_bias:
            x = x - self.b_dec
        f_pre = self.encode(x)
        out = TensorDict({}, batch_size=x.shape[0])
        if ghost_mask is not None:
            f_ghost = torch.exp(f_pre) * ghost_mask.to(f_pre)
            x_ghost = f_ghost @ self.W_dec
            out["x_ghost"] = x_ghost
        f = self.relu(f_pre)
        if output_features:
            out["features"] = f
        x_hat = self.decode(f)
        out["x_hat"] = x_hat
        return out