lvwerra HF staff commited on
Commit
74c33a2
·
1 Parent(s): 24790fd

Update codeparrot_training.py

Browse files
Files changed (1) hide show
  1. codeparrot_training.py +11 -6
codeparrot_training.py CHANGED
@@ -15,7 +15,7 @@ import wandb
15
 
16
  class ConstantLengthDataset(IterableDataset):
17
 
18
- def __init__(self, tokenizer, dataset, seq_length=1024,
19
  num_of_sequences=1024, chars_per_token=3.6):
20
  self.tokenizer = tokenizer
21
  self.concat_token_id = tokenizer.bos_token_id
@@ -23,6 +23,7 @@ class ConstantLengthDataset(IterableDataset):
23
  self.seq_length = seq_length
24
  self.input_characters = seq_length * chars_per_token * num_of_sequences
25
  self.epoch = 0
 
26
 
27
  def __iter__(self):
28
  iterator = iter(self.dataset)
@@ -36,9 +37,13 @@ class ConstantLengthDataset(IterableDataset):
36
  buffer.append(next(iterator)['content'])
37
  buffer_len += len(buffer[-1])
38
  except StopIteration:
39
- iterator = iter(self.dataset)
40
- self.epoch += 1
41
- logger.info(f"Dataset epoch: {self.epoch}")
 
 
 
 
42
  tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids']
43
  all_token_ids = []
44
  for tokenized_input in tokenized_inputs:
@@ -77,9 +82,9 @@ def create_dataloaders(dataset_name, args):
77
  train_data = train_data.shuffle(buffer_size=args.shuffle_buffer,
78
  seed=args.seed)
79
  valid_data = load_dataset(dataset_name+'-valid', split="train", **ds_kwargs)
80
- train_dataset = ConstantLengthDataset(tokenizer, train_data,
81
  seq_length=args.seq_length)
82
- valid_dataset = ConstantLengthDataset(tokenizer, valid_data,
83
  seq_length=args.seq_length)
84
  train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size)
85
  eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size)
 
15
 
16
  class ConstantLengthDataset(IterableDataset):
17
 
18
+ def __init__(self, tokenizer, dataset, infinite=False, seq_length=1024,
19
  num_of_sequences=1024, chars_per_token=3.6):
20
  self.tokenizer = tokenizer
21
  self.concat_token_id = tokenizer.bos_token_id
 
23
  self.seq_length = seq_length
24
  self.input_characters = seq_length * chars_per_token * num_of_sequences
25
  self.epoch = 0
26
+ self.infinite = infinite
27
 
28
  def __iter__(self):
29
  iterator = iter(self.dataset)
 
37
  buffer.append(next(iterator)['content'])
38
  buffer_len += len(buffer[-1])
39
  except StopIteration:
40
+ if self.infinite:
41
+ iterator = iter(self.dataset)
42
+ self.epoch += 1
43
+ logger.info(f"Dataset epoch: {self.epoch}")
44
+ else:
45
+ more_examples = False
46
+ break
47
  tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids']
48
  all_token_ids = []
49
  for tokenized_input in tokenized_inputs:
 
82
  train_data = train_data.shuffle(buffer_size=args.shuffle_buffer,
83
  seed=args.seed)
84
  valid_data = load_dataset(dataset_name+'-valid', split="train", **ds_kwargs)
85
+ train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True,
86
  seq_length=args.seq_length)
87
+ valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False,
88
  seq_length=args.seq_length)
89
  train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size)
90
  eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size)