Yanisadel commited on
Commit
45667ce
·
verified ·
1 Parent(s): 36fb493

Upload text_generation.py

Browse files
Files changed (1) hide show
  1. text_generation.py +124 -0
text_generation.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from transformers import AutoTokenizer, Pipeline
4
+
5
+
6
+ class TextGenerationPipeline(Pipeline):
7
+ def __init__(self, model, **kwargs): # type: ignore
8
+ super().__init__(model=model, **kwargs)
9
+ # Load tokenizers
10
+ # TODO: Maybe do this in a better way (for now the easiest way was done)
11
+ model_name = "InstaDeepAI/ChatNT"
12
+ self.english_tokenizer = AutoTokenizer.from_pretrained(
13
+ model_name, subfolder="english_tokenizer"
14
+ )
15
+ self.bio_tokenizer = AutoTokenizer.from_pretrained(
16
+ model_name, subfolder="bio_tokenizer"
17
+ )
18
+
19
+ def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]:
20
+ preprocess_kwargs = {}
21
+ forward_kwargs = {}
22
+ postprocess_kwargs = {} # type: ignore
23
+
24
+ if "max_num_tokens_to_decode" in kwargs:
25
+ forward_kwargs["max_num_tokens_to_decode"] = kwargs[
26
+ "max_num_tokens_to_decode"
27
+ ]
28
+ if "english_tokens_max_length" in kwargs:
29
+ preprocess_kwargs["english_tokens_max_length"] = kwargs[
30
+ "english_tokens_max_length"
31
+ ]
32
+ if "bio_tokens_max_length" in kwargs:
33
+ preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"]
34
+
35
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
36
+
37
+ def preprocess(
38
+ self,
39
+ inputs: dict,
40
+ english_tokens_max_length: int = 512,
41
+ bio_tokens_max_length: int = 512,
42
+ ) -> dict:
43
+ english_sequence = inputs["english_sequence"]
44
+ dna_sequences = inputs["dna_sequences"]
45
+
46
+ context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " # noqa
47
+ space = " "
48
+ if english_sequence[-1] == " ":
49
+ space = ""
50
+ english_sequence = context + english_sequence + space + "ASSISTANT:"
51
+
52
+ english_tokens = self.english_tokenizer(
53
+ english_sequence,
54
+ return_tensors="pt",
55
+ padding="max_length",
56
+ truncation=True,
57
+ max_length=english_tokens_max_length,
58
+ ).input_ids
59
+ bio_tokens = self.bio_tokenizer(
60
+ dna_sequences,
61
+ return_tensors="pt",
62
+ padding="max_length",
63
+ max_length=bio_tokens_max_length,
64
+ truncation=True,
65
+ ).input_ids.unsqueeze(0)
66
+
67
+ return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
68
+
69
+ def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
70
+ english_tokens = model_inputs["english_tokens"].clone()
71
+ bio_tokens = model_inputs["bio_tokens"].clone()
72
+ projected_bio_embeddings = None
73
+
74
+ actual_num_steps = 0
75
+ with torch.no_grad():
76
+ for _ in range(max_num_tokens_to_decode):
77
+ # Check if no more pad token id
78
+ if (
79
+ self.english_tokenizer.pad_token_id
80
+ not in english_tokens[0].cpu().numpy()
81
+ ):
82
+ break
83
+
84
+ # Predictions
85
+ outs = self.model(
86
+ multi_omics_tokens_ids=(english_tokens, bio_tokens),
87
+ projection_english_tokens_ids=english_tokens,
88
+ projected_bio_embeddings=projected_bio_embeddings,
89
+ )
90
+ projected_bio_embeddings = outs["projected_bio_embeddings"]
91
+ logits = outs["logits"].detach().cpu().numpy()
92
+
93
+ # Get predicted token
94
+ first_idx_pad_token = np.where(
95
+ english_tokens[0].cpu() == self.english_tokenizer.pad_token_id
96
+ )[0][0]
97
+ predicted_token = np.argmax(logits[0, first_idx_pad_token - 1])
98
+
99
+ # If it's <eos> then stop, else add the predicted token
100
+ if predicted_token == self.english_tokenizer.eos_token_id:
101
+ break
102
+ else:
103
+ english_tokens[0, first_idx_pad_token] = predicted_token
104
+ actual_num_steps += 1
105
+
106
+ # Get the position where generation started
107
+ idx_begin_generation = np.where(
108
+ model_inputs["english_tokens"][0].cpu()
109
+ == self.english_tokenizer.pad_token_id
110
+ )[0][0]
111
+
112
+ # Get generated tokens
113
+ generated_tokens = english_tokens[
114
+ 0, idx_begin_generation : idx_begin_generation + actual_num_steps
115
+ ]
116
+
117
+ return {
118
+ "generated_tokens": generated_tokens,
119
+ }
120
+
121
+ def postprocess(self, model_outputs: dict) -> str:
122
+ generated_tokens = model_outputs["generated_tokens"]
123
+ generated_sequence: str = self.english_tokenizer.decode(generated_tokens)
124
+ return generated_sequence