nferruz commited on
Commit
045ef38
1 Parent(s): 18b580e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +48 -76
README.md CHANGED
@@ -160,19 +160,39 @@ ancestrally-reconstructed sets, or after searching against metagenomics database
160
  as it will learn new properties from your dataset and potentially improve the generation quality
161
  (especially for poorly populated EC classes).
162
 
163
- To fine-tune ZymCTRL, you will need to process your sequences quite a bit. The scripts below can exactly do that without any
164
- modifications. The only requisite is to start with an input file, 'sequences.fasta' which contains all the sequences in a fasta format.
 
 
 
 
 
 
 
 
 
 
165
  We recommend using at least 200 sequences to obtain the best results. But we've seen it working with fewer sequences, so if you don't have
166
  that many, give it still a go.
167
 
168
 
169
  ```
170
  import random
171
- import transformers
172
  from transformers import AutoTokenizer
173
 
174
- # 1. Read the source file
175
- with open('sequences.fasta', 'r') as fn:
 
 
 
 
 
 
 
 
 
 
 
176
  data = fn.readlines()
177
  fn.close()
178
 
@@ -181,49 +201,39 @@ sequences={}
181
  for line in data:
182
  if '>' in line:
183
  name = line.strip()
184
- sequences[name] = ['2.7.3.12'] # modify with the actual EC class.
185
  continue
186
  sequences[name].append(line.strip())
187
 
188
- # Process fasta files to be in single string - run this part only if the fastas were formated to 60 characters
189
- processed_sequences = {}
190
- for name, sequence in sequences.items():
191
- processed_sequences[f"{sequence[0]};{name}"] = ''.join([x for x in sequence[1:]])
192
-
193
- # Shuffle sequences
194
- sequences_list = [(key,value) for key,value in processed_sequences.items()]
195
  random.shuffle(sequences_list)
196
 
197
- # Load tokenizer
198
- tokenizer = AutoTokenizer.from_pretrained('/path/to/ZymCTRL')
199
-
200
- # the objective is to get here strings, that when tokenized, will span a window length of 1024.
201
- # for each sequence group its length and untokenized string
202
-
203
  print("procesing dataset")
204
  processed_dataset = []
205
  for i in sequences_list:
206
  # length of the control code
207
- label = i[0].split(';')[0]
208
  sequence = i[1].strip()
209
  separator = '<sep>'
210
- control_code_length = len(tokenizer(label+separator)['input_ids'])
211
  available_space = 1021 - control_code_length # It is not 1024 because '<|endoftext|>', and start and end
212
 
213
- # Option 1: the sequence is larger than the available space (3-4% of sequences in BRENDA are over 1024)
214
  if len(sequence) > available_space:
215
  total_length = control_code_length + len(sequence[:available_space]) + 1
216
- seq = f"{label}{separator}{sequence[:available_space]}<|endoftext|>"
217
  processed_dataset.append((total_length, seq))
218
 
219
  # Option 2 & 3: The sequence fits in the block_size space with or without padding
220
  else:
221
  total_length = control_code_length + len(sequence) + 3
222
  # in this case the sequence does not fit with the start/end tokens
223
- seq = f"{label}{separator}<start>{sequence}<end><|endoftext|>"
224
  processed_dataset.append((total_length, seq))
225
 
226
- # Helper function to group sequences
227
  def grouper(iterable):
228
  prev = None
229
  group = ''
@@ -241,50 +251,30 @@ def grouper(iterable):
241
  total_sum = 0
242
  yield group
243
 
244
- # Group sequences
245
  print("grouping processed dataset")
246
  grouped_dataset=dict(enumerate(grouper(processed_dataset),1))
247
 
248
- # Save the processed file out
249
- fn = open("./2.7.3.13_processed.txt",'w')
250
- for key,value in grouped_dataset.items():
251
- fn.write(value)
252
- fn.write("\n")
253
- fn.close()
254
-
255
- fn = open("./2.7.3.13_processed.txt",'w')
256
  for key,value in grouped_dataset.items():
257
  padding_len = 1024 - len(tokenizer(value)['input_ids'])
258
  padding = "<pad>"*padding_len
259
- print(len(tokenizer(value+padding)['input_ids']))
260
  fn.write(value+padding)
261
  fn.write
262
  fn.write("\n")
263
- fn.close()
264
- ```
265
- The previous script will prepare a text file with the correct format for tokenization.
266
- Now we can use the tokenizer to convert its contents to tokens.
267
-
268
- ```
269
- from datasets import load_dataset
270
- import transformers
271
- from transformers.testing_utils import CaptureLogger
272
-
273
- # Load the tokenizer again
274
- from transformers import AutoTokenizer
275
- tokenizer = AutoTokenizer.from_pretrained('/agh/projects/noelia/NLP/zymCTRL/dataset_preparation/tokenizer')
276
-
277
 
278
- #Load the data files
 
279
  data_files = {}
280
  dataset_args = {}
281
- validation_split_percentage = 10 # for a split 90/10
282
- data_files["train"] = './2.7.3.12_processed.txt'
283
  extension = "text"
284
- raw_datasets = load_dataset(extension, data_files=data_files, cache_dir='.', **dataset_args)
285
  tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
286
 
287
- # Load datasets using the HF datasets library:
 
288
  raw_datasets["train"] = load_dataset(extension,
289
  data_files=data_files,
290
  split=f"train[{validation_split_percentage}%:]",
@@ -298,7 +288,6 @@ raw_datasets["validation"] = load_dataset(extension,
298
  **dataset_args,)
299
 
300
  def tokenize_function(examples):
301
- " This function tokenizes input"
302
  with CaptureLogger(tok_logger) as cl:
303
  output = tokenizer(examples["text"])
304
  # clm input could be much much longer than block_size
@@ -308,7 +297,6 @@ def tokenize_function(examples):
308
  )
309
  return output
310
 
311
- # tokenize in parallel
312
  tokenized_datasets = raw_datasets.map(
313
  tokenize_function,
314
  batched=True,
@@ -318,24 +306,6 @@ tokenized_datasets = raw_datasets.map(
318
  desc="Running tokenizer on dataset",
319
  )
320
 
321
- train_dataset = tokenized_datasets["train"]
322
- eval_dataset = tokenized_datasets["validation"]
323
-
324
- train_dataset.save_to_disk('./dataset/train')
325
- eval_dataset.save_to_disk('./dataset/eval')
326
-
327
- # This has saved the datasets tokenized. Now we need to group them into the block size of 1024
328
- from datasets import load_from_disk
329
-
330
- train_dataset = load_from_disk('./2.7.3.13/dataset/train')
331
- eval_dataset = load_from_disk('./2.7.3.13/dataset/eval')
332
-
333
- from datasets.dataset_dict import DatasetDict
334
- tokenized_datasets = DatasetDict()
335
-
336
- tokenized_datasets["train"] = train_dataset
337
- tokenized_datasets["validation"] = eval_dataset
338
-
339
  block_size = 1024
340
  def group_texts(examples):
341
  # Concatenate all texts.
@@ -364,8 +334,10 @@ lm_datasets = tokenized_datasets.map(
364
  train_dataset = lm_datasets["train"]
365
  eval_dataset = lm_datasets["validation"]
366
 
367
- train_dataset.save_to_disk('./dataset/train2')
368
- eval_dataset.save_to_disk('./dataset/eval2')
 
 
369
  ```
