File size: 4,831 Bytes
45667ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import torch
from transformers import AutoTokenizer, Pipeline


class TextGenerationPipeline(Pipeline):
    def __init__(self, model, **kwargs):  # type: ignore
        super().__init__(model=model, **kwargs)
        # Load tokenizers
        model_name = "InstaDeepAI/ChatNT"
        self.english_tokenizer = AutoTokenizer.from_pretrained(
            model_name, subfolder="english_tokenizer"
        )
        self.bio_tokenizer = AutoTokenizer.from_pretrained(
            model_name, subfolder="bio_tokenizer"
        )

    def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]:
        preprocess_kwargs = {}
        forward_kwargs = {}
        postprocess_kwargs = {}  # type: ignore

        if "max_num_tokens_to_decode" in kwargs:
            forward_kwargs["max_num_tokens_to_decode"] = kwargs[
                "max_num_tokens_to_decode"
            ]
        if "english_tokens_max_length" in kwargs:
            preprocess_kwargs["english_tokens_max_length"] = kwargs[
                "english_tokens_max_length"
            ]
        if "bio_tokens_max_length" in kwargs:
            preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"]

        return preprocess_kwargs, forward_kwargs, postprocess_kwargs

    def preprocess(
        self,
        inputs: dict,
        english_tokens_max_length: int = 512,
        bio_tokens_max_length: int = 512,
    ) -> dict:
        english_sequence = inputs["english_sequence"]
        dna_sequences = inputs["dna_sequences"]

        context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "  # noqa
        space = " "
        if english_sequence[-1] == " ":
            space = ""
        english_sequence = context + english_sequence + space + "ASSISTANT:"

        english_tokens = self.english_tokenizer(
            english_sequence,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=english_tokens_max_length,
        ).input_ids
        bio_tokens = self.bio_tokenizer(
            dna_sequences,
            return_tensors="pt",
            padding="max_length",
            max_length=bio_tokens_max_length,
            truncation=True,
        ).input_ids.unsqueeze(0)

        return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}

    def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
        english_tokens = model_inputs["english_tokens"].clone()
        bio_tokens = model_inputs["bio_tokens"].clone()
        projected_bio_embeddings = None

        actual_num_steps = 0
        with torch.no_grad():
            for _ in range(max_num_tokens_to_decode):
                # Check if no more pad token id
                if (
                    self.english_tokenizer.pad_token_id
                    not in english_tokens[0].cpu().numpy()
                ):
                    break

                # Predictions
                outs = self.model(
                    multi_omics_tokens_ids=(english_tokens, bio_tokens),
                    projection_english_tokens_ids=english_tokens,
                    projected_bio_embeddings=projected_bio_embeddings,
                )
                projected_bio_embeddings = outs["projected_bio_embeddings"]
                logits = outs["logits"].detach().cpu().numpy()

                # Get predicted token
                first_idx_pad_token = np.where(
                    english_tokens[0].cpu() == self.english_tokenizer.pad_token_id
                )[0][0]
                predicted_token = np.argmax(logits[0, first_idx_pad_token - 1])

                # If it's <eos> then stop, else add the predicted token
                if predicted_token == self.english_tokenizer.eos_token_id:
                    break
                else:
                    english_tokens[0, first_idx_pad_token] = predicted_token
                    actual_num_steps += 1

            # Get the position where generation started
            idx_begin_generation = np.where(
                model_inputs["english_tokens"][0].cpu()
                == self.english_tokenizer.pad_token_id
            )[0][0]

            # Get generated tokens
            generated_tokens = english_tokens[
                0, idx_begin_generation : idx_begin_generation + actual_num_steps
            ]

        return {
            "generated_tokens": generated_tokens,
        }

    def postprocess(self, model_outputs: dict) -> str:
        generated_tokens = model_outputs["generated_tokens"]
        generated_sequence: str = self.english_tokenizer.decode(generated_tokens)
        return generated_sequence