File size: 4,187 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from functools import cached_property

import torch

from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C


class SASADiscretizingTokenizer(EsmTokenizerBase):
    """Tokenizer for Solvent Accessible Surface Area (SASA)."""

    def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES):
        self._boundaries = sorted(boundaries)

    @cached_property
    def special_tokens(self) -> list[str]:
        return ["<pad>", "<motif>", "<unk>"]

    @cached_property
    def vocab(self) -> list[str]:
        """Discrete token vocabulary.

        Returns:
            token vocabulary with ranges represented as "<low-high>".
        """
        boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"]
        range_tokens = [
            f"<{low}-{high}>"
            for low, high in zip(boundary_strs[:-1], boundary_strs[1:])
        ]
        return self.special_tokens + range_tokens

    @cached_property
    def midpoints(self) -> list[float]:
        """Midpoints of the SASA token ranges."""
        boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2]
        midpoint_tokens = [
            (float(high) + float(low)) / 2
            for low, high in zip(boundaries[:-1], boundaries[1:])
        ]
        midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens
        return midpoint_tokens

    @cached_property
    def vocab_to_index(self) -> dict[str, int]:
        """Constructs token -> token id mapping."""
        return {word: i for i, word in enumerate(self.vocab)}

    def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
        """Determines which positions are special tokens.

        Args:
            tokens: <int>[length]
        Returns:
            <bool>[length] tensor, true where special tokens are located in the input.
        """
        return tokens < len(self.special_tokens)

    def encode(
        self, values: list[float | str], add_special_tokens: bool = True
    ) -> torch.Tensor:
        """Encodes SASA values as discrete tokens.

        Args:
            values: list of either SASA values or individual tokens. For example
                [1.2, "<pad>", 10.3, <pad>, 0.]
        Returns:
            Token ids as tensor. Adds BOS and EOS special tokens.
        """
        ids = []
        if add_special_tokens:
            ids.append(self.vocab_to_index["<pad>"])  # BOS
        for value in values:
            if isinstance(value, (float, int)):
                bucket = torch.bucketize(value, torch.tensor(self._boundaries))
                token_id = len(self.special_tokens) + bucket
            elif isinstance(value, str):
                token_id = self.vocab_to_index[value]
            else:
                raise TypeError(value)
            ids.append(token_id)
        if add_special_tokens:
            ids.append(self.vocab_to_index["<pad>"])  # EOS

        return torch.tensor(ids, dtype=torch.int64)

    def decode_float(self, encoded: torch.Tensor) -> list[float]:
        """Decodes SASA token ids into float values."""
        return [self.midpoints[token_id] for token_id in encoded]

    def decode(self, encoded: torch.Tensor) -> str:
        """Decodes SASA token ids."""
        return ",".join(self.vocab[i] for i in encoded)

    def decode_list(self, encoded: torch.Tensor) -> list[str]:
        """Decodes SASA token ids."""
        return [self.vocab[i] for i in encoded]

    @property
    def mask_token(self) -> str:
        return "<pad>"

    @property
    def mask_token_id(self) -> int:
        return self.vocab_to_index[self.mask_token]

    @property
    def bos_token(self) -> str:
        return "<pad>"

    @property
    def bos_token_id(self) -> int:
        return self.vocab_to_index[self.bos_token]

    @property
    def eos_token(self) -> str:
        return "<pad>"

    @property
    def eos_token_id(self) -> int:
        return self.vocab_to_index[self.eos_token]

    @property
    def pad_token(self) -> str:
        return "<pad>"

    @property
    def pad_token_id(self) -> int:
        return self.vocab_to_index[self.pad_token]