File size: 3,236 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from functools import cached_property
from typing import Sequence

import torch

from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C


class SecondaryStructureTokenizer(EsmTokenizerBase):
    """Tokenizer for secondary structure strings."""

    def __init__(self, kind: str = "ss8"):
        assert kind in ("ss8", "ss3")
        self.kind = kind

    @property
    def special_tokens(self) -> list[str]:
        return ["<pad>", "<motif>", "<unk>"]

    @cached_property
    def vocab(self):
        """Tokenzier vocabulary list."""
        match self.kind:
            case "ss8":
                nonspecial_tokens = list(C.SSE_8CLASS_VOCAB)  # "GHITEBSC"
            case "ss3":
                nonspecial_tokens = list(C.SSE_3CLASS_VOCAB)  # HEC
            case _:
                raise ValueError(self.kind)

        # The non-special tokens ids match amino acid tokens ids when possible.
        return [*self.special_tokens, *nonspecial_tokens]

    @cached_property
    def vocab_to_index(self) -> dict[str, int]:
        """Constructs token -> token id mapping."""
        return {word: i for i, word in enumerate(self.vocab)}

    def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
        """Determines which positions are special tokens.

        Args:
            tokens: <int>[length]
        Returns:
            <bool>[length] tensor, true where special tokens are located in the input.
        """
        return tokens < len(self.special_tokens)

    def encode(
        self, sequence: str | Sequence[str], add_special_tokens: bool = True
    ) -> torch.Tensor:
        """Encode secondary structure string

        Args:
            string: secondary structure string e.g. "GHHIT", or as token listk.
        Returns:
            <int>[sequence_length] token ids representing. Will add <cls>/<eos>.
        """
        ids = []
        if add_special_tokens:
            ids.append(self.vocab_to_index["<pad>"])  # cls
        for char in sequence:
            ids.append(self.vocab_to_index[char])
        if add_special_tokens:
            ids.append(self.vocab_to_index["<pad>"])  # eos
        return torch.tensor(ids, dtype=torch.int64)

    def decode(self, encoded: torch.Tensor) -> str:
        """Decodes token ids into secondary structure string.

        Args:
            encoded: <int>[length] token id array.
        Returns
            Decoded secondary structure string.
        """
        return "".join(self.vocab[i] for i in encoded)

    @property
    def mask_token(self) -> str:
        return "<pad>"

    @property
    def mask_token_id(self) -> int:
        return self.vocab_to_index[self.mask_token]

    @property
    def bos_token(self) -> str:
        return "<pad>"

    @property
    def bos_token_id(self) -> int:
        return self.vocab_to_index[self.bos_token]

    @property
    def eos_token(self) -> str:
        return "<pad>"

    @property
    def eos_token_id(self) -> int:
        return self.vocab_to_index[self.eos_token]

    @property
    def pad_token(self) -> str:
        return "<pad>"

    @property
    def pad_token_id(self) -> int:
        return self.vocab_to_index[self.pad_token]