370
  The processed datasets will be inside the folder dataset/, called train2 and eval2.
371
  You could also put the two previous scripts into a single one and run it in one go (that is what we do).
 
160
  as it will learn new properties from your dataset and potentially improve the generation quality
161
  (especially for poorly populated EC classes).
162
 
163
+ To fine-tune ZymCTRL, you can use the script below to process your sequences. The only requisite is to start with an input file,
164
+ 'sequences.fasta' which contains all the sequences in a fasta format. Please follow the format below. There should not be new lines '\n' or
165
+ any separator between sequences. In the script, change the variable ec_label to the specific BRENDA class you'd like to fine-tune.
166
+ The script will produce a file called {ec_label}_processed.txt and a folder with the training and validation datasets (split 10%)
167
+ ```
168
+ >Sequence1
169
+ MMMMYMPLKVCD..
170
+ >Sequence2
171
+ MQWMXMYMPLKVCD..
172
+ >Sequence3
173
+ MPLKVCWMXMYMPLD..
174
+ ```
175
  We recommend using at least 200 sequences to obtain the best results. But we've seen it working with fewer sequences, so if you don't have
176
  that many, give it still a go.
177
 
178
 
179
  ```
180
  import random
 
181
  from transformers import AutoTokenizer
182
 
183
+ from datasets import load_dataset
184
+ import transformers
185
+ from transformers.testing_utils import CaptureLogger
186
+
187
+ ## DEFINE THESE VARIABLES
188
+ tokenizer = AutoTokenizer.from_pretrained('AI4PD/ZymCTRL')
189
+ ec_label = '1.1.1.1' # CHANGE TO YOUR LABEL
190
+ validation_split_percentage = 10 # change if you want
191
+ sequence_file = 'sequence.fasta'
192
+
193
+
194
+ #Load sequences, Read source file
195
+ with open(sequence_file, 'r') as fn: #! CHANGE TO SEQUENCES.FASTA
196
  data = fn.readlines()
197
  fn.close()
198
 
 
201
  for line in data:
202
  if '>' in line:
203
  name = line.strip()
204
+ sequences[name] = [] #! CHANGE TO corre
205
  continue
206
  sequences[name].append(line.strip())
207
 
208
+ #Pass sequences to list and shuffle their order randomly
209
+ sequences_list = [(key,value[0]) for key,value in sequences.items()]
 
 
 
 
 
210
  random.shuffle(sequences_list)
211
 
212
+ #the objective is to get here strings, that when tokenized, would span a length of 1024.
213
+ #for each sequence group its length and untokenized string
 
 
 
 
214
  print("procesing dataset")
215
  processed_dataset = []
216
  for i in sequences_list:
217
  # length of the control code
 
218
  sequence = i[1].strip()
219
  separator = '<sep>'
220
+ control_code_length = len(tokenizer(ec_label+separator)['input_ids'])
221
  available_space = 1021 - control_code_length # It is not 1024 because '<|endoftext|>', and start and end
222
 
223
+ # Option 1: the sequence is larger than the available space (3-4% of sequences)
224
  if len(sequence) > available_space:
225
  total_length = control_code_length + len(sequence[:available_space]) + 1
226
+ seq = f"{ec_label}{separator}{sequence[:available_space]}<|endoftext|>"
227
  processed_dataset.append((total_length, seq))
228
 
229
  # Option 2 & 3: The sequence fits in the block_size space with or without padding
230
  else:
231
  total_length = control_code_length + len(sequence) + 3
232
  # in this case the sequence does not fit with the start/end tokens
233
+ seq = f"{ec_label}{separator}<start>{sequence}<end><|endoftext|>"
234
  processed_dataset.append((total_length, seq))
235
 
236
+ # Group sequences
237
  def grouper(iterable):
238
  prev = None
239
  group = ''
 
251
  total_sum = 0
252
  yield group
253
 
 
254
  print("grouping processed dataset")
255
  grouped_dataset=dict(enumerate(grouper(processed_dataset),1))
256
 
257
+ # Write file out for the tokenizer to read
258
+ fn = open(f"{ec_label}_processed.txt",'w')
 
 
 
 
 
 
259
  for key,value in grouped_dataset.items():
260
  padding_len = 1024 - len(tokenizer(value)['input_ids'])
261
  padding = "<pad>"*padding_len
 
262
  fn.write(value+padding)
263
  fn.write
264
  fn.write("\n")
265
+ fn.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ ##TOKENIZE
268
+ # adapted from the trainer file
269
  data_files = {}
270
  dataset_args = {}
271
+
272
+ data_files["train"] = f"{ec_label}_processed.txt"
273
  extension = "text"
 
274
  tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
275
 
276
+ raw_datasets = load_dataset(extension, data_files=data_files, cache_dir='.', **dataset_args)
277
+
278
  raw_datasets["train"] = load_dataset(extension,
279
  data_files=data_files,
280
  split=f"train[{validation_split_percentage}%:]",
 
288
  **dataset_args,)
289
 
290
  def tokenize_function(examples):
 
291
  with CaptureLogger(tok_logger) as cl:
292
  output = tokenizer(examples["text"])
293
  # clm input could be much much longer than block_size
 
297
  )
298
  return output
299
 
 
300
  tokenized_datasets = raw_datasets.map(
301
  tokenize_function,
302
  batched=True,
 
306
  desc="Running tokenizer on dataset",
307
  )
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  block_size = 1024
310
  def group_texts(examples):
311
  # Concatenate all texts.
 
334
  train_dataset = lm_datasets["train"]
335
  eval_dataset = lm_datasets["validation"]
336
 
337
+ train_dataset.save_to_disk('./dataset/train')
338
+ eval_dataset.save_to_disk('./dataset/eval')
339
+
340
+
341
  ```
342
  The processed datasets will be inside the folder dataset/, called train2 and eval2.
343
  You could also put the two previous scripts into a single one and run it in one go (that is what we do).