plenz commited on
Commit
8cc4283
1 Parent(s): 5866762

bugfix: Add attention mask for generation

Browse files
Files changed (2) hide show
  1. README.md +6 -5
  2. wrapper_functions.py +10 -2
README.md CHANGED
@@ -76,7 +76,7 @@ model_generation = T5ForConditionalGeneration.from_pretrained(modelcard_generati
76
  del model_generation.encoder # we only need the decoder for generation. Deleting the encoder is optional, but saves memory.
77
  model = AutoModel.from_pretrained(modelcard, trust_remote_code=True, revision='main')
78
  tokenizer = AutoTokenizer.from_pretrained(modelcard)
79
-
80
 
81
  print('get dummy input (2 instances to show batching)')
82
  graph_1 = [
@@ -100,17 +100,17 @@ how = 'global' # can be 'global' or 'local', depending on whether the local or
100
  data_1 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_1, text=text_1, how=how)
101
  data_2 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_2, text=text_2, how=how)
102
  datas = [data_1, data_2]
103
- model_inputs = model.data_processor.to_batch(data_instances=datas, tokenizer=tokenizer, max_seq_len=None, device='cpu')
104
 
105
  print('compute token encodings')
106
  outputs = model(**model_inputs)
107
 
108
  print('generate conditional on encoded graph and text')
109
- outputs = model_generation.generate(encoder_outputs=outputs, max_new_tokens=10)
110
  print('generation 1:', tokenizer.decode(outputs[0], skip_special_tokens=True))
111
  print('generation 2:', tokenizer.decode(outputs[1], skip_special_tokens=False))
112
  ```
113
-
114
 
115
  ## Contact
116
  More information can be found in our paper [Graph Language Models](https://arxiv.org/abs/2401.07105) or our [GitHub repository](https://github.com/Heidelberg-NLP/GraphLanguageModels).
@@ -129,4 +129,5 @@ If this model is helpful for your work, please consider citing the paper:
129
  }
130
  ```
131
 
132
-
 
 
76
  del model_generation.encoder # we only need the decoder for generation. Deleting the encoder is optional, but saves memory.
77
  model = AutoModel.from_pretrained(modelcard, trust_remote_code=True, revision='main')
78
  tokenizer = AutoTokenizer.from_pretrained(modelcard)
79
+ model_generation.shared = model.shared # share embeddings between encoder and decoder. This mimics the T5 architecture.
80
 
81
  print('get dummy input (2 instances to show batching)')
82
  graph_1 = [
 
100
  data_1 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_1, text=text_1, how=how)
101
  data_2 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_2, text=text_2, how=how)
102
  datas = [data_1, data_2]
103
+ model_inputs, attention_mask = model.data_processor.to_batch(data_instances=datas, tokenizer=tokenizer, max_seq_len=None, device='cpu', return_attention_mask=True)
104
 
105
  print('compute token encodings')
106
  outputs = model(**model_inputs)
107
 
108
  print('generate conditional on encoded graph and text')
109
+ outputs = model_generation.generate(encoder_outputs=outputs, max_new_tokens=10, attention_mask=attention_mask)
110
  print('generation 1:', tokenizer.decode(outputs[0], skip_special_tokens=True))
111
  print('generation 2:', tokenizer.decode(outputs[1], skip_special_tokens=False))
112
  ```
113
+ Note that the embedding to map from the vocabulary to T5's hidden dimension is shared by the encoder and the decoder in T5. To mimic the T5 architecture, we run `model_generation.shared = model.shared` after loading the models. For inference this has no effect, since the embeddings are not updated during inference. However, during training / finetuning, the embeddings can become different for the encoder and decoder if they are not shared.
114
 
115
  ## Contact
116
  More information can be found in our paper [Graph Language Models](https://arxiv.org/abs/2401.07105) or our [GitHub repository](https://github.com/Heidelberg-NLP/GraphLanguageModels).
 
129
  }
130
  ```
131
 
132
+ ## Acknowledgments
133
+ Many thanks to Moritz Blum for his help on the generation part.
wrapper_functions.py CHANGED
@@ -416,13 +416,14 @@ class DataProcessor():
416
  return data
417
 
418
  @staticmethod
419
- def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', **kwargs)->dict:
420
  """
421
  converts list of data instances to batched inputs for GLM forward call.
422
  :param datas: list of Data instances
423
- :param max_seq_len: maximum sequence length
424
  :param tokenizer: tokenizer
 
425
  :param device: device
 
426
  :return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket'
427
  """
428
  current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances])
@@ -451,6 +452,9 @@ class DataProcessor():
451
  sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
452
  use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
453
 
 
 
 
454
  # fill tensors
455
  for i, data in enumerate(data_instances):
456
  instance_len = min(data.input_ids.shape[1], max_seq_len)
@@ -459,6 +463,8 @@ class DataProcessor():
459
  relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len]
460
  sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len]
461
  use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len]
 
 
462
 
463
  model_input = {
464
  'input_ids': input_ids,
@@ -467,6 +473,8 @@ class DataProcessor():
467
  'use_additional_bucket': use_additional_bucket,
468
  **kwargs
469
  }
 
 
470
  return model_input
471
 
472
  @staticmethod
 
416
  return data
417
 
418
  @staticmethod
419
+ def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', return_attention_mask:bool=False, **kwargs)->dict:
420
  """
421
  converts list of data instances to batched inputs for GLM forward call.
422
  :param datas: list of Data instances
 
423
  :param tokenizer: tokenizer
424
+ :param max_seq_len: maximum sequence length
425
  :param device: device
426
+ :param return_attention_mask: whether to return attention mask. The attention mask is not used by the GLM encoder, but the decoder needs it to mask out padding tokens in cross attention.
427
  :return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket'
428
  """
429
  current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances])
 
452
  sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
453
  use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
454
 
455
+ if return_attention_mask:
456
+ attention_mask = torch.zeros((len(data_instances), max_seq_len), dtype=torch.bool, device=device)
457
+
458
  # fill tensors
459
  for i, data in enumerate(data_instances):
460
  instance_len = min(data.input_ids.shape[1], max_seq_len)
 
463
  relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len]
464
  sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len]
465
  use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len]
466
+ if return_attention_mask:
467
+ attention_mask[i, :instance_len] = 1
468
 
469
  model_input = {
470
  'input_ids': input_ids,
 
473
  'use_additional_bucket': use_additional_bucket,
474
  **kwargs
475
  }
476
+ if return_attention_mask:
477
+ return model_input, attention_mask
478
  return model_input
479
 
480
  @staticmethod