Update README.md
Browse files
README.md
CHANGED
@@ -160,19 +160,39 @@ ancestrally-reconstructed sets, or after searching against metagenomics database
|
|
160 |
as it will learn new properties from your dataset and potentially improve the generation quality
|
161 |
(especially for poorly populated EC classes).
|
162 |
|
163 |
-
To fine-tune ZymCTRL, you
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
We recommend using at least 200 sequences to obtain the best results. But we've seen it working with fewer sequences, so if you don't have
|
166 |
that many, give it still a go.
|
167 |
|
168 |
|
169 |
```
|
170 |
import random
|
171 |
-
import transformers
|
172 |
from transformers import AutoTokenizer
|
173 |
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
data = fn.readlines()
|
177 |
fn.close()
|
178 |
|
@@ -181,49 +201,39 @@ sequences={}
|
|
181 |
for line in data:
|
182 |
if '>' in line:
|
183 |
name = line.strip()
|
184 |
-
sequences[name] = [
|
185 |
continue
|
186 |
sequences[name].append(line.strip())
|
187 |
|
188 |
-
#
|
189 |
-
|
190 |
-
for name, sequence in sequences.items():
|
191 |
-
processed_sequences[f"{sequence[0]};{name}"] = ''.join([x for x in sequence[1:]])
|
192 |
-
|
193 |
-
# Shuffle sequences
|
194 |
-
sequences_list = [(key,value) for key,value in processed_sequences.items()]
|
195 |
random.shuffle(sequences_list)
|
196 |
|
197 |
-
#
|
198 |
-
|
199 |
-
|
200 |
-
# the objective is to get here strings, that when tokenized, will span a window length of 1024.
|
201 |
-
# for each sequence group its length and untokenized string
|
202 |
-
|
203 |
print("procesing dataset")
|
204 |
processed_dataset = []
|
205 |
for i in sequences_list:
|
206 |
# length of the control code
|
207 |
-
label = i[0].split(';')[0]
|
208 |
sequence = i[1].strip()
|
209 |
separator = '<sep>'
|
210 |
-
control_code_length = len(tokenizer(
|
211 |
available_space = 1021 - control_code_length # It is not 1024 because '<|endoftext|>', and start and end
|
212 |
|
213 |
-
# Option 1: the sequence is larger than the available space (3-4% of sequences
|
214 |
if len(sequence) > available_space:
|
215 |
total_length = control_code_length + len(sequence[:available_space]) + 1
|
216 |
-
seq = f"{
|
217 |
processed_dataset.append((total_length, seq))
|
218 |
|
219 |
# Option 2 & 3: The sequence fits in the block_size space with or without padding
|
220 |
else:
|
221 |
total_length = control_code_length + len(sequence) + 3
|
222 |
# in this case the sequence does not fit with the start/end tokens
|
223 |
-
seq = f"{
|
224 |
processed_dataset.append((total_length, seq))
|
225 |
|
226 |
-
#
|
227 |
def grouper(iterable):
|
228 |
prev = None
|
229 |
group = ''
|
@@ -241,50 +251,30 @@ def grouper(iterable):
|
|
241 |
total_sum = 0
|
242 |
yield group
|
243 |
|
244 |
-
# Group sequences
|
245 |
print("grouping processed dataset")
|
246 |
grouped_dataset=dict(enumerate(grouper(processed_dataset),1))
|
247 |
|
248 |
-
#
|
249 |
-
fn = open("
|
250 |
-
for key,value in grouped_dataset.items():
|
251 |
-
fn.write(value)
|
252 |
-
fn.write("\n")
|
253 |
-
fn.close()
|
254 |
-
|
255 |
-
fn = open("./2.7.3.13_processed.txt",'w')
|
256 |
for key,value in grouped_dataset.items():
|
257 |
padding_len = 1024 - len(tokenizer(value)['input_ids'])
|
258 |
padding = "<pad>"*padding_len
|
259 |
-
print(len(tokenizer(value+padding)['input_ids']))
|
260 |
fn.write(value+padding)
|
261 |
fn.write
|
262 |
fn.write("\n")
|
263 |
-
fn.close()
|
264 |
-
```
|
265 |
-
The previous script will prepare a text file with the correct format for tokenization.
|
266 |
-
Now we can use the tokenizer to convert its contents to tokens.
|
267 |
-
|
268 |
-
```
|
269 |
-
from datasets import load_dataset
|
270 |
-
import transformers
|
271 |
-
from transformers.testing_utils import CaptureLogger
|
272 |
-
|
273 |
-
# Load the tokenizer again
|
274 |
-
from transformers import AutoTokenizer
|
275 |
-
tokenizer = AutoTokenizer.from_pretrained('/agh/projects/noelia/NLP/zymCTRL/dataset_preparation/tokenizer')
|
276 |
-
|
277 |
|
278 |
-
|
|
|
279 |
data_files = {}
|
280 |
dataset_args = {}
|
281 |
-
|
282 |
-
data_files["train"] =
|
283 |
extension = "text"
|
284 |
-
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir='.', **dataset_args)
|
285 |
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
286 |
|
287 |
-
|
|
|
288 |
raw_datasets["train"] = load_dataset(extension,
|
289 |
data_files=data_files,
|
290 |
split=f"train[{validation_split_percentage}%:]",
|
@@ -298,7 +288,6 @@ raw_datasets["validation"] = load_dataset(extension,
|
|
298 |
**dataset_args,)
|
299 |
|
300 |
def tokenize_function(examples):
|
301 |
-
" This function tokenizes input"
|
302 |
with CaptureLogger(tok_logger) as cl:
|
303 |
output = tokenizer(examples["text"])
|
304 |
# clm input could be much much longer than block_size
|
@@ -308,7 +297,6 @@ def tokenize_function(examples):
|
|
308 |
)
|
309 |
return output
|
310 |
|
311 |
-
# tokenize in parallel
|
312 |
tokenized_datasets = raw_datasets.map(
|
313 |
tokenize_function,
|
314 |
batched=True,
|
@@ -318,24 +306,6 @@ tokenized_datasets = raw_datasets.map(
|
|
318 |
desc="Running tokenizer on dataset",
|
319 |
)
|
320 |
|
321 |
-
train_dataset = tokenized_datasets["train"]
|
322 |
-
eval_dataset = tokenized_datasets["validation"]
|
323 |
-
|
324 |
-
train_dataset.save_to_disk('./dataset/train')
|
325 |
-
eval_dataset.save_to_disk('./dataset/eval')
|
326 |
-
|
327 |
-
# This has saved the datasets tokenized. Now we need to group them into the block size of 1024
|
328 |
-
from datasets import load_from_disk
|
329 |
-
|
330 |
-
train_dataset = load_from_disk('./2.7.3.13/dataset/train')
|
331 |
-
eval_dataset = load_from_disk('./2.7.3.13/dataset/eval')
|
332 |
-
|
333 |
-
from datasets.dataset_dict import DatasetDict
|
334 |
-
tokenized_datasets = DatasetDict()
|
335 |
-
|
336 |
-
tokenized_datasets["train"] = train_dataset
|
337 |
-
tokenized_datasets["validation"] = eval_dataset
|
338 |
-
|
339 |
block_size = 1024
|
340 |
def group_texts(examples):
|
341 |
# Concatenate all texts.
|
@@ -364,8 +334,10 @@ lm_datasets = tokenized_datasets.map(
|
|
364 |
train_dataset = lm_datasets["train"]
|
365 |
eval_dataset = lm_datasets["validation"]
|
366 |
|
367 |
-
train_dataset.save_to_disk('./dataset/
|
368 |
-
eval_dataset.save_to_disk('./dataset/
|
|
|
|
|
369 |
```
|
370 |
The processed datasets will be inside the folder dataset/, called train2 and eval2.
|
371 |
You could also put the two previous scripts into a single one and run it in one go (that is what we do).
|
|
|
160 |
as it will learn new properties from your dataset and potentially improve the generation quality
|
161 |
(especially for poorly populated EC classes).
|
162 |
|
163 |
+
To fine-tune ZymCTRL, you can use the script below to process your sequences. The only requisite is to start with an input file,
|
164 |
+
'sequences.fasta' which contains all the sequences in a fasta format. Please follow the format below. There should not be new lines '\n' or
|
165 |
+
any separator between sequences. In the script, change the variable ec_label to the specific BRENDA class you'd like to fine-tune.
|
166 |
+
The script will produce a file called {ec_label}_processed.txt and a folder with the training and validation datasets (split 10%)
|
167 |
+
```
|
168 |
+
>Sequence1
|
169 |
+
MMMMYMPLKVCD..
|
170 |
+
>Sequence2
|
171 |
+
MQWMXMYMPLKVCD..
|
172 |
+
>Sequence3
|
173 |
+
MPLKVCWMXMYMPLD..
|
174 |
+
```
|
175 |
We recommend using at least 200 sequences to obtain the best results. But we've seen it working with fewer sequences, so if you don't have
|
176 |
that many, give it still a go.
|
177 |
|
178 |
|
179 |
```
|
180 |
import random
|
|
|
181 |
from transformers import AutoTokenizer
|
182 |
|
183 |
+
from datasets import load_dataset
|
184 |
+
import transformers
|
185 |
+
from transformers.testing_utils import CaptureLogger
|
186 |
+
|
187 |
+
## DEFINE THESE VARIABLES
|
188 |
+
tokenizer = AutoTokenizer.from_pretrained('AI4PD/ZymCTRL')
|
189 |
+
ec_label = '1.1.1.1' # CHANGE TO YOUR LABEL
|
190 |
+
validation_split_percentage = 10 # change if you want
|
191 |
+
sequence_file = 'sequence.fasta'
|
192 |
+
|
193 |
+
|
194 |
+
#Load sequences, Read source file
|
195 |
+
with open(sequence_file, 'r') as fn: #! CHANGE TO SEQUENCES.FASTA
|
196 |
data = fn.readlines()
|
197 |
fn.close()
|
198 |
|
|
|
201 |
for line in data:
|
202 |
if '>' in line:
|
203 |
name = line.strip()
|
204 |
+
sequences[name] = [] #! CHANGE TO corre
|
205 |
continue
|
206 |
sequences[name].append(line.strip())
|
207 |
|
208 |
+
#Pass sequences to list and shuffle their order randomly
|
209 |
+
sequences_list = [(key,value[0]) for key,value in sequences.items()]
|
|
|
|
|
|
|
|
|
|
|
210 |
random.shuffle(sequences_list)
|
211 |
|
212 |
+
#the objective is to get here strings, that when tokenized, would span a length of 1024.
|
213 |
+
#for each sequence group its length and untokenized string
|
|
|
|
|
|
|
|
|
214 |
print("procesing dataset")
|
215 |
processed_dataset = []
|
216 |
for i in sequences_list:
|
217 |
# length of the control code
|
|
|
218 |
sequence = i[1].strip()
|
219 |
separator = '<sep>'
|
220 |
+
control_code_length = len(tokenizer(ec_label+separator)['input_ids'])
|
221 |
available_space = 1021 - control_code_length # It is not 1024 because '<|endoftext|>', and start and end
|
222 |
|
223 |
+
# Option 1: the sequence is larger than the available space (3-4% of sequences)
|
224 |
if len(sequence) > available_space:
|
225 |
total_length = control_code_length + len(sequence[:available_space]) + 1
|
226 |
+
seq = f"{ec_label}{separator}{sequence[:available_space]}<|endoftext|>"
|
227 |
processed_dataset.append((total_length, seq))
|
228 |
|
229 |
# Option 2 & 3: The sequence fits in the block_size space with or without padding
|
230 |
else:
|
231 |
total_length = control_code_length + len(sequence) + 3
|
232 |
# in this case the sequence does not fit with the start/end tokens
|
233 |
+
seq = f"{ec_label}{separator}<start>{sequence}<end><|endoftext|>"
|
234 |
processed_dataset.append((total_length, seq))
|
235 |
|
236 |
+
# Group sequences
|
237 |
def grouper(iterable):
|
238 |
prev = None
|
239 |
group = ''
|
|
|
251 |
total_sum = 0
|
252 |
yield group
|
253 |
|
|
|
254 |
print("grouping processed dataset")
|
255 |
grouped_dataset=dict(enumerate(grouper(processed_dataset),1))
|
256 |
|
257 |
+
# Write file out for the tokenizer to read
|
258 |
+
fn = open(f"{ec_label}_processed.txt",'w')
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
for key,value in grouped_dataset.items():
|
260 |
padding_len = 1024 - len(tokenizer(value)['input_ids'])
|
261 |
padding = "<pad>"*padding_len
|
|
|
262 |
fn.write(value+padding)
|
263 |
fn.write
|
264 |
fn.write("\n")
|
265 |
+
fn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
+
##TOKENIZE
|
268 |
+
# adapted from the trainer file
|
269 |
data_files = {}
|
270 |
dataset_args = {}
|
271 |
+
|
272 |
+
data_files["train"] = f"{ec_label}_processed.txt"
|
273 |
extension = "text"
|
|
|
274 |
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
275 |
|
276 |
+
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir='.', **dataset_args)
|
277 |
+
|
278 |
raw_datasets["train"] = load_dataset(extension,
|
279 |
data_files=data_files,
|
280 |
split=f"train[{validation_split_percentage}%:]",
|
|
|
288 |
**dataset_args,)
|
289 |
|
290 |
def tokenize_function(examples):
|
|
|
291 |
with CaptureLogger(tok_logger) as cl:
|
292 |
output = tokenizer(examples["text"])
|
293 |
# clm input could be much much longer than block_size
|
|
|
297 |
)
|
298 |
return output
|
299 |
|
|
|
300 |
tokenized_datasets = raw_datasets.map(
|
301 |
tokenize_function,
|
302 |
batched=True,
|
|
|
306 |
desc="Running tokenizer on dataset",
|
307 |
)
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
block_size = 1024
|
310 |
def group_texts(examples):
|
311 |
# Concatenate all texts.
|
|
|
334 |
train_dataset = lm_datasets["train"]
|
335 |
eval_dataset = lm_datasets["validation"]
|
336 |
|
337 |
+
train_dataset.save_to_disk('./dataset/train')
|
338 |
+
eval_dataset.save_to_disk('./dataset/eval')
|
339 |
+
|
340 |
+
|
341 |
```
|
342 |
The processed datasets will be inside the folder dataset/, called train2 and eval2.
|
343 |
You could also put the two previous scripts into a single one and run it in one go (that is what we do).
|