bugfix: Add attention mask for generation
Browse files- README.md +6 -5
- wrapper_functions.py +10 -2
README.md
CHANGED
@@ -76,7 +76,7 @@ model_generation = T5ForConditionalGeneration.from_pretrained(modelcard_generati
|
|
76 |
del model_generation.encoder # we only need the decoder for generation. Deleting the encoder is optional, but saves memory.
|
77 |
model = AutoModel.from_pretrained(modelcard, trust_remote_code=True, revision='main')
|
78 |
tokenizer = AutoTokenizer.from_pretrained(modelcard)
|
79 |
-
|
80 |
|
81 |
print('get dummy input (2 instances to show batching)')
|
82 |
graph_1 = [
|
@@ -100,17 +100,17 @@ how = 'global' # can be 'global' or 'local', depending on whether the local or
|
|
100 |
data_1 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_1, text=text_1, how=how)
|
101 |
data_2 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_2, text=text_2, how=how)
|
102 |
datas = [data_1, data_2]
|
103 |
-
model_inputs = model.data_processor.to_batch(data_instances=datas, tokenizer=tokenizer, max_seq_len=None, device='cpu')
|
104 |
|
105 |
print('compute token encodings')
|
106 |
outputs = model(**model_inputs)
|
107 |
|
108 |
print('generate conditional on encoded graph and text')
|
109 |
-
outputs = model_generation.generate(encoder_outputs=outputs, max_new_tokens=10)
|
110 |
print('generation 1:', tokenizer.decode(outputs[0], skip_special_tokens=True))
|
111 |
print('generation 2:', tokenizer.decode(outputs[1], skip_special_tokens=False))
|
112 |
```
|
113 |
-
|
114 |
|
115 |
## Contact
|
116 |
More information can be found in our paper [Graph Language Models](https://arxiv.org/abs/2401.07105) or our [GitHub repository](https://github.com/Heidelberg-NLP/GraphLanguageModels).
|
@@ -129,4 +129,5 @@ If this model is helpful for your work, please consider citing the paper:
|
|
129 |
}
|
130 |
```
|
131 |
|
132 |
-
|
|
|
|
76 |
del model_generation.encoder # we only need the decoder for generation. Deleting the encoder is optional, but saves memory.
|
77 |
model = AutoModel.from_pretrained(modelcard, trust_remote_code=True, revision='main')
|
78 |
tokenizer = AutoTokenizer.from_pretrained(modelcard)
|
79 |
+
model_generation.shared = model.shared # share embeddings between encoder and decoder. This mimics the T5 architecture.
|
80 |
|
81 |
print('get dummy input (2 instances to show batching)')
|
82 |
graph_1 = [
|
|
|
100 |
data_1 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_1, text=text_1, how=how)
|
101 |
data_2 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_2, text=text_2, how=how)
|
102 |
datas = [data_1, data_2]
|
103 |
+
model_inputs, attention_mask = model.data_processor.to_batch(data_instances=datas, tokenizer=tokenizer, max_seq_len=None, device='cpu', return_attention_mask=True)
|
104 |
|
105 |
print('compute token encodings')
|
106 |
outputs = model(**model_inputs)
|
107 |
|
108 |
print('generate conditional on encoded graph and text')
|
109 |
+
outputs = model_generation.generate(encoder_outputs=outputs, max_new_tokens=10, attention_mask=attention_mask)
|
110 |
print('generation 1:', tokenizer.decode(outputs[0], skip_special_tokens=True))
|
111 |
print('generation 2:', tokenizer.decode(outputs[1], skip_special_tokens=False))
|
112 |
```
|
113 |
+
Note that the embedding to map from the vocabulary to T5's hidden dimension is shared by the encoder and the decoder in T5. To mimic the T5 architecture, we run `model_generation.shared = model.shared` after loading the models. For inference this has no effect, since the embeddings are not updated during inference. However, during training / finetuning, the embeddings can become different for the encoder and decoder if they are not shared.
|
114 |
|
115 |
## Contact
|
116 |
More information can be found in our paper [Graph Language Models](https://arxiv.org/abs/2401.07105) or our [GitHub repository](https://github.com/Heidelberg-NLP/GraphLanguageModels).
|
|
|
129 |
}
|
130 |
```
|
131 |
|
132 |
+
## Acknowledgments
|
133 |
+
Many thanks to Moritz Blum for his help on the generation part.
|
wrapper_functions.py
CHANGED
@@ -416,13 +416,14 @@ class DataProcessor():
|
|
416 |
return data
|
417 |
|
418 |
@staticmethod
|
419 |
-
def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', **kwargs)->dict:
|
420 |
"""
|
421 |
converts list of data instances to batched inputs for GLM forward call.
|
422 |
:param datas: list of Data instances
|
423 |
-
:param max_seq_len: maximum sequence length
|
424 |
:param tokenizer: tokenizer
|
|
|
425 |
:param device: device
|
|
|
426 |
:return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket'
|
427 |
"""
|
428 |
current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances])
|
@@ -451,6 +452,9 @@ class DataProcessor():
|
|
451 |
sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
|
452 |
use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
|
453 |
|
|
|
|
|
|
|
454 |
# fill tensors
|
455 |
for i, data in enumerate(data_instances):
|
456 |
instance_len = min(data.input_ids.shape[1], max_seq_len)
|
@@ -459,6 +463,8 @@ class DataProcessor():
|
|
459 |
relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len]
|
460 |
sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len]
|
461 |
use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len]
|
|
|
|
|
462 |
|
463 |
model_input = {
|
464 |
'input_ids': input_ids,
|
@@ -467,6 +473,8 @@ class DataProcessor():
|
|
467 |
'use_additional_bucket': use_additional_bucket,
|
468 |
**kwargs
|
469 |
}
|
|
|
|
|
470 |
return model_input
|
471 |
|
472 |
@staticmethod
|
|
|
416 |
return data
|
417 |
|
418 |
@staticmethod
|
419 |
+
def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', return_attention_mask:bool=False, **kwargs)->dict:
|
420 |
"""
|
421 |
converts list of data instances to batched inputs for GLM forward call.
|
422 |
:param datas: list of Data instances
|
|
|
423 |
:param tokenizer: tokenizer
|
424 |
+
:param max_seq_len: maximum sequence length
|
425 |
:param device: device
|
426 |
+
:param return_attention_mask: whether to return attention mask. The attention mask is not used by the GLM encoder, but the decoder needs it to mask out padding tokens in cross attention.
|
427 |
:return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket'
|
428 |
"""
|
429 |
current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances])
|
|
|
452 |
sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
|
453 |
use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
|
454 |
|
455 |
+
if return_attention_mask:
|
456 |
+
attention_mask = torch.zeros((len(data_instances), max_seq_len), dtype=torch.bool, device=device)
|
457 |
+
|
458 |
# fill tensors
|
459 |
for i, data in enumerate(data_instances):
|
460 |
instance_len = min(data.input_ids.shape[1], max_seq_len)
|
|
|
463 |
relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len]
|
464 |
sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len]
|
465 |
use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len]
|
466 |
+
if return_attention_mask:
|
467 |
+
attention_mask[i, :instance_len] = 1
|
468 |
|
469 |
model_input = {
|
470 |
'input_ids': input_ids,
|
|
|
473 |
'use_additional_bucket': use_additional_bucket,
|
474 |
**kwargs
|
475 |
}
|
476 |
+
if return_attention_mask:
|
477 |
+
return model_input, attention_mask
|
478 |
return model_input
|
479 |
|
480 |
@staticmethod
